From b65db0287b55f4fbeaf318f03e4ab76ff6a2e7d8 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 2 Oct 2025 21:54:52 +0800 Subject: [PATCH] Tiny cleanup deepseek_v2.py (#11163) --- .../srt/layers/moe/fused_moe_triton/layer.py | 13 ++-- python/sglang/srt/models/deepseek_v2.py | 63 +++++++++---------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 27d41184e..d3583975d 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module): self.quant_method.create_moe_runner(self, self.moe_runner_config) self.dispatcher = StandardDispatcher() + self.should_fuse_routed_scaling_factor_in_topk = isinstance( + self.quant_method, ModelOptNvFp4FusedMoEMethod + ) or ( + isinstance(self.quant_method, Fp8MoEMethod) + and self.quant_method.use_cutlass_fused_experts_fp8 + ) + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module): for shard_id in ["w1", "w2", "w3"] ] - def should_fuse_routed_scaling_factor_in_topk(self): - return isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or ( - isinstance(self.quant_method, Fp8MoEMethod) - and self.quant_method.use_cutlass_fused_experts_fp8 - ) - class FlashInferFusedMoE(FusedMoE): def __init__(self, *args, **kwargs): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 131786946..336a5b68c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -166,16 +166,15 @@ if _is_cuda: elif _is_cpu and _is_cpu_amx_available: pass elif _is_hip: + from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( + decode_attention_fwd_grouped_rope, + ) from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_triton as awq_dequantize, ) else: pass -if _is_hip: - from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( - decode_attention_fwd_grouped_rope, - ) _is_flashinfer_available = is_flashinfer_available() _is_sm100_supported = is_cuda() and is_sm100_supported() @@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch): return AttnForwardMethod.MLA -class BackendRegistry: +class AttentionBackendRegistry: _handlers = {} @classmethod @@ -241,7 +240,7 @@ class BackendRegistry: return cls._handlers.get(backend_name, cls._handlers.get("triton")) -def handle_ascend(attn, forward_batch): +def handle_attention_ascend(attn, forward_batch): if ( forward_batch.forward_mode.is_extend() and not forward_batch.forward_mode.is_target_verify() @@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch): ) -def _handle_backend(attn, forward_batch, backend_name): +def _handle_attention_backend(attn, forward_batch, backend_name): sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) disable_ragged = ( backend_name in ["flashinfer", "flashmla"] @@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name): return _dispatch_mla_subtype(attn, forward_batch) -def handle_flashinfer(attn, forward_batch): - return _handle_backend(attn, forward_batch, "flashinfer") +def handle_attention_flashinfer(attn, forward_batch): + return _handle_attention_backend(attn, forward_batch, "flashinfer") -def handle_fa3(attn, forward_batch): - return _handle_backend(attn, forward_batch, "fa3") +def handle_attention_fa3(attn, forward_batch): + return _handle_attention_backend(attn, forward_batch, "fa3") -def handle_flashmla(attn, forward_batch): - return _handle_backend(attn, forward_batch, "flashmla") +def handle_attention_flashmla(attn, forward_batch): + return _handle_attention_backend(attn, forward_batch, "flashmla") -def handle_cutlass_mla(attn, forward_batch): - return _handle_backend(attn, forward_batch, "cutlass_mla") +def handle_attention_cutlass_mla(attn, forward_batch): + return _handle_attention_backend(attn, forward_batch, "cutlass_mla") -def handle_fa4(attn, forward_batch): +def handle_attention_fa4(attn, forward_batch): # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now return AttnForwardMethod.MHA_CHUNKED_KV -def handle_trtllm_mla(attn, forward_batch): +def handle_attention_trtllm_mla(attn, forward_batch): sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch) if _is_extend_without_speculative(forward_batch) and ( not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0 @@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch): return _dispatch_mla_subtype(attn, forward_batch) -def handle_aiter(attn, forward_batch): +def handle_attention_aiter(attn, forward_batch): if _is_extend_without_speculative(forward_batch): if is_dp_attention_enabled(): if sum(forward_batch.extend_prefix_lens_cpu) == 0: @@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch): return AttnForwardMethod.MLA -def handle_triton(attn, forward_batch): +def handle_attention_triton(attn, forward_batch): if ( _is_extend_without_speculative(forward_batch) and sum(forward_batch.extend_prefix_lens_cpu) == 0 @@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.gate.e_score_correction_bias, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, - apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized # and requires the output format to be standard. We use quant_config to determine the output format. output_format=TopKOutputFormat.STANDARD if quant_config is None else None, @@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module): if shared_output is not None: x = shared_output - if self.experts.should_fuse_routed_scaling_factor_in_topk(): + if self.experts.should_fuse_routed_scaling_factor_in_topk: x.add_(final_hidden_states) else: x.add_(final_hidden_states, alpha=self.routed_scaling_factor) final_hidden_states = x else: - if not self.experts.should_fuse_routed_scaling_factor_in_topk(): + if not self.experts.should_fuse_routed_scaling_factor_in_topk: final_hidden_states *= self.routed_scaling_factor return final_hidden_states @@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module): attention_backend = global_server_args_dict["prefill_attention_backend"] self.current_attention_backend = attention_backend - handler = BackendRegistry.get_handler(attention_backend) + handler = AttentionBackendRegistry.get_handler(attention_backend) return handler(self, forward_batch) def op_prepare(self, state): @@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module): ) -BackendRegistry.register("ascend", handle_ascend) -BackendRegistry.register("flashinfer", handle_flashinfer) -BackendRegistry.register("fa3", handle_fa3) -BackendRegistry.register("flashmla", handle_flashmla) -BackendRegistry.register("cutlass_mla", handle_cutlass_mla) -BackendRegistry.register("fa4", handle_fa4) -BackendRegistry.register("trtllm_mla", handle_trtllm_mla) -BackendRegistry.register("aiter", handle_aiter) -BackendRegistry.register("triton", handle_triton) +AttentionBackendRegistry.register("ascend", handle_attention_ascend) +AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer) +AttentionBackendRegistry.register("fa3", handle_attention_fa3) +AttentionBackendRegistry.register("flashmla", handle_attention_flashmla) +AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla) +AttentionBackendRegistry.register("fa4", handle_attention_fa4) +AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla) +AttentionBackendRegistry.register("aiter", handle_attention_aiter) +AttentionBackendRegistry.register("triton", handle_attention_triton) class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):