diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 1d197c5da..d9868b307 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -1,7 +1,8 @@ from __future__ import annotations """ -Support attention backend for TRTLLM MLA kernels from flashinfer. +Support attention backend for TRTLLM MHA kernels from flashinfer. +The kernel supports sm100 only, with sliding window and attention sink features. """ from dataclasses import dataclass @@ -57,11 +58,6 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): # MHA-specific dimensions self.max_context_len = model_runner.model_config.context_len - self.sliding_window_size = ( - model_runner.sliding_window_size - if model_runner.sliding_window_size is not None - else -1 # -1 indicates full attention - ) self.hidden_size = config.hidden_size # Runtime parameters @@ -117,10 +113,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): metadata = TRTLLMMHAMetadata() # Get sequence information - metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) + metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) # Precompute maximum sequence length - metadata.max_seq_len_k = seq_lens.max().item() + metadata.max_seq_len_k = self.max_context_len # Precompute page table metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] @@ -149,7 +145,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): metadata = self.decode_cuda_graph_metadata[bs] max_len = seq_lens_cpu.max().item() max_seq_pages = (max_len + self.page_size - 1) // self.page_size - metadata.max_seq_len_k = max_len + metadata.max_seq_len_k = self.max_context_len metadata.cache_seqlens_int32.copy_(seq_lens) page_indices = self.req_to_token[ @@ -217,6 +213,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ) -> torch.Tensor: """Run forward for decode using TRTLLM MHA kernel.""" cache_loc = forward_batch.out_cache_loc @@ -228,7 +225,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) # shape conversion: - # [bs, page_size, num_kv_heads, head_dim] -> [bs, num_kv_heads, page_size, head_dim] + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim ).permute(0, 2, 1, 3) @@ -237,7 +234,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ).permute(0, 2, 1, 3) kv_cache = (k_cache, v_cache) - # TODO: bmm1_scale and bmm2_scale might require modification + # TODO: add support for quantization q_scale = 1.0 k_scale = ( layer.k_scale_float @@ -246,6 +243,8 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ) bmm1_scale = q_scale * k_scale * layer.scaling bmm2_scale = 1.0 + # sink: additional value per head in the denominator of the softmax. + attention_sink = kwargs.get("sinks", None) # Call TRT-LLM kernel # raw_out: like q, [bs, acc_q_len, num_q_heads, head_dim] but with output dtype @@ -258,8 +257,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): max_seq_len=self.forward_metadata.max_seq_len_k, bmm1_scale=bmm1_scale, bmm2_scale=bmm2_scale, - window_left=self.sliding_window_size, + window_left=layer.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed + sinks=attention_sink, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -272,6 +272,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + **kwargs, ): cache_loc = forward_batch.out_cache_loc if save_kv_cache and k is not None: @@ -279,6 +280,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) q = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + # [num_pages, page_size, num_kv_heads, head_dim] -> [num_pages, num_kv_heads, page_size, head_dim] k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) k_cache = k_cache.view( -1, self.page_size, layer.tp_k_head_num, layer.head_dim @@ -288,8 +290,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ).permute(0, 2, 1, 3) kv_cache = (k_cache, v_cache) - # TODO: bmm1_scale and bmm2_scale might require modification - # TODO: Change once quantization is supported + # sink: additional value per head in the denominator of the softmax. + attention_sink = kwargs.get("sinks", None) + # TODO: add support for quantization q_scale = 1.0 k_scale = ( layer.k_scale_float @@ -312,8 +315,9 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): batch_size=forward_batch.batch_size, cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, - window_left=self.sliding_window_size, + window_left=layer.sliding_window_size, # TODO: add attention_sink operation or nvfp4 scale factor if needed + sinks=attention_sink, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 53c3d51f6..fe5d2c478 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1443,13 +1443,13 @@ class ModelRunner: ) return CutlassMLABackend(self) - elif self.server_args.attention_backend == "trtllm_mla": + elif backend_str == "trtllm_mla": if not self.use_mla_backend: raise ValueError("trtllm_mla backend can only be used with MLA models.") from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend return TRTLLMMLABackend(self) - elif self.server_args.attention_backend == "trtllm_mha": + elif backend_str == "trtllm_mha": if self.use_mla_backend: raise ValueError( "trtllm_mha backend can only be used with non-MLA models." @@ -1460,7 +1460,7 @@ class ModelRunner: return TRTLLMHAAttnBackend(self) - elif self.server_args.attention_backend == "intel_amx": + elif backend_str == "intel_amx": from sglang.srt.layers.attention.intel_amx_backend import ( IntelAMXAttnBackend, ) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 58b68fb38..b523c2e1b 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -301,7 +301,7 @@ class GptOssAttention(nn.Module): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states - attn_output = self.attn(*inner_state, sinks=self.sinks) + attn_output = self.attn(*inner_state, sinks=self.sinks.to(torch.float32)) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6412398bb..605214a98 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -445,7 +445,11 @@ class ServerArgs: "trtllm_mla backend does not support speculative decoding yet." ) - if self.attention_backend == "trtllm_mha": + if ( + self.attention_backend == "trtllm_mha" + or self.decode_attention_backend == "trtllm_mha" + or self.prefill_attention_backend == "trtllm_mha" + ): if not is_sm100_supported(): raise ValueError( "TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend." @@ -459,11 +463,18 @@ class ServerArgs: if self.speculative_algorithm is not None: raise ValueError( - "trtllm_mla backend does not support speculative decoding yet." + "trtllm_mha backend does not support speculative decoding yet." ) + model_arch = self.get_hf_config().architectures[0] if model_arch in ["GptOssForCausalLM"]: - self.attention_backend = "triton" + if self.attention_backend is None: + # default is triton, but we could have trtllm_mha as an option + self.attention_backend = "triton" + assert ( + self.attention_backend == "trtllm_mha" + or self.attention_backend == "triton" + ) # Check if FlashInfer MXFP4 MoE is enabled from sglang.srt.utils import get_bool_env_var