Refactor DeepSeek attention dispatching (#6476)
This commit is contained in:
@@ -127,6 +127,9 @@ class AttnForwardMethod(IntEnum):
|
||||
# This method can avoid OOM when prefix lengths are long.
|
||||
MHA_CHUNKED_KV = auto()
|
||||
|
||||
# Use MLA but with fused RoPE
|
||||
MLA_FUSED_ROPE = auto()
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
@@ -609,6 +612,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> AttnForwardMethod:
|
||||
def _dispatch_mla_subtype():
|
||||
if _is_hip:
|
||||
if (
|
||||
self.rocm_fused_decode_mla
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
):
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
if self.attention_backend == "flashinfer":
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
if (
|
||||
@@ -620,7 +635,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
return _dispatch_mla_subtype()
|
||||
elif self.attention_backend == "fa3":
|
||||
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
if forward_batch.extend_prefix_lens_cpu is not None:
|
||||
@@ -637,7 +652,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
return _dispatch_mla_subtype()
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
@@ -648,7 +663,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
return _dispatch_mla_subtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -671,23 +686,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
return self.forward_normal_chunked_kv(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA:
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
||||
return self.forward_absorb_fused_mla_rope(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
if _is_hip:
|
||||
if (
|
||||
self.rocm_fused_decode_mla
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
):
|
||||
return self.forward_absorb_fused_mla_rope(
|
||||
positions, hidden_states, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
else:
|
||||
return self.forward_absorb(
|
||||
positions, hidden_states, forward_batch, zero_allocator
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user