feat: add trtllm-gen mha from direct call (#8782)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
321
python/sglang/srt/layers/attention/trtllm_mha_backend.py
Normal file
321
python/sglang/srt/layers/attention/trtllm_mha_backend.py
Normal file
@@ -0,0 +1,321 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from python.sglang.srt.layers.radix_attention import RadixAttention
|
||||
|
||||
"""
|
||||
Support attention backend for TRTLLM MLA kernels from flashinfer.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
# Constants
|
||||
DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
|
||||
|
||||
|
||||
@dataclass
|
||||
class TRTLLMMHAMetadata:
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
# Maximum sequence length for query
|
||||
max_seq_len_q: int = 1
|
||||
# Maximum sequence length for key
|
||||
max_seq_len_k: int = 0
|
||||
# Cumulative sequence lengths for `query
|
||||
cu_seqlens_q: torch.Tensor = None
|
||||
# Cumulative sequence lengths for key
|
||||
cu_seqlens_k: torch.Tensor = None
|
||||
# Page table, the index of KV Cache Tables/Blocks
|
||||
page_table: torch.Tensor = None
|
||||
|
||||
|
||||
class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
"""TRTLLM MHA attention kernel from flashinfer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
skip_prefill: bool = False,
|
||||
kv_indptr_buf: Optional[torch.Tensor] = None,
|
||||
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
|
||||
|
||||
config = model_runner.model_config
|
||||
|
||||
# MHA-specific dimensions
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.sliding_window_size = (
|
||||
model_runner.sliding_window_size
|
||||
if model_runner.sliding_window_size is not None
|
||||
else -1 # -1 indicates full attention
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Runtime parameters
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
self.q_data_type = model_runner.dtype
|
||||
self.page_size = model_runner.page_size
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.device = model_runner.device
|
||||
|
||||
# Workspace allocation
|
||||
self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
|
||||
self.workspace_buffer = torch.empty(
|
||||
self.workspace_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
|
||||
# Forward metadata
|
||||
self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize CUDA graph state for TRTLLM MHA."""
|
||||
self.decode_cuda_graph_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
metadata = TRTLLMMHAMetadata()
|
||||
|
||||
# Get sequence information
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
|
||||
# Precompute page table
|
||||
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
device = seq_lens.device
|
||||
metadata = None
|
||||
|
||||
# Normal Decode
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
||||
"""Get the fill value for sequence lengths in CUDA graph."""
|
||||
return 1
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize the metadata for a forward pass."""
|
||||
|
||||
metadata = TRTLLMMHAMetadata()
|
||||
seqlens_in_batch = forward_batch.seq_lens
|
||||
batch_size = forward_batch.batch_size
|
||||
device = seqlens_in_batch.device
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
else:
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
else:
|
||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
|
||||
# Convert the page table to a strided format
|
||||
if self.page_size > 1:
|
||||
self.strided_indices = torch.arange(
|
||||
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
||||
)
|
||||
metadata.page_table = (
|
||||
metadata.page_table[:, self.strided_indices] // self.page_size
|
||||
)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Run forward for decode using TRTLLM MHA kernel."""
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
if save_kv_cache and k is not None:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
|
||||
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
# shape conversion:
|
||||
# [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim]
|
||||
k_cache = k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
).permute(0, 2, 1, 3)
|
||||
v_cache = v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
).permute(0, 2, 1, 3)
|
||||
kv_cache = (k_cache, v_cache)
|
||||
|
||||
# TODO: bmm1_scale and bmm2_scale might require modification
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
bmm2_scale = 1.0
|
||||
|
||||
# Call TRT-LLM kernel
|
||||
# raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype
|
||||
o = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=q,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
block_tables=self.forward_metadata.page_table,
|
||||
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
||||
max_seq_len=self.forward_metadata.max_seq_len_k,
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
window_left=self.sliding_window_size,
|
||||
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
):
|
||||
cache_loc = forward_batch.out_cache_loc
|
||||
if save_kv_cache and k is not None:
|
||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
||||
)
|
||||
q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
||||
k_cache = k_cache.view(
|
||||
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
||||
).permute(0, 2, 1, 3)
|
||||
v_cache = v_cache.view(
|
||||
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
||||
).permute(0, 2, 1, 3)
|
||||
kv_cache = (k_cache, v_cache)
|
||||
|
||||
# TODO: bmm1_scale and bmm2_scale might require modification
|
||||
# TODO: Change once quantization is supported
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
bmm2_scale = 1.0
|
||||
|
||||
o = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=q,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
block_tables=self.forward_metadata.page_table,
|
||||
seq_lens=self.forward_metadata.cache_seqlens_int32,
|
||||
max_q_len=self.forward_metadata.max_seq_len_q,
|
||||
max_kv_len=self.forward_metadata.max_seq_len_k,
|
||||
bmm1_scale=bmm1_scale,
|
||||
bmm2_scale=bmm2_scale,
|
||||
batch_size=forward_batch.batch_size,
|
||||
cum_seq_lens_q=self.forward_metadata.cu_seqlens_q,
|
||||
cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k,
|
||||
window_left=self.sliding_window_size,
|
||||
# TODO: add attention_sink operation or nvfp4 scale factor if needed
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||
@@ -1705,6 +1705,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
or attention_backend_str == "flashmla"
|
||||
or attention_backend_str == "cutlass_mla"
|
||||
or attention_backend_str == "ascend"
|
||||
or attention_backend_str == "trtllm_mha"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
|
||||
@@ -1449,6 +1449,17 @@ class ModelRunner:
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
||||
|
||||
return TRTLLMMLABackend(self)
|
||||
elif self.server_args.attention_backend == "trtllm_mha":
|
||||
if self.use_mla_backend:
|
||||
raise ValueError(
|
||||
"trtllm_mha backend can only be used with non-MLA models."
|
||||
)
|
||||
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
||||
TRTLLMHAAttnBackend,
|
||||
)
|
||||
|
||||
return TRTLLMHAAttnBackend(self)
|
||||
|
||||
elif self.server_args.attention_backend == "intel_amx":
|
||||
from sglang.srt.layers.attention.intel_amx_backend import (
|
||||
IntelAMXAttnBackend,
|
||||
|
||||
@@ -441,6 +441,23 @@ class ServerArgs:
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.attention_backend == "trtllm_mha":
|
||||
if not is_sm100_supported():
|
||||
raise ValueError(
|
||||
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
||||
)
|
||||
|
||||
if self.page_size not in [16, 32, 64]:
|
||||
logger.warning(
|
||||
f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
# Set page size
|
||||
if self.page_size is None:
|
||||
self.page_size = 1
|
||||
@@ -1275,6 +1292,7 @@ class ServerArgs:
|
||||
"ascend",
|
||||
"triton",
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
],
|
||||
default=ServerArgs.attention_backend,
|
||||
help="Choose the kernels for attention layers.",
|
||||
|
||||
Reference in New Issue
Block a user