From b1bb8e7490d19b02ae7391cf8162fd4ce0976867 Mon Sep 17 00:00:00 2001 From: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> Date: Mon, 22 Sep 2025 15:54:00 -0700 Subject: [PATCH] Enables TRT-LLM backend to be used for target_verify (#10281) Co-authored-by: Pranav Marathe Co-authored-by: fzyzcjy --- .../layers/attention/trtllm_mla_backend.py | 200 ++++++++++++------ 1 file changed, 134 insertions(+), 66 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 185764ad7..3613afd17 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -127,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): "disable_chunked_prefix_cache" ] + self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + def _calc_padded_blocks(self, max_seq_len: int) -> int: """ Calculate padded block count that satisfies both TRT-LLM and Triton constraints. @@ -217,7 +219,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): """Initialize metadata for CUDA graph capture.""" # Delegate to parent for non-decode modes. - if not forward_mode.is_decode_or_idle(): + if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify(): return super().init_forward_metadata_capture_cuda_graph( bs, num_tokens, @@ -228,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): spec_info, ) + if forward_mode.is_target_verify(): + seq_lens = seq_lens + self.num_draft_tokens + # Custom fast-path for decode/idle. # Capture with full width so future longer sequences are safe during replay max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) @@ -270,7 +275,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): ): """Replay CUDA graph with new inputs.""" # Delegate to parent for non-decode modes. - if not forward_mode.is_decode_or_idle(): + if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify(): return super().init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, @@ -282,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): seq_lens_cpu, ) + if forward_mode.is_target_verify(): + seq_lens = seq_lens + self.num_draft_tokens + del seq_lens_sum # not handle "num_draft_tokens" but we do not need it + metadata = self.decode_cuda_graph_metadata[bs] # Update block indices for new sequences. @@ -332,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): cum_seq_lens_q, seq_lens, ) - elif forward_batch.forward_mode.is_decode_or_idle(): + elif ( + forward_batch.forward_mode.is_decode_or_idle() + or forward_batch.forward_mode.is_target_verify() + ): bs = forward_batch.batch_size # Get maximum sequence length. @@ -341,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): else: max_seq = forward_batch.seq_lens.max().item() + seq_lens = forward_batch.seq_lens + + if forward_batch.forward_mode.is_target_verify(): + max_seq = max_seq + self.num_draft_tokens + seq_lens = seq_lens + self.num_draft_tokens + max_seqlen_pad = self._calc_padded_blocks(max_seq) block_kv_indices = self._create_block_kv_indices( bs, max_seqlen_pad, forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.seq_lens.device, + seq_lens, + seq_lens.device, ) max_seq_len_val = int(max_seq) @@ -553,84 +571,134 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, - ): - if ( - forward_batch.forward_mode.is_target_verify() - or forward_batch.forward_mode.is_draft_extend() - ): - return super().forward_extend( - q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope - ) - # chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel - if forward_batch.attn_attend_prefix_cache is None: + ) -> torch.Tensor: + if forward_batch.forward_mode.is_draft_extend(): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) - if not forward_batch.attn_attend_prefix_cache: - q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k = k.view(-1, layer.tp_k_head_num, layer.head_dim) - v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) - output = flashinfer.prefill.trtllm_ragged_attention_deepseek( + # Save KV cache if requested + if save_kv_cache: + assert ( + k is not None and k_rope is not None + ), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None." + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, forward_batch.out_cache_loc, k, k_rope + ) + + if q_rope is not None: + q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + q = torch.cat([q, q_rope], dim=-1) + + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + + if k_rope is not None: + k = torch.cat([k, k_rope], dim=-1) + k = k.view(-1, layer.tp_k_head_num, layer.head_dim) + + v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) + + if forward_batch.forward_mode.is_target_verify(): + metadata = ( + getattr(forward_batch, "decode_trtllm_mla_metadata", None) + or self.forward_decode_metadata + ) + + # Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim] + bs = forward_batch.batch_size + q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) + + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1) + + q_scale = 1.0 + k_scale = ( + layer.k_scale_float + if getattr(layer, "k_scale_float", None) is not None + else 1.0 + ) + + bmm1_scale = q_scale * k_scale * layer.scaling + + seq_lens = ( + forward_batch.seq_lens.to(torch.int32) + + forward_batch.spec_info.draft_token_num + ) + max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num + + # TODO may use `mla_rope_quantize_fp8` fusion + q = q.to(self.data_type) + assert kv_cache.dtype == self.data_type + + raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=q, + kv_cache=kv_cache, + workspace_buffer=self.workspace_buffer, + qk_nope_head_dim=self.qk_nope_head_dim, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + block_tables=metadata.block_kv_indices, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=bmm1_scale, + ) + + # Reshape output directly without slicing + output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim) + return output + + if forward_batch.attn_attend_prefix_cache: + # MHA for chunked prefix kv cache when running model with MLA + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert q_rope is None + assert k_rope is None + chunk_idx = forward_batch.prefix_chunk_idx + + output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim) + return flashinfer.prefill.trtllm_ragged_attention_deepseek( query=q, key=k, value=v, workspace_buffer=self.workspace_buffer, - seq_lens=self.forward_prefill_metadata.seq_lens, + seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx], max_q_len=self.forward_prefill_metadata.max_seq_len, - max_kv_len=self.forward_prefill_metadata.max_seq_len, + max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], bmm1_scale=layer.scaling, bmm2_scale=1.0, - o_sf_scale=1.0, + o_sf_scale=-1.0, batch_size=forward_batch.batch_size, window_left=-1, cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, - cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens, + cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], enable_pdl=False, - is_causal=True, - return_lse=forward_batch.mha_return_lse, + is_causal=False, + return_lse=True, + out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device), ) - else: - if not ( - forward_batch.attn_attend_prefix_cache is not None - and forward_batch.mha_return_lse - ): - output = super().forward_extend( - q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope - ) - else: - # MHA for chunked prefix kv cache when running model with MLA - assert forward_batch.prefix_chunk_idx is not None - assert forward_batch.prefix_chunk_cu_seq_lens is not None - assert q_rope is None - assert k_rope is None - chunk_idx = forward_batch.prefix_chunk_idx - q = q.view(-1, layer.tp_q_head_num, layer.head_dim) - k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype) - v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype) - output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim) - output = flashinfer.prefill.trtllm_ragged_attention_deepseek( - query=q, - key=k, - value=v, - workspace_buffer=self.workspace_buffer, - seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx], - max_q_len=self.forward_prefill_metadata.max_seq_len, - max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], - bmm1_scale=layer.scaling, - bmm2_scale=1.0, - o_sf_scale=-1.0, - batch_size=forward_batch.batch_size, - window_left=-1, - cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, - cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], - enable_pdl=False, - is_causal=False, - return_lse=True, - out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device), - ) - return output + return flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self.workspace_buffer, + seq_lens=self.forward_prefill_metadata.seq_lens, + max_q_len=self.forward_prefill_metadata.max_seq_len, + max_kv_len=self.forward_prefill_metadata.max_seq_len, + bmm1_scale=layer.scaling, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=forward_batch.batch_size, + window_left=-1, + cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, + cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens, + enable_pdl=False, + is_causal=True, + return_lse=forward_batch.mha_return_lse, + ) class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):