diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index ee69c0cb9..e37071697 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -51,6 +51,7 @@ class TRTLLMMLADecodeMetadata: workspace: Optional[torch.Tensor] = None block_kv_indices: Optional[torch.Tensor] = None + max_seq_len: Optional[int] = None class TRTLLMMLABackend(FlashInferMLAAttnBackend): @@ -207,8 +208,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ) # Custom fast-path for decode/idle. - max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) - block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad] + # Capture with full width so future longer sequences are safe during replay + max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) + block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq] create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, @@ -217,13 +219,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): None, block_kv_indices, self.req_to_token.stride(0), - max_seqlen_pad, + max_blocks_per_seq, NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK, PAGED_SIZE=self.page_size, ) + # Record the true maximum sequence length for this capture batch so that + # the kernel launch path (which requires an int not a tensor) can reuse + # it safely during both capture and replay. + max_seq_len_val = int(seq_lens.max().item()) + metadata = TRTLLMMLADecodeMetadata( - self.decode_cuda_graph_workspace, block_kv_indices + self.decode_cuda_graph_workspace, + block_kv_indices, + max_seq_len_val, ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -268,6 +277,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): PAGED_SIZE=self.page_size, ) + # Update stored max_seq_len so subsequent kernel calls use the correct value + # Prefer CPU tensor to avoid GPU synchronization when available. + if seq_lens_cpu is not None: + metadata.max_seq_len = int(seq_lens_cpu.max().item()) + else: + metadata.max_seq_len = int(seq_lens.max().item()) + def get_cuda_graph_seq_len_fill_value(self) -> int: """Get the fill value for sequence lengths in CUDA graph.""" return 1 @@ -295,8 +311,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): forward_batch.seq_lens.device, ) + max_seq_len_val = int(max_seq) self.forward_metadata = TRTLLMMLADecodeMetadata( - self.workspace_buffer, block_kv_indices + self.workspace_buffer, block_kv_indices, max_seq_len_val ) forward_batch.decode_trtllm_mla_metadata = self.forward_metadata @@ -471,14 +488,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): qk_rope_head_dim=self.qk_rope_head_dim, block_tables=metadata.block_kv_indices, seq_lens=forward_batch.seq_lens.to(torch.int32), - max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size), + max_seq_len=metadata.max_seq_len, bmm1_scale=bmm1_scale, ) - # Extract value projection part and reshape - raw_out_v = raw_out[..., : layer.v_head_dim].contiguous() - output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) - + # Reshape output directly without slicing + output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 18a7f77ea..b2017066b 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -208,6 +208,15 @@ class MockModelRunner: self.kv_cache_dtype = config["kv_cache_dtype"] self.page_size = config["page_size"] + # Server args stub - needed by attention backends + self.server_args = type( + "ServerArgs", + (), + { + "enable_dp_attention": False, # Default value for testing + }, + ) + # Model-config stub with MLA attributes self.model_config = type( "ModelConfig", @@ -833,7 +842,7 @@ class TestTRTLLMMLA(CustomTestCase): # Test workspace properties self.assertEqual(metadata.workspace.device.type, "cuda") - self.assertEqual(metadata.workspace.dtype, torch.int8) + self.assertEqual(metadata.workspace.dtype, torch.uint8) self.assertGreater( metadata.workspace.numel(), 0, "Workspace should have non-zero size" ) @@ -993,8 +1002,8 @@ class TestTRTLLMMLA(CustomTestCase): ) # Verify CUDA graph buffers are allocated - self.assertIsNotNone(backend.cuda_graph_kv_indices) - self.assertIsNotNone(backend.cuda_graph_workspace) + self.assertIsNotNone(backend.decode_cuda_graph_kv_indices) + self.assertIsNotNone(backend.decode_cuda_graph_workspace) # Test capture metadata seq_lens = torch.full(