From d6e1d28c8aff0ba6cca1d33985d1e1663a734f6f Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Wed, 21 May 2025 17:03:39 +0800 Subject: [PATCH] Refactor DeepSeek attention dispatching (#6476) --- python/sglang/srt/models/deepseek_v2.py | 46 +++++++++++++++---------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0974102e3..b90fadec4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum): # This method can avoid OOM when prefix lengths are long. MHA_CHUNKED_KV = auto() + # Use MLA but with fused RoPE + MLA_FUSED_ROPE = auto() + class DeepseekV2MLP(nn.Module): def __init__( @@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module): def dispatch_attn_forward_method( self, forward_batch: ForwardBatch ) -> AttnForwardMethod: + def _dispatch_mla_subtype(): + if _is_hip: + if ( + self.rocm_fused_decode_mla + and forward_batch.forward_mode.is_decode() + ): + return AttnForwardMethod.MLA_FUSED_ROPE + else: + return AttnForwardMethod.MLA + else: + return AttnForwardMethod.MLA + if self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill if ( @@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module): ): return AttnForwardMethod.MHA else: - return AttnForwardMethod.MLA + return _dispatch_mla_subtype() elif self.attention_backend == "fa3": # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. if forward_batch.extend_prefix_lens_cpu is not None: @@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module): ): return AttnForwardMethod.MHA_CHUNKED_KV else: - return AttnForwardMethod.MLA + return _dispatch_mla_subtype() else: # Triton: Use normal computation for prefill and use weight absorption for extend/decode if ( @@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module): ): return AttnForwardMethod.MHA else: - return AttnForwardMethod.MLA + return _dispatch_mla_subtype() def forward( self, @@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module): return self.forward_normal_chunked_kv( positions, hidden_states, forward_batch ) + elif attn_forward_method == AttnForwardMethod.MLA: + return self.forward_absorb( + positions, hidden_states, forward_batch, zero_allocator + ) + elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: + return self.forward_absorb_fused_mla_rope( + positions, hidden_states, forward_batch + ) else: - if _is_hip: - if ( - self.rocm_fused_decode_mla - and forward_batch.forward_mode.is_decode() - ): - return self.forward_absorb_fused_mla_rope( - positions, hidden_states, forward_batch - ) - else: - return self.forward_absorb( - positions, hidden_states, forward_batch, zero_allocator - ) - else: - return self.forward_absorb( - positions, hidden_states, forward_batch, zero_allocator - ) + raise NotImplementedError def forward_normal( self,