Refactor DeepSeek attention dispatching (#6476)

This commit is contained in:
fzyzcjy
2025-05-21 17:03:39 +08:00
committed by GitHub
parent 7c347259ff
commit d6e1d28c8a

View File

@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
# This method can avoid OOM when prefix lengths are long. # This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto() MHA_CHUNKED_KV = auto()
# Use MLA but with fused RoPE
MLA_FUSED_ROPE = auto()
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
def dispatch_attn_forward_method( def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch self, forward_batch: ForwardBatch
) -> AttnForwardMethod: ) -> AttnForwardMethod:
def _dispatch_mla_subtype():
if _is_hip:
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return AttnForwardMethod.MLA_FUSED_ROPE
else:
return AttnForwardMethod.MLA
else:
return AttnForwardMethod.MLA
if self.attention_backend == "flashinfer": if self.attention_backend == "flashinfer":
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
if ( if (
@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
elif self.attention_backend == "fa3": elif self.attention_backend == "fa3":
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences. # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
if forward_batch.extend_prefix_lens_cpu is not None: if forward_batch.extend_prefix_lens_cpu is not None:
@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
): ):
return AttnForwardMethod.MHA return AttnForwardMethod.MHA
else: else:
return AttnForwardMethod.MLA return _dispatch_mla_subtype()
def forward( def forward(
self, self,
@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv( return self.forward_normal_chunked_kv(
positions, hidden_states, forward_batch positions, hidden_states, forward_batch
) )
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else: else:
if _is_hip: raise NotImplementedError
if (
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return self.forward_absorb_fused_mla_rope(
positions, hidden_states, forward_batch
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
else:
return self.forward_absorb(
positions, hidden_states, forward_batch, zero_allocator
)
def forward_normal( def forward_normal(
self, self,