diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 61ae74cac..38b7111b2 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -24,6 +24,7 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cuda, is_flashinfer_available +from sglang.srt.utils.common import cached_triton_kernel if is_flashinfer_available(): import flashinfer @@ -50,6 +51,7 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB TRTLLM_BLOCK_CONSTRAINT = 128 +@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"])) @triton.jit def pad_draft_extend_query_kernel( q_ptr, # Input query tensor [total_seq_len, num_heads, head_dim] @@ -123,6 +125,7 @@ def pad_draft_extend_query_kernel( ) +@cached_triton_kernel(lambda _, kwargs: (kwargs["BLOCK_SIZE"])) @triton.jit def unpad_draft_extend_output_kernel( raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim) @@ -389,7 +392,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): if ( not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify() - and not forward_mode.is_draft_extend() + and not forward_mode.is_draft_extend(include_v2=True) ): return super().init_forward_metadata_capture_cuda_graph( bs, @@ -429,7 +432,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): block_kv_indices, max_seq_len_val, ) - if forward_mode.is_draft_extend(): + if forward_mode.is_draft_extend(include_v2=True): num_tokens_per_bs = num_tokens // bs metadata.max_seq_len_q = num_tokens_per_bs + 1 metadata.sum_seq_lens_q = num_tokens_per_bs * bs @@ -462,7 +465,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): if ( not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify() - and not forward_mode.is_draft_extend() + and not forward_mode.is_draft_extend(include_v2=True) ): return super().init_forward_metadata_replay_cuda_graph( bs, @@ -481,7 +484,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): metadata = self.decode_cuda_graph_metadata[bs] - if forward_mode.is_draft_extend(): + if forward_mode.is_draft_extend(include_v2=True): accept_length = spec_info.accept_length[:bs] if spec_info.accept_length_cpu: metadata.max_seq_len_q = max(spec_info.accept_length_cpu[:bs]) @@ -523,7 +526,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): if ( forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() + and not forward_batch.forward_mode.is_draft_extend(include_v2=True) ): if self.disable_chunked_prefix_cache: super().init_forward_metadata(forward_batch) @@ -544,7 +547,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): elif ( forward_batch.forward_mode.is_decode_or_idle() or forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): bs = forward_batch.batch_size @@ -573,7 +576,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): self.forward_decode_metadata = TRTLLMMLADecodeMetadata( block_kv_indices, max_seq_len_val ) - if forward_batch.forward_mode.is_draft_extend(): + if forward_batch.forward_mode.is_draft_extend(include_v2=True): max_seq = forward_batch.seq_lens_cpu.max().item() sum_seq_lens_q = sum(forward_batch.extend_seq_lens_cpu) @@ -922,7 +925,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): if ( forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() + or forward_batch.forward_mode.is_draft_extend(include_v2=True) ): metadata = ( getattr(forward_batch, "decode_trtllm_mla_metadata", None) @@ -994,7 +997,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # Reshape output directly without slicing - if forward_batch.forward_mode.is_draft_extend(): + if forward_batch.forward_mode.is_draft_extend(include_v2=True): raw_out = self.unpad_draft_extend_output( raw_out, metadata.cu_seqlens_q,