diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index b3acc8b01..05e9bef80 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -96,6 +96,7 @@ class FlashInferMhaChunkKVRunner: def update_wrapper( self, forward_batch: ForwardBatch, + disable_flashinfer_ragged: bool = False, ): assert forward_batch.num_prefix_chunks is not None num_prefix_chunks = forward_batch.num_prefix_chunks @@ -128,16 +129,17 @@ class FlashInferMhaChunkKVRunner: causal=False, ) # ragged prefill - self.ragged_wrapper.begin_forward( - qo_indptr=qo_indptr, - kv_indptr=qo_indptr, - num_qo_heads=self.num_local_heads, - num_kv_heads=self.num_local_heads, - head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, - head_dim_vo=self.v_head_dim, - q_data_type=self.q_data_type, - causal=True, - ) + if not disable_flashinfer_ragged: + self.ragged_wrapper.begin_forward( + qo_indptr=qo_indptr, + kv_indptr=qo_indptr, + num_qo_heads=self.num_local_heads, + num_kv_heads=self.num_local_heads, + head_dim_qk=self.qk_nope_head_dim + self.qk_rope_head_dim, + head_dim_vo=self.v_head_dim, + q_data_type=self.q_data_type, + causal=True, + ) def forward( self, @@ -491,9 +493,11 @@ class FlashInferMLAAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 1 - def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): + def init_mha_chunk_metadata( + self, forward_batch: ForwardBatch, disable_flashinfer_ragged: bool = False + ): """Init the metadata for a forward pass.""" - self.mha_chunk_kv_cache.update_wrapper(forward_batch) + self.mha_chunk_kv_cache.update_wrapper(forward_batch, disable_flashinfer_ragged) def forward_extend( self, diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index e37071697..408a66257 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -45,6 +45,15 @@ TRTLLM_BLOCK_CONSTRAINT = 128 global_zero_init_workspace_buffer = None +@dataclass +class TRTLLMMLAPrefillMetadata: + """Metadata for TRTLLM MLA prefill operations.""" + + max_seq_len: int + cum_seq_lens: torch.Tensor + seq_lens: torch.Tensor + + @dataclass class TRTLLMMLADecodeMetadata: """Metadata for TRTLLM MLA decode operations.""" @@ -101,7 +110,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # CUDA graph state self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_kv_indices = None - self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None + self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None + self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None def _calc_padded_blocks(self, max_seq_len: int) -> int: """ @@ -235,7 +245,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): max_seq_len_val, ) self.decode_cuda_graph_metadata[bs] = metadata - self.forward_metadata = metadata + self.forward_decode_metadata = metadata def init_forward_metadata_replay_cuda_graph( self, @@ -291,31 +301,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" # Delegate to parent for non-decode modes. - if not forward_batch.forward_mode.is_decode_or_idle(): + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens + cum_seq_lens_q = torch.cat( + ( + torch.tensor([0], device=forward_batch.seq_lens.device), + torch.cumsum(seq_lens, dim=0), + ) + ).int() + max_seq_len = max(forward_batch.extend_seq_lens_cpu) + self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata( + max_seq_len, + cum_seq_lens_q, + seq_lens, + ) + elif forward_batch.forward_mode.is_decode_or_idle(): + bs = forward_batch.batch_size + + # Get maximum sequence length. + if getattr(forward_batch, "seq_lens_cpu", None) is not None: + max_seq = forward_batch.seq_lens_cpu.max().item() + else: + max_seq = forward_batch.seq_lens.max().item() + + max_seqlen_pad = self._calc_padded_blocks(max_seq) + block_kv_indices = self._create_block_kv_indices( + bs, + max_seqlen_pad, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens.device, + ) + + max_seq_len_val = int(max_seq) + self.forward_decode_metadata = TRTLLMMLADecodeMetadata( + self.workspace_buffer, block_kv_indices, max_seq_len_val + ) + forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata + else: return super().init_forward_metadata(forward_batch) - bs = forward_batch.batch_size - - # Get maximum sequence length. - if getattr(forward_batch, "seq_lens_cpu", None) is not None: - max_seq = forward_batch.seq_lens_cpu.max().item() - else: - max_seq = forward_batch.seq_lens.max().item() - - max_seqlen_pad = self._calc_padded_blocks(max_seq) - block_kv_indices = self._create_block_kv_indices( - bs, - max_seqlen_pad, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens.device, - ) - - max_seq_len_val = int(max_seq) - self.forward_metadata = TRTLLMMLADecodeMetadata( - self.workspace_buffer, block_kv_indices, max_seq_len_val - ) - forward_batch.decode_trtllm_mla_metadata = self.forward_metadata + def init_mha_chunk_metadata(self, forward_batch: ForwardBatch): + super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True) def quantize_and_rope_for_fp8( self, @@ -459,7 +490,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # Get metadata metadata = ( getattr(forward_batch, "decode_trtllm_mla_metadata", None) - or self.forward_metadata + or self.forward_decode_metadata ) # Scale computation for TRTLLM MLA kernel BMM1 operation: @@ -496,6 +527,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if ( + forward_batch.forward_mode.is_target_verify() + or forward_batch.forward_mode.is_draft_extend() + ): + return super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + + if not forward_batch.attn_attend_prefix_cache: + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k = k.view(-1, layer.tp_k_head_num, layer.head_dim) + v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) + output = flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self.workspace_buffer, + seq_lens=self.forward_prefill_metadata.seq_lens, + max_q_len=self.forward_prefill_metadata.max_seq_len, + max_kv_len=self.forward_prefill_metadata.max_seq_len, + bmm1_scale=layer.scaling, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=forward_batch.batch_size, + window_left=-1, + cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, + cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens, + enable_pdl=False, + is_causal=True, + return_lse=forward_batch.mha_return_lse, + ) + else: + # replace with trtllm ragged attention once accuracy is resolved. + output = super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + return output + class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): """Multi-step draft backend for TRT-LLM MLA used by EAGLE.""" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 32726d11b..1a56e87c6 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1050,7 +1050,6 @@ class DeepseekV2AttentionMLA(nn.Module): attention_backend == "flashinfer" or attention_backend == "fa3" or attention_backend == "flashmla" - or attention_backend == "trtllm_mla" or attention_backend == "cutlass_mla" ): # Use MHA with chunked KV cache when prefilling on long sequences. @@ -1079,6 +1078,15 @@ class DeepseekV2AttentionMLA(nn.Module): return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype() + elif attention_backend == "trtllm_mla": + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return AttnForwardMethod.MHA_CHUNKED_KV + else: + return _dispatch_mla_subtype() elif attention_backend == "aiter": if ( forward_batch.forward_mode.is_extend() diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index b2017066b..6f610baf0 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -41,6 +41,10 @@ DEFAULT_CONFIG = { "v_head_dim": 512, "num_kv_heads": 1, "layer_id": 0, + "tp_q_head_num": 128, + "tp_k_head_num": 128, + "prefill_head_dim": 192, + "prefill_v_head_dim": 128, } ROPE_BASE = 10000 @@ -92,7 +96,7 @@ TEST_CASES = { "description": "Medium-scale batch", }, ], - "decode_output_match": [ + "output_match": [ { "name": "single_fp16", "batch_size": 1, @@ -322,7 +326,7 @@ class TestTRTLLMMLA(CustomTestCase): config.update(test_case) return config - def _create_model_components(self, config): + def _create_model_components(self, config, is_prefill=False): """Create model runners, backends, and layer for testing.""" # Create model runners model_runner_trtllm = MockModelRunner(config) @@ -332,14 +336,23 @@ class TestTRTLLMMLA(CustomTestCase): trtllm_backend = TRTLLMMLABackend(model_runner_trtllm) reference_backend = FlashInferMLAAttnBackend(model_runner_reference) + head_dim = ( + config["kv_lora_rank"] + config["qk_rope_head_dim"] + if not is_prefill + else config["prefill_head_dim"] + ) + v_head_dim = ( + config["v_head_dim"] if not is_prefill else config["prefill_v_head_dim"] + ) + # Create RadixAttention layer layer = RadixAttention( num_heads=config["num_attention_heads"], - head_dim=config["kv_lora_rank"] + config["qk_rope_head_dim"], + head_dim=head_dim, scaling=model_runner_trtllm.model_config.scaling, num_kv_heads=config["num_kv_heads"], layer_id=config["layer_id"], - v_head_dim=config["v_head_dim"], + v_head_dim=v_head_dim, prefix="attn_mqa", ) @@ -524,7 +537,7 @@ class TestTRTLLMMLA(CustomTestCase): """Test that TRTLLM and FlashInfer MLA backends produce matching outputs.""" print(f"\nRunning decode output matching tests...") - for test_case in TEST_CASES["decode_output_match"]: + for test_case in TEST_CASES["output_match"]: with self.subTest(test_case=test_case["name"]): print(f" Testing {test_case['name']}: {test_case['description']}") @@ -1099,6 +1112,157 @@ class TestTRTLLMMLA(CustomTestCase): self.assertIsNotNone(metadata_3.block_kv_indices) self.assertEqual(metadata_3.block_kv_indices.shape[0], config["batch_size"]) + def test_prefill_output_match_self_attention(self): + """Test prefill (forward) behavior of TRTLLM MLA backend vs reference.""" + print(f"\nRunning prefill output tests...") + + for test_case in TEST_CASES["output_match"][:2]: # Just a subset for speed + with self.subTest(test_case=test_case["name"]): + print( + f"Prefill Testing {test_case['name']}: {test_case['description']}" + ) + + config = self._merge_config(test_case) + batch_size = config["batch_size"] + max_seq_len = config["max_seq_len"] + + # Create components + ( + model_runner_trtllm, + model_runner_reference, + trtllm_backend, + reference_backend, + layer, + ) = self._create_model_components(config, is_prefill=True) + + # Prefill uses full sequences + seq_lens = torch.full( + (batch_size,), max_seq_len, device=config["device"] + ) + + def _create_forward_batch_prefill( + batch_size, + seq_lens, + extend_prefix_lens, + backend, + model_runner, + config, + ): + """Create a forward batch for the given backend.""" + + fb = ForwardBatch( + batch_size=batch_size, + input_ids=torch.randint( + 0, 100, (batch_size, 1), device=config["device"] + ), + out_cache_loc=torch.arange(batch_size, device=config["device"]), + seq_lens_sum=int(seq_lens.sum().item()), + extend_prefix_lens=extend_prefix_lens, + extend_prefix_lens_cpu=extend_prefix_lens.cpu().int().tolist(), + extend_seq_lens_cpu=(seq_lens - extend_prefix_lens) + .cpu() + .int() + .tolist(), + forward_mode=ForwardMode.EXTEND, + req_pool_indices=torch.arange( + batch_size, device=config["device"] + ), + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + attn_attend_prefix_cache=False, + mha_return_lse=False, + attn_backend=backend, + ) + fb.req_to_token_pool = model_runner.req_to_token_pool + fb.token_to_kv_pool = model_runner.token_to_kv_pool + + # Add position information for RoPE + fb.positions = torch.arange(batch_size, device=config["device"]) + + return fb + + # Create forward batches + fb_trtllm = _create_forward_batch_prefill( + batch_size, + seq_lens.clone(), + torch.zeros(batch_size, device=config["device"], dtype=torch.int32), + trtllm_backend, + model_runner_trtllm, + config, + ) + fb_reference = _create_forward_batch_prefill( + batch_size, + seq_lens.clone(), + torch.zeros(batch_size, device=config["device"], dtype=torch.int32), + reference_backend, + model_runner_reference, + config, + ) + + # Initialize metadata for both backends + trtllm_backend.init_forward_metadata(fb_trtllm) + reference_backend.init_forward_metadata(fb_reference) + + # Create Q, K, V tensors for prefill + torch.manual_seed(config["seed_qkv"]) + + def _create_qkv_tensors_prefill( + batch_size, seq_len, config, dtype_override=None + ): + """Create Q, K, V tensors for prefill, using config for head_num and head_dim.""" + device = config["device"] + dtype = dtype_override or config["dtype"] + + total_tokens = batch_size * seq_len + + tp_q_head_num = config["tp_q_head_num"] + tp_k_head_num = config["tp_k_head_num"] + head_dim = config["prefill_head_dim"] + v_head_dim = config["prefill_v_head_dim"] + + q = torch.randn( + (total_tokens, tp_q_head_num * head_dim), + dtype=dtype, + device=device, + ) + k = torch.randn( + (total_tokens, tp_k_head_num * head_dim), + dtype=dtype, + device=device, + ) + v = torch.randn( + (total_tokens, tp_k_head_num * v_head_dim), + dtype=dtype, + device=device, + ) + + # Reshape as requested + q = q.view(-1, tp_q_head_num, head_dim) + k = k.view(-1, tp_k_head_num, head_dim) + v = v.view(-1, tp_k_head_num, v_head_dim) + + return q, k, v + + q, k, v = _create_qkv_tensors_prefill(batch_size, max_seq_len, config) + # Run prefill on both backends + out_trtllm = trtllm_backend.forward_extend( + q, k, v, layer, fb_trtllm, False + ).view(-1, layer.tp_q_head_num * layer.v_head_dim) + out_reference = reference_backend.forward_extend( + q, k, v, layer, fb_reference, False + ) + + tolerance = config.get("tolerance", 1e-2) + comparison_passed = compare_outputs( + out_trtllm, out_reference, tolerance=tolerance + ) + self.assertTrue( + comparison_passed, + f"TRTLLM and Reference prefill outputs differ beyond tolerance. " + f"Config: {test_case['name']}, " + f"Max diff: {(out_trtllm - out_reference).abs().max().item()}", + ) + if __name__ == "__main__": unittest.main()