From 40e3b2beebef63234ed17c40015677aa252b93b5 Mon Sep 17 00:00:00 2001 From: eigen <52445717+yyihuang@users.noreply.github.com> Date: Tue, 5 Aug 2025 06:28:39 -0400 Subject: [PATCH] feat: add trtllm-gen mha from direct call (#8782) Co-authored-by: Baizhou Zhang --- .../layers/attention/trtllm_mha_backend.py | 321 ++++++++++++++++++ python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 11 + python/sglang/srt/server_args.py | 18 + 4 files changed, 351 insertions(+) create mode 100644 python/sglang/srt/layers/attention/trtllm_mha_backend.py diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py new file mode 100644 index 000000000..2e7c67758 --- /dev/null +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +from python.sglang.srt.layers.radix_attention import RadixAttention + +""" +Support attention backend for TRTLLM MLA kernels from flashinfer. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + import flashinfer + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo + +# Constants +DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB + + +@dataclass +class TRTLLMMHAMetadata: + # Sequence lengths for the forward batch + cache_seqlens_int32: torch.Tensor = None + # Maximum sequence length for query + max_seq_len_q: int = 1 + # Maximum sequence length for key + max_seq_len_k: int = 0 + # Cumulative sequence lengths for `query + cu_seqlens_q: torch.Tensor = None + # Cumulative sequence lengths for key + cu_seqlens_k: torch.Tensor = None + # Page table, the index of KV Cache Tables/Blocks + page_table: torch.Tensor = None + + +class TRTLLMHAAttnBackend(FlashInferAttnBackend): + """TRTLLM MHA attention kernel from flashinfer.""" + + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + q_indptr_decode_buf: Optional[torch.Tensor] = None, + ): + super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) + + config = model_runner.model_config + + # MHA-specific dimensions + self.max_context_len = model_runner.model_config.context_len + self.sliding_window_size = ( + model_runner.sliding_window_size + if model_runner.sliding_window_size is not None + else -1 # -1 indicates full attention + ) + self.hidden_size = config.hidden_size + + # Runtime parameters + self.data_type = model_runner.kv_cache_dtype + self.q_data_type = model_runner.dtype + self.page_size = model_runner.page_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.device = model_runner.device + + # Workspace allocation + self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024 + self.workspace_buffer = torch.empty( + self.workspace_size, dtype=torch.int8, device=self.device + ) + + # CUDA graph state + self.decode_cuda_graph_metadata = {} + + # Forward metadata + self.forward_metadata: Optional[TRTLLMMHAMetadata] = None + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ): + """Initialize CUDA graph state for TRTLLM MHA.""" + self.decode_cuda_graph_metadata = { + "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), + "page_table": torch.zeros( + max_bs, + (self.max_context_len + self.page_size - 1) // self.page_size, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + ): + """Initialize metadata for CUDA graph capture.""" + metadata = TRTLLMMHAMetadata() + + # Get sequence information + metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + + # Precompute maximum sequence length + metadata.max_seq_len_k = seq_lens.max().item() + + # Precompute page table + metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] + self.decode_cuda_graph_metadata[bs] = metadata + self.forward_metadata = metadata + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], + seq_lens_cpu: Optional[torch.Tensor], + ): + """Replay CUDA graph with new inputs.""" + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + device = seq_lens.device + metadata = None + + # Normal Decode + metadata = self.decode_cuda_graph_metadata[bs] + max_len = seq_lens_cpu.max().item() + max_seq_pages = (max_len + self.page_size - 1) // self.page_size + metadata.max_seq_len_k = max_len + + metadata.cache_seqlens_int32.copy_(seq_lens) + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :], + ] + metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) + self.forward_metadata = metadata + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Get the fill value for sequence lengths in CUDA graph.""" + return 1 + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Initialize the metadata for a forward pass.""" + + metadata = TRTLLMMHAMetadata() + seqlens_in_batch = forward_batch.seq_lens + batch_size = forward_batch.batch_size + device = seqlens_in_batch.device + + if forward_batch.forward_mode.is_decode_or_idle(): + # Normal Decode + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + else: + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + + if any(forward_batch.extend_prefix_lens_cpu): + extend_seq_lens = forward_batch.extend_seq_lens + metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) + metadata.cu_seqlens_q = torch.nn.functional.pad( + torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) + else: + metadata.max_seq_len_q = metadata.max_seq_len_k + metadata.cu_seqlens_q = metadata.cu_seqlens_k + + # Convert the page table to a strided format + if self.page_size > 1: + self.strided_indices = torch.arange( + 0, metadata.page_table.shape[1], self.page_size, device=self.device + ) + metadata.page_table = ( + metadata.page_table[:, self.strided_indices] // self.page_size + ) + + self.forward_metadata = metadata + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ) -> torch.Tensor: + """Run forward for decode using TRTLLM MHA kernel.""" + cache_loc = forward_batch.out_cache_loc + if save_kv_cache and k is not None: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + # shape conversion: + # [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim] + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + kv_cache = (k_cache, v_cache) + + # TODO: bmm1_scale and bmm2_scale might require modification + q_scale = 1.0 + k_scale = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + bmm1_scale = q_scale * k_scale * layer.scaling + bmm2_scale = 1.0 + + # Call TRT-LLM kernel + # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype + o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=self.workspace_buffer, + block_tables=self.forward_metadata.page_table, + seq_lens=self.forward_metadata.cache_seqlens_int32, + max_seq_len=self.forward_metadata.max_seq_len_k, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + window_left=self.sliding_window_size, + # TODO: add attention_sink operation or nvfp4 scale factor if needed + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache=True, + ): + cache_loc = forward_batch.out_cache_loc + if save_kv_cache and k is not None: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) + q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + k_cache = k_cache.view( + -1, self.page_size, layer.tp_k_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + v_cache = v_cache.view( + -1, self.page_size, layer.tp_v_head_num, layer.head_dim + ).permute(0, 2, 1, 3) + kv_cache = (k_cache, v_cache) + + # TODO: bmm1_scale and bmm2_scale might require modification + # TODO: Change once quantization is supported + q_scale = 1.0 + k_scale = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + bmm1_scale = q_scale * k_scale * layer.scaling + bmm2_scale = 1.0 + + o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=self.workspace_buffer, + block_tables=self.forward_metadata.page_table, + seq_lens=self.forward_metadata.cache_seqlens_int32, + max_q_len=self.forward_metadata.max_seq_len_q, + max_kv_len=self.forward_metadata.max_seq_len_k, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + batch_size=forward_batch.batch_size, + cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, + cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, + window_left=self.sliding_window_size, + # TODO: add attention_sink operation or nvfp4 scale factor if needed + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 759bb6afa..99ea56965 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1705,6 +1705,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): or attention_backend_str == "flashmla" or attention_backend_str == "cutlass_mla" or attention_backend_str == "ascend" + or attention_backend_str == "trtllm_mha" or global_server_args_dict["enable_two_batch_overlap"] ): seq_lens_cpu = ( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0ce13abc2..53c3d51f6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1449,6 +1449,17 @@ class ModelRunner: from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend return TRTLLMMLABackend(self) + elif self.server_args.attention_backend == "trtllm_mha": + if self.use_mla_backend: + raise ValueError( + "trtllm_mha backend can only be used with non-MLA models." + ) + from sglang.srt.layers.attention.trtllm_mha_backend import ( + TRTLLMHAAttnBackend, + ) + + return TRTLLMHAAttnBackend(self) + elif self.server_args.attention_backend == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 60d8efb9e..6c4a818ae 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -441,6 +441,23 @@ class ServerArgs: "trtllm_mla backend does not support speculative decoding yet." ) + if self.attention_backend == "trtllm_mha": + if not is_sm100_supported(): + raise ValueError( + "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." + ) + + if self.page_size not in [16, 32, 64]: + logger.warning( + f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64." + ) + self.page_size = 64 + + if self.speculative_algorithm is not None: + raise ValueError( + "trtllm_mla backend does not support speculative decoding yet." + ) + # Set page size if self.page_size is None: self.page_size = 1 @@ -1275,6 +1292,7 @@ class ServerArgs: "ascend", "triton", "trtllm_mla", + "trtllm_mha", ], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.",