Files
sglang/python/sglang/srt/layers/attention/trtllm_mha_backend.py

335 lines
12 KiB
Python

from __future__ import annotations
"""
Support attention backend for TRTLLM MHA kernels from flashinfer.
The kernel supports sm100 only, with sliding window and attention sink features.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
# Constants
DEFAULT_WORKSPACE_SIZE_MB = (
512 # Memory workspace size in MB, todo(Yingyi): read from config
)
# Reuse this workspace buffer across all TRTLLM MHA wrappers
global_zero_init_workspace_buffer = None
@dataclass
class TRTLLMMHAMetadata:
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 1
# Maximum sequence length for key
max_seq_len_k: int = 0
# Cumulative sequence lengths for `query
cu_seqlens_q: torch.Tensor = None
# Cumulative sequence lengths for key
cu_seqlens_k: torch.Tensor = None
# Page table, the index of KV Cache Tables/Blocks
page_table: torch.Tensor = None
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
"""TRTLLM MHA attention kernel from flashinfer."""
def __init__(
self,
model_runner: ModelRunner,
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
q_indptr_decode_buf: Optional[torch.Tensor] = None,
):
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
config = model_runner.model_config
# MHA-specific dimensions
self.max_context_len = model_runner.model_config.context_len
self.hidden_size = config.hidden_size
# Runtime parameters
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.page_size = model_runner.page_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.device = model_runner.device
# Workspace allocation
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
# Allocate buffers
global global_zero_init_workspace_buffer
if global_zero_init_workspace_buffer is None:
global_zero_init_workspace_buffer = torch.zeros(
self.workspace_size,
dtype=torch.uint8,
device=model_runner.device,
)
self.workspace_buffer = global_zero_init_workspace_buffer
# CUDA graph state
self.decode_cuda_graph_metadata = {}
# Forward metadata
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
"""Initialize CUDA graph state for TRTLLM MHA."""
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
):
"""Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata()
# Get sequence information
metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
# Precompute maximum sequence length
metadata.max_seq_len_k = seq_lens[:bs].max().item()
# Precompute page table
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
self.decode_cuda_graph_metadata[bs] = metadata
self.forward_metadata = metadata
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
metadata = None
# Normal Decode
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32.copy_(seq_lens)
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
self.forward_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self) -> int:
"""Get the fill value for sequence lengths in CUDA graph."""
return 1
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize the metadata for a forward pass."""
metadata = TRTLLMMHAMetadata()
seqlens_in_batch = forward_batch.seq_lens
batch_size = forward_batch.batch_size
device = seqlens_in_batch.device
if forward_batch.forward_mode.is_decode_or_idle():
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if any(forward_batch.extend_prefix_lens_cpu):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
else:
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Convert the page table to a strided format
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
) -> torch.Tensor:
"""Run forward for decode using TRTLLM MHA kernel."""
cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
# shape conversion:
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
v_cache = v_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# Call TRT-LLM kernel
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_seq_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
**kwargs,
):
cache_loc = forward_batch.out_cache_loc
if save_kv_cache and k is not None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim]
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
k_cache = k_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
).permute(0, 2, 1, 3)
v_cache = v_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
).permute(0, 2, 1, 3)
kv_cache = (k_cache, v_cache)
# sink: additional value per head in the denominator of the softmax.
attention_sink = kwargs.get("sinks", None)
# TODO: add support for quantization
q_scale = 1.0
k_scale = (
layer.k_scale_float
if getattr(layer, "k_scale_float", None) is not None
else 1.0
)
bmm1_scale = q_scale * k_scale * layer.scaling
bmm2_scale = 1.0
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=q,
kv_cache=kv_cache,
workspace_buffer=self.workspace_buffer,
block_tables=self.forward_metadata.page_table,
seq_lens=self.forward_metadata.cache_seqlens_int32,
max_q_len=self.forward_metadata.max_seq_len_q,
max_kv_len=self.max_context_len,
bmm1_scale=bmm1_scale,
bmm2_scale=bmm2_scale,
batch_size=forward_batch.batch_size,
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
window_left=layer.sliding_window_size,
# TODO: add attention_sink operation or nvfp4 scale factor if needed
sinks=attention_sink,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)