diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5bd796daa..df2a89029 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -177,6 +177,20 @@ _is_sm100_supported = is_cuda() and is_sm100_supported() logger = logging.getLogger(__name__) +FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [ + "fa3", + "flashinfer", + "cutlass_mla", + "trtllm_mla", + "ascend", +] + + +def add_forward_absorb_core_attention_backend(backend_name): + if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: + FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name) + logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.") + class AttnForwardMethod(IntEnum): # Use multi-head attention @@ -196,6 +210,134 @@ class AttnForwardMethod(IntEnum): MLA_FUSED_ROPE_CPU = auto() +def _dispatch_mla_subtype(attn, forward_batch): + if _is_hip: + if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode(): + return AttnForwardMethod.MLA_FUSED_ROPE + else: + return AttnForwardMethod.MLA + else: + if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn): + return AttnForwardMethod.MLA_FUSED_ROPE_CPU + else: + return AttnForwardMethod.MLA + + +class BackendRegistry: + _handlers = {} + + @classmethod + def register(cls, backend_name, handler_func): + cls._handlers[backend_name] = handler_func + + @classmethod + def get_handler(cls, backend_name): + return cls._handlers.get(backend_name, cls._handlers.get("triton")) + + +def handle_ascend(attn, forward_batch): + if ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ): + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA + + +def _get_sum_extend_prefix_lens(forward_batch): + return ( + sum(forward_batch.extend_prefix_lens_cpu) + if forward_batch.extend_prefix_lens_cpu is not None + else 0 + ) + + +def _is_extend_without_speculative(forward_batch): + return ( + forward_batch.forward_mode.is_extend() + and not forward_batch.forward_mode.is_target_verify() + and not forward_batch.forward_mode.is_draft_extend() + ) + + +def _handle_backend(attn, forward_batch, backend_name): + sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) + disable_ragged = ( + backend_name in ["flashinfer", "flashmla"] + ) and attn.flashinfer_mla_disable_ragged + + if ( + not disable_ragged + and _is_extend_without_speculative(forward_batch) + and ( + ( + sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold + and not attn.disable_chunked_prefix_cache + ) + or sum_extend_prefix_lens == 0 + ) + ): + return AttnForwardMethod.MHA_CHUNKED_KV + else: + return _dispatch_mla_subtype(attn, forward_batch) + + +def handle_flashinfer(attn, forward_batch): + return _handle_backend(attn, forward_batch, "flashinfer") + + +def handle_fa3(attn, forward_batch): + return _handle_backend(attn, forward_batch, "fa3") + + +def handle_flashmla(attn, forward_batch): + return _handle_backend(attn, forward_batch, "flashmla") + + +def handle_cutlass_mla(attn, forward_batch): + return _handle_backend(attn, forward_batch, "cutlass_mla") + + +def handle_fa4(attn, forward_batch): + # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now + return AttnForwardMethod.MHA_CHUNKED_KV + + +def handle_trtllm_mla(attn, forward_batch): + sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) + if _is_extend_without_speculative(forward_batch) and ( + not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0 + ): + return AttnForwardMethod.MHA_CHUNKED_KV + else: + return _dispatch_mla_subtype(attn, forward_batch) + + +def handle_aiter(attn, forward_batch): + if _is_extend_without_speculative(forward_batch): + if is_dp_attention_enabled(): + if sum(forward_batch.extend_prefix_lens_cpu) == 0: + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA + else: + return AttnForwardMethod.MHA + else: + return AttnForwardMethod.MLA + + +def handle_triton(attn, forward_batch): + if ( + _is_extend_without_speculative(forward_batch) + and sum(forward_batch.extend_prefix_lens_cpu) == 0 + ): + return AttnForwardMethod.MHA + else: + return _dispatch_mla_subtype(attn, forward_batch) + + class DeepseekV2MLP(nn.Module): def __init__( self, @@ -1039,23 +1181,6 @@ class DeepseekV2AttentionMLA(nn.Module): def dispatch_attn_forward_method( self, forward_batch: ForwardBatch ) -> AttnForwardMethod: - def _dispatch_mla_subtype(): - if _is_hip: - if ( - self.rocm_fused_decode_mla - and forward_batch.forward_mode.is_decode() - ): - return AttnForwardMethod.MLA_FUSED_ROPE - else: - return AttnForwardMethod.MLA - else: - if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend( - self - ): - return AttnForwardMethod.MLA_FUSED_ROPE_CPU - else: - return AttnForwardMethod.MLA - # Determine attention backend used by current forward batch if forward_batch.forward_mode.is_decode_or_idle(): attention_backend = global_server_args_dict["decode_attention_backend"] @@ -1072,94 +1197,8 @@ class DeepseekV2AttentionMLA(nn.Module): attention_backend = global_server_args_dict["prefill_attention_backend"] self.current_attention_backend = attention_backend - if attention_backend == "ascend": - if ( - forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - ): - return AttnForwardMethod.MHA - else: - return AttnForwardMethod.MLA - elif ( - attention_backend == "flashinfer" - or attention_backend == "fa3" - or attention_backend == "flashmla" - or attention_backend == "cutlass_mla" - ): - # Use MHA with chunked KV cache when prefilling on long sequences. - sum_extend_prefix_lens = ( - sum(forward_batch.extend_prefix_lens_cpu) - if forward_batch.extend_prefix_lens_cpu is not None - else 0 - ) - # Flashinfer MLA: Do not absorb when enabling ragged prefill - disable_ragged = ( - attention_backend == "flashinfer" or attention_backend == "flashmla" - ) and self.flashinfer_mla_disable_ragged - - if ( - not disable_ragged - and forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - and ( - ( - sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold - and not self.disable_chunked_prefix_cache - ) - or sum_extend_prefix_lens == 0 - ) - ): - return AttnForwardMethod.MHA_CHUNKED_KV - else: - return _dispatch_mla_subtype() - elif attention_backend == "fa4": - # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now - return AttnForwardMethod.MHA_CHUNKED_KV - elif attention_backend == "trtllm_mla": - sum_extend_prefix_lens = ( - sum(forward_batch.extend_prefix_lens_cpu) - if forward_batch.extend_prefix_lens_cpu is not None - else 0 - ) - if ( - forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - and ( - not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0 - ) - ): - return AttnForwardMethod.MHA_CHUNKED_KV - else: - return _dispatch_mla_subtype() - elif attention_backend == "aiter": - if ( - forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - ): - if is_dp_attention_enabled(): - if sum(forward_batch.extend_prefix_lens_cpu) == 0: - return AttnForwardMethod.MHA - else: - return AttnForwardMethod.MLA - else: - return AttnForwardMethod.MHA - else: - return AttnForwardMethod.MLA - else: - # Triton: Use normal computation for prefill and use weight absorption for extend/decode - if ( - forward_batch.forward_mode.is_extend() - and not forward_batch.forward_mode.is_target_verify() - and not forward_batch.forward_mode.is_draft_extend() - and sum(forward_batch.extend_prefix_lens_cpu) == 0 - ): - return AttnForwardMethod.MHA - else: - return _dispatch_mla_subtype() + handler = BackendRegistry.get_handler(attention_backend) + return handler(self, forward_batch) def op_prepare(self, state): state.attn_intermediate_state = self.forward_prepare( @@ -1456,13 +1495,7 @@ class DeepseekV2AttentionMLA(nn.Module): def forward_absorb_core( self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions ): - if ( - self.current_attention_backend == "fa3" - or self.current_attention_backend == "flashinfer" - or self.current_attention_backend == "cutlass_mla" - or self.current_attention_backend == "trtllm_mla" - or self.current_attention_backend == "ascend" - ): + if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS: extra_args = {} if self._fuse_rope_for_trtllm_mla(forward_batch): extra_args = { @@ -3016,6 +3049,17 @@ class DeepseekV2ForCausalLM(nn.Module): ) +BackendRegistry.register("ascend", handle_ascend) +BackendRegistry.register("flashinfer", handle_flashinfer) +BackendRegistry.register("fa3", handle_fa3) +BackendRegistry.register("flashmla", handle_flashmla) +BackendRegistry.register("cutlass_mla", handle_cutlass_mla) +BackendRegistry.register("fa4", handle_fa4) +BackendRegistry.register("trtllm_mla", handle_trtllm_mla) +BackendRegistry.register("aiter", handle_aiter) +BackendRegistry.register("triton", handle_triton) + + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass