From 2df9d40aa672feb3e4943d9848dce604ec96e938 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sat, 17 May 2025 10:06:03 +0800 Subject: [PATCH] Minor code cleanup refactor for DeepSeek models (#6324) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 11 ++++- python/sglang/srt/models/deepseek_v2.py | 50 +++++++------------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index b39e15f4b..f91b8d5a6 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -5,6 +5,7 @@ import torch from torch.nn import Module from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM +from sglang.srt.managers.schedule_batch import global_server_args_dict try: from deep_gemm import ( @@ -40,7 +41,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -1173,3 +1174,11 @@ class DeepEPMoE(EPMoE): ) return down_output + + +def get_moe_impl_class(): + if global_server_args_dict["enable_deepep_moe"]: + return DeepEPMoE + if global_server_args_dict["enable_ep_moe"]: + return EPMoE + return FusedMoE diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 758a50f53..436e966db 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -52,7 +52,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import select_experts @@ -222,13 +222,7 @@ class DeepseekV2MoE(nn.Module): self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) - MoEImpl = ( - DeepEPMoE - if global_server_args_dict["enable_deepep_moe"] - else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) - ) - - self.experts = MoEImpl( + self.experts = get_moe_impl_class()( num_experts=config.n_routed_experts + self.n_share_experts_fusion, top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1), hidden_size=config.hidden_size, @@ -251,26 +245,19 @@ class DeepseekV2MoE(nn.Module): if config.n_shared_experts is not None and self.n_share_experts_fusion == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts # disable tp for shared experts when enable deepep moe - if not global_server_args_dict["enable_deepep_moe"]: - self.shared_experts = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=add_prefix("shared_experts", prefix), - ) - else: - self.shared_experts = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=add_prefix("shared_experts", prefix), - tp_rank=0, - tp_size=1, - ) + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=add_prefix("shared_experts", prefix), + **( + dict(tp_rank=0, tp_size=1) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), + ) if global_server_args_dict["enable_deepep_moe"]: # TODO: we will support tp < ep in the future @@ -1726,12 +1713,7 @@ class DeepseekV2ForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - MoEImpl = ( - DeepEPMoE - if global_server_args_dict["enable_deepep_moe"] - else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) - ) - expert_params_mapping = MoEImpl.make_expert_params_mapping( + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj",