[Auto Sync] Update deepseek_v2.py (20250920) (#10683)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -177,6 +177,20 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
||||
"fa3",
|
||||
"flashinfer",
|
||||
"cutlass_mla",
|
||||
"trtllm_mla",
|
||||
"ascend",
|
||||
]
|
||||
|
||||
|
||||
def add_forward_absorb_core_attention_backend(backend_name):
|
||||
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
|
||||
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
||||
|
||||
|
||||
class AttnForwardMethod(IntEnum):
|
||||
# Use multi-head attention
|
||||
@@ -196,6 +210,134 @@ class AttnForwardMethod(IntEnum):
|
||||
MLA_FUSED_ROPE_CPU = auto()
|
||||
|
||||
|
||||
def _dispatch_mla_subtype(attn, forward_batch):
|
||||
if _is_hip:
|
||||
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
class BackendRegistry:
|
||||
_handlers = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, backend_name, handler_func):
|
||||
cls._handlers[backend_name] = handler_func
|
||||
|
||||
@classmethod
|
||||
def get_handler(cls, backend_name):
|
||||
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
||||
|
||||
|
||||
def handle_ascend(attn, forward_batch):
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def _get_sum_extend_prefix_lens(forward_batch):
|
||||
return (
|
||||
sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.extend_prefix_lens_cpu is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
def _is_extend_without_speculative(forward_batch):
|
||||
return (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
)
|
||||
|
||||
|
||||
def _handle_backend(attn, forward_batch, backend_name):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
disable_ragged = (
|
||||
backend_name in ["flashinfer", "flashmla"]
|
||||
) and attn.flashinfer_mla_disable_ragged
|
||||
|
||||
if (
|
||||
not disable_ragged
|
||||
and _is_extend_without_speculative(forward_batch)
|
||||
and (
|
||||
(
|
||||
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
|
||||
and not attn.disable_chunked_prefix_cache
|
||||
)
|
||||
or sum_extend_prefix_lens == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype(attn, forward_batch)
|
||||
|
||||
|
||||
def handle_flashinfer(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "flashinfer")
|
||||
|
||||
|
||||
def handle_fa3(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "fa3")
|
||||
|
||||
|
||||
def handle_flashmla(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "flashmla")
|
||||
|
||||
|
||||
def handle_cutlass_mla(attn, forward_batch):
|
||||
return _handle_backend(attn, forward_batch, "cutlass_mla")
|
||||
|
||||
|
||||
def handle_fa4(attn, forward_batch):
|
||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
|
||||
|
||||
def handle_trtllm_mla(attn, forward_batch):
|
||||
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||
if _is_extend_without_speculative(forward_batch) and (
|
||||
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype(attn, forward_batch)
|
||||
|
||||
|
||||
def handle_aiter(attn, forward_batch):
|
||||
if _is_extend_without_speculative(forward_batch):
|
||||
if is_dp_attention_enabled():
|
||||
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
|
||||
def handle_triton(attn, forward_batch):
|
||||
if (
|
||||
_is_extend_without_speculative(forward_batch)
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return _dispatch_mla_subtype(attn, forward_batch)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1039,23 +1181,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
def dispatch_attn_forward_method(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> AttnForwardMethod:
|
||||
def _dispatch_mla_subtype():
|
||||
if _is_hip:
|
||||
if (
|
||||
self.rocm_fused_decode_mla
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
):
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
||||
self
|
||||
):
|
||||
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
|
||||
# Determine attention backend used by current forward batch
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||
@@ -1072,94 +1197,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||
self.current_attention_backend = attention_backend
|
||||
|
||||
if attention_backend == "ascend":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
elif (
|
||||
attention_backend == "flashinfer"
|
||||
or attention_backend == "fa3"
|
||||
or attention_backend == "flashmla"
|
||||
or attention_backend == "cutlass_mla"
|
||||
):
|
||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
||||
sum_extend_prefix_lens = (
|
||||
sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.extend_prefix_lens_cpu is not None
|
||||
else 0
|
||||
)
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
disable_ragged = (
|
||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
||||
) and self.flashinfer_mla_disable_ragged
|
||||
|
||||
if (
|
||||
not disable_ragged
|
||||
and forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and (
|
||||
(
|
||||
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
||||
and not self.disable_chunked_prefix_cache
|
||||
)
|
||||
or sum_extend_prefix_lens == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "fa4":
|
||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
elif attention_backend == "trtllm_mla":
|
||||
sum_extend_prefix_lens = (
|
||||
sum(forward_batch.extend_prefix_lens_cpu)
|
||||
if forward_batch.extend_prefix_lens_cpu is not None
|
||||
else 0
|
||||
)
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and (
|
||||
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||
)
|
||||
):
|
||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
elif attention_backend == "aiter":
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
if is_dp_attention_enabled():
|
||||
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return AttnForwardMethod.MLA
|
||||
else:
|
||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and not forward_batch.forward_mode.is_target_verify()
|
||||
and not forward_batch.forward_mode.is_draft_extend()
|
||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||
):
|
||||
return AttnForwardMethod.MHA
|
||||
else:
|
||||
return _dispatch_mla_subtype()
|
||||
handler = BackendRegistry.get_handler(attention_backend)
|
||||
return handler(self, forward_batch)
|
||||
|
||||
def op_prepare(self, state):
|
||||
state.attn_intermediate_state = self.forward_prepare(
|
||||
@@ -1456,13 +1495,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
):
|
||||
if (
|
||||
self.current_attention_backend == "fa3"
|
||||
or self.current_attention_backend == "flashinfer"
|
||||
or self.current_attention_backend == "cutlass_mla"
|
||||
or self.current_attention_backend == "trtllm_mla"
|
||||
or self.current_attention_backend == "ascend"
|
||||
):
|
||||
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||
extra_args = {}
|
||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
extra_args = {
|
||||
@@ -3016,6 +3049,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
BackendRegistry.register("ascend", handle_ascend)
|
||||
BackendRegistry.register("flashinfer", handle_flashinfer)
|
||||
BackendRegistry.register("fa3", handle_fa3)
|
||||
BackendRegistry.register("flashmla", handle_flashmla)
|
||||
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
|
||||
BackendRegistry.register("fa4", handle_fa4)
|
||||
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
|
||||
BackendRegistry.register("aiter", handle_aiter)
|
||||
BackendRegistry.register("triton", handle_triton)
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user