[Auto Sync] Update deepseek_v2.py (20250920) (#10683)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -177,6 +177,20 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
||||||
|
"fa3",
|
||||||
|
"flashinfer",
|
||||||
|
"cutlass_mla",
|
||||||
|
"trtllm_mla",
|
||||||
|
"ascend",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_forward_absorb_core_attention_backend(backend_name):
|
||||||
|
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||||
|
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
|
||||||
|
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
||||||
|
|
||||||
|
|
||||||
class AttnForwardMethod(IntEnum):
|
class AttnForwardMethod(IntEnum):
|
||||||
# Use multi-head attention
|
# Use multi-head attention
|
||||||
@@ -196,6 +210,134 @@ class AttnForwardMethod(IntEnum):
|
|||||||
MLA_FUSED_ROPE_CPU = auto()
|
MLA_FUSED_ROPE_CPU = auto()
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch_mla_subtype(attn, forward_batch):
|
||||||
|
if _is_hip:
|
||||||
|
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
|
||||||
|
return AttnForwardMethod.MLA_FUSED_ROPE
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
else:
|
||||||
|
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
|
||||||
|
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
|
|
||||||
|
class BackendRegistry:
|
||||||
|
_handlers = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, backend_name, handler_func):
|
||||||
|
cls._handlers[backend_name] = handler_func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_handler(cls, backend_name):
|
||||||
|
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
||||||
|
|
||||||
|
|
||||||
|
def handle_ascend(attn, forward_batch):
|
||||||
|
if (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sum_extend_prefix_lens(forward_batch):
|
||||||
|
return (
|
||||||
|
sum(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
if forward_batch.extend_prefix_lens_cpu is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_extend_without_speculative(forward_batch):
|
||||||
|
return (
|
||||||
|
forward_batch.forward_mode.is_extend()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
|
and not forward_batch.forward_mode.is_draft_extend()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_backend(attn, forward_batch, backend_name):
|
||||||
|
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||||
|
disable_ragged = (
|
||||||
|
backend_name in ["flashinfer", "flashmla"]
|
||||||
|
) and attn.flashinfer_mla_disable_ragged
|
||||||
|
|
||||||
|
if (
|
||||||
|
not disable_ragged
|
||||||
|
and _is_extend_without_speculative(forward_batch)
|
||||||
|
and (
|
||||||
|
(
|
||||||
|
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
|
||||||
|
and not attn.disable_chunked_prefix_cache
|
||||||
|
)
|
||||||
|
or sum_extend_prefix_lens == 0
|
||||||
|
)
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
|
else:
|
||||||
|
return _dispatch_mla_subtype(attn, forward_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_flashinfer(attn, forward_batch):
|
||||||
|
return _handle_backend(attn, forward_batch, "flashinfer")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_fa3(attn, forward_batch):
|
||||||
|
return _handle_backend(attn, forward_batch, "fa3")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_flashmla(attn, forward_batch):
|
||||||
|
return _handle_backend(attn, forward_batch, "flashmla")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_cutlass_mla(attn, forward_batch):
|
||||||
|
return _handle_backend(attn, forward_batch, "cutlass_mla")
|
||||||
|
|
||||||
|
|
||||||
|
def handle_fa4(attn, forward_batch):
|
||||||
|
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
||||||
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
|
|
||||||
|
|
||||||
|
def handle_trtllm_mla(attn, forward_batch):
|
||||||
|
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
||||||
|
if _is_extend_without_speculative(forward_batch) and (
|
||||||
|
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
||||||
|
else:
|
||||||
|
return _dispatch_mla_subtype(attn, forward_batch)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_aiter(attn, forward_batch):
|
||||||
|
if _is_extend_without_speculative(forward_batch):
|
||||||
|
if is_dp_attention_enabled():
|
||||||
|
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return AttnForwardMethod.MLA
|
||||||
|
|
||||||
|
|
||||||
|
def handle_triton(attn, forward_batch):
|
||||||
|
if (
|
||||||
|
_is_extend_without_speculative(forward_batch)
|
||||||
|
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
||||||
|
):
|
||||||
|
return AttnForwardMethod.MHA
|
||||||
|
else:
|
||||||
|
return _dispatch_mla_subtype(attn, forward_batch)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1039,23 +1181,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
def dispatch_attn_forward_method(
|
def dispatch_attn_forward_method(
|
||||||
self, forward_batch: ForwardBatch
|
self, forward_batch: ForwardBatch
|
||||||
) -> AttnForwardMethod:
|
) -> AttnForwardMethod:
|
||||||
def _dispatch_mla_subtype():
|
|
||||||
if _is_hip:
|
|
||||||
if (
|
|
||||||
self.rocm_fused_decode_mla
|
|
||||||
and forward_batch.forward_mode.is_decode()
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MLA_FUSED_ROPE
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MLA
|
|
||||||
else:
|
|
||||||
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
|
||||||
self
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MLA
|
|
||||||
|
|
||||||
# Determine attention backend used by current forward batch
|
# Determine attention backend used by current forward batch
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
attention_backend = global_server_args_dict["decode_attention_backend"]
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
||||||
@@ -1072,94 +1197,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
||||||
self.current_attention_backend = attention_backend
|
self.current_attention_backend = attention_backend
|
||||||
|
|
||||||
if attention_backend == "ascend":
|
handler = BackendRegistry.get_handler(attention_backend)
|
||||||
if (
|
return handler(self, forward_batch)
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MHA
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MLA
|
|
||||||
elif (
|
|
||||||
attention_backend == "flashinfer"
|
|
||||||
or attention_backend == "fa3"
|
|
||||||
or attention_backend == "flashmla"
|
|
||||||
or attention_backend == "cutlass_mla"
|
|
||||||
):
|
|
||||||
# Use MHA with chunked KV cache when prefilling on long sequences.
|
|
||||||
sum_extend_prefix_lens = (
|
|
||||||
sum(forward_batch.extend_prefix_lens_cpu)
|
|
||||||
if forward_batch.extend_prefix_lens_cpu is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
|
||||||
disable_ragged = (
|
|
||||||
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
|
||||||
) and self.flashinfer_mla_disable_ragged
|
|
||||||
|
|
||||||
if (
|
|
||||||
not disable_ragged
|
|
||||||
and forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and (
|
|
||||||
(
|
|
||||||
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
|
||||||
and not self.disable_chunked_prefix_cache
|
|
||||||
)
|
|
||||||
or sum_extend_prefix_lens == 0
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
|
||||||
else:
|
|
||||||
return _dispatch_mla_subtype()
|
|
||||||
elif attention_backend == "fa4":
|
|
||||||
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
|
||||||
elif attention_backend == "trtllm_mla":
|
|
||||||
sum_extend_prefix_lens = (
|
|
||||||
sum(forward_batch.extend_prefix_lens_cpu)
|
|
||||||
if forward_batch.extend_prefix_lens_cpu is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and (
|
|
||||||
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
|
||||||
)
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MHA_CHUNKED_KV
|
|
||||||
else:
|
|
||||||
return _dispatch_mla_subtype()
|
|
||||||
elif attention_backend == "aiter":
|
|
||||||
if (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
):
|
|
||||||
if is_dp_attention_enabled():
|
|
||||||
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
|
||||||
return AttnForwardMethod.MHA
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MLA
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MHA
|
|
||||||
else:
|
|
||||||
return AttnForwardMethod.MLA
|
|
||||||
else:
|
|
||||||
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
|
||||||
if (
|
|
||||||
forward_batch.forward_mode.is_extend()
|
|
||||||
and not forward_batch.forward_mode.is_target_verify()
|
|
||||||
and not forward_batch.forward_mode.is_draft_extend()
|
|
||||||
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
|
||||||
):
|
|
||||||
return AttnForwardMethod.MHA
|
|
||||||
else:
|
|
||||||
return _dispatch_mla_subtype()
|
|
||||||
|
|
||||||
def op_prepare(self, state):
|
def op_prepare(self, state):
|
||||||
state.attn_intermediate_state = self.forward_prepare(
|
state.attn_intermediate_state = self.forward_prepare(
|
||||||
@@ -1456,13 +1495,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
def forward_absorb_core(
|
def forward_absorb_core(
|
||||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||||
):
|
):
|
||||||
if (
|
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
||||||
self.current_attention_backend == "fa3"
|
|
||||||
or self.current_attention_backend == "flashinfer"
|
|
||||||
or self.current_attention_backend == "cutlass_mla"
|
|
||||||
or self.current_attention_backend == "trtllm_mla"
|
|
||||||
or self.current_attention_backend == "ascend"
|
|
||||||
):
|
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||||
extra_args = {
|
extra_args = {
|
||||||
@@ -3016,6 +3049,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BackendRegistry.register("ascend", handle_ascend)
|
||||||
|
BackendRegistry.register("flashinfer", handle_flashinfer)
|
||||||
|
BackendRegistry.register("fa3", handle_fa3)
|
||||||
|
BackendRegistry.register("flashmla", handle_flashmla)
|
||||||
|
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
|
||||||
|
BackendRegistry.register("fa4", handle_fa4)
|
||||||
|
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
|
||||||
|
BackendRegistry.register("aiter", handle_aiter)
|
||||||
|
BackendRegistry.register("triton", handle_triton)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user