From e30c273bc9bbb2ea31268b22b67ba8221e058e2c Mon Sep 17 00:00:00 2001 From: xu-yfei Date: Fri, 9 May 2025 14:17:14 +0800 Subject: [PATCH] opt flashinfer mla cat (#5822) Co-authored-by: xuyongfei.xyf --- .../attention/flashinfer_mla_backend.py | 72 +++++++++++++++---- python/sglang/srt/models/deepseek_v2.py | 2 +- 2 files changed, 60 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 58982f3e8..cd7778418 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -339,22 +339,38 @@ class FlashInferMLAAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ): cache_loc = forward_batch.out_cache_loc logits_soft_cap = layer.logit_cap prefill_wrapper_paged = self.forward_metadata.prefill_wrapper - qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) # Save kv cache if save_kv_cache and k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if k_rope is not None: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, cache_loc, k, k_rope + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if q_rope is not None: + q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) if self.forward_metadata.use_ragged: # ragged prefill + if q_rope is not None: + q = torch.cat([q, q_rope], dim=-1) + qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) + if k_rope is not None: + k = torch.cat([k, k_rope], dim=-1) o = self.prefill_wrapper_ragged.forward( qall, k.view(-1, layer.tp_k_head_num, layer.head_dim), @@ -365,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend): ) else: # mla paged prefill + if q_rope is None: + qall = q.view(-1, layer.tp_q_head_num, layer.head_dim) + q, q_rope = ( + qall[:, :, : layer.v_head_dim], + qall[:, :, layer.v_head_dim :], + ) + o = q.new_empty(q.shape) o = prefill_wrapper_paged.run( - qall[:, :, : layer.v_head_dim], - qall[:, :, layer.v_head_dim :], + q, + q_rope, k_buf[:, :, : layer.v_head_dim], k_buf[:, :, layer.v_head_dim :], + out=o, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) @@ -382,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, ): decode_wrapper = self.forward_metadata.decode_wrapper cache_loc = forward_batch.out_cache_loc @@ -389,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend): if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer( - layer, - cache_loc, - k, - v, - ) + if k_rope is not None: + forward_batch.token_to_kv_pool.set_mla_kv_buffer( + layer, + cache_loc, + k, + k_rope, + ) + else: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, + cache_loc, + k, + v, + ) # Reshape inputs - reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = reshaped_q[:, :, : layer.v_head_dim] + q_rope = reshaped_q[:, :, layer.v_head_dim :] + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + o = q_nope.new_empty(q_nope.shape) # Direct call to run without the wrapper o = decode_wrapper.run( - reshaped_q[:, :, : layer.v_head_dim], - reshaped_q[:, :, layer.v_head_dim :], + q_nope, + q_rope, k_buffer[:, :, : layer.v_head_dim], k_buffer[:, :, layer.v_head_dim :], + out=o, ) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3ee7a5d76..7763fb18c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope_out.transpose(0, 1) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - if self.attention_backend == "fa3": + if self.attention_backend == "fa3" or self.attention_backend == "flashinfer": attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe )