diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 97dce19fd..85e535b07 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): save_kv_cache: bool = True, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, + cos_sin_cache: Optional[torch.Tensor] = None, + is_neox: Optional[bool] = False, ) -> torch.Tensor: if forward_batch.forward_mode.is_draft_extend(): return super().forward_extend( q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope ) + # TODO refactor to avoid code duplication + merge_query = q_rope is not None + if ( + self.data_type == torch.float8_e4m3fn + ) and forward_batch.forward_mode.is_target_verify(): + # For FP8 path, we quantize the query and rope parts and merge them into a single tensor + # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend + assert all( + x is not None for x in [q_rope, k_rope, cos_sin_cache] + ), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None." + q, k, k_rope = self.quantize_and_rope_for_fp8( + q, + q_rope, + k.squeeze(1), + k_rope.squeeze(1), + forward_batch, + cos_sin_cache, + is_neox, + ) + merge_query = False + # Save KV cache if requested if save_kv_cache: assert ( @@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): layer, forward_batch.out_cache_loc, k, k_rope ) - if q_rope is not None: - q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) - q_rope = q_rope.view( + # TODO refactor to avoid code duplication + # Prepare query tensor inline + if merge_query: + # For FP16 path, we merge the query and rope parts into a single tensor + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope_reshaped = q_rope.view( -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim ) - q = _concat_mla_absorb_q_general(q, q_rope) + q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped) + else: + # For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function + q = q.view(-1, layer.tp_q_head_num, layer.head_dim) q = q.view(-1, layer.tp_q_head_num, layer.head_dim) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 73ff4c1c7..74d207103 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module): """ return ( self.current_attention_backend == "trtllm_mla" - and forward_batch.forward_mode.is_decode_or_idle() + and ( + forward_batch.forward_mode.is_decode_or_idle() + or forward_batch.forward_mode.is_target_verify() + ) and forward_batch.attn_backend.data_type == torch.float8_e4m3fn )