diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 2ad35914a..52bcd5fba 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -62,6 +62,7 @@ class AttentionBackend(ABC): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): """Run forward on an attention layer.""" if forward_batch.forward_mode.is_decode(): @@ -72,6 +73,7 @@ class AttentionBackend(ABC): layer, forward_batch, save_kv_cache=save_kv_cache, + **kwargs, ) else: return self.forward_extend( @@ -81,6 +83,7 @@ class AttentionBackend(ABC): layer, forward_batch, save_kv_cache=save_kv_cache, + **kwargs, ) def forward_decode( diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 4bdb21820..c3f533b27 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, ): if k is not None: assert v is not None @@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend): c_kv_cache = c_kv.view( -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim ) - q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - q_nope = q_all[:, :, : layer.v_head_dim] - q_rope = q_all[:, :, layer.v_head_dim :] + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] result = flash_attn_with_kvcache( q=q_rope, @@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend): layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, + # For multi-head latent attention + q_rope: Optional[torch.Tensor] = None, ) -> torch.Tensor: if k is not None: assert v is not None @@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend): -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim ) - q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) - q_nope = q_all[:, :, : layer.v_head_dim] - q_rope = q_all[:, :, layer.v_head_dim :] + if q_rope is not None: + q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim) + q_rope = q_rope.view( + -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim + ) + else: + q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + q_nope = q_all[:, :, : layer.v_head_dim] + q_rope = q_all[:, :, layer.v_head_dim :] max_seqlen_q = metadata.max_seq_len_q result = flash_attn_with_kvcache( diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 3c10a3924..741937a72 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -87,6 +87,7 @@ class RadixAttention(nn.Module): v, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): if k is not None: # For cross-layer sharing, kv can be None @@ -95,5 +96,11 @@ class RadixAttention(nn.Module): v = v.view(-1, self.tp_v_head_num, self.v_head_dim) return forward_batch.attn_backend.forward( - q, k, v, self, forward_batch, save_kv_cache + q, + k, + v, + self, + forward_batch, + save_kv_cache, + **kwargs, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c5afc559d..8fab2c488 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - q = torch.cat([q_nope_out, q_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1) - attn_output = self.attn_mqa(q, k, k_nope, forward_batch) + if self.attention_backend == "fa3": + attn_output = self.attn_mqa( + q_nope_out, k, k_nope, forward_batch, q_rope=q_pe + ) + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) if self.use_deep_gemm_bmm: