diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index ea316150e..185764ad7 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -553,7 +553,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ): if ( forward_batch.forward_mode.is_target_verify() or forward_batch.forward_mode.is_draft_extend() @@ -591,10 +591,45 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): return_lse=forward_batch.mha_return_lse, ) else: - # replace with trtllm ragged attention once accuracy is resolved. - output = super().forward_extend( - q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope - ) + if not ( + forward_batch.attn_attend_prefix_cache is not None + and forward_batch.mha_return_lse + ): + output = super().forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope + ) + else: + # MHA for chunked prefix kv cache when running model with MLA + assert forward_batch.prefix_chunk_idx is not None + assert forward_batch.prefix_chunk_cu_seq_lens is not None + assert q_rope is None + assert k_rope is None + chunk_idx = forward_batch.prefix_chunk_idx + + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype) + v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype) + output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim) + output = flashinfer.prefill.trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=self.workspace_buffer, + seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx], + max_q_len=self.forward_prefill_metadata.max_seq_len, + max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], + bmm1_scale=layer.scaling, + bmm2_scale=1.0, + o_sf_scale=-1.0, + batch_size=forward_batch.batch_size, + window_left=-1, + cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens, + cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device), + ) return output