From d58e35447203c293933ad300c066c8f581cb0935 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sun, 20 Apr 2025 04:17:35 +0800 Subject: [PATCH] simplify the control logic for using shared experts fusion (#5504) --- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 ++ .../sglang/srt/layers/moe/fused_moe_native.py | 4 ++ .../layers/moe/fused_moe_triton/fused_moe.py | 2 + .../srt/layers/moe/fused_moe_triton/layer.py | 7 +++ python/sglang/srt/layers/moe/topk.py | 25 +++++---- .../srt/layers/quantization/__init__.py | 1 + .../srt/layers/quantization/blockwise_int8.py | 2 + .../compressed_tensors_moe.py | 4 ++ python/sglang/srt/layers/quantization/fp8.py | 2 + .../srt/layers/quantization/moe_wna16.py | 2 + .../srt/layers/quantization/w8a8_fp8.py | 2 + .../srt/layers/quantization/w8a8_int8.py | 2 + python/sglang/srt/managers/schedule_batch.py | 1 - .../sglang/srt/model_executor/model_runner.py | 1 - python/sglang/srt/models/deepseek_v2.py | 52 ++++++++----------- python/sglang/srt/server_args.py | 13 +---- 16 files changed, 69 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index a5a9eb738..1979c8d3a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -136,6 +136,7 @@ class EPMoE(torch.nn.Module): correction_bias: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, activation: str = "silu", + routed_scaling_factor: Optional[float] = None, ): super().__init__() @@ -164,6 +165,7 @@ class EPMoE(torch.nn.Module): self.correction_bias = correction_bias self.custom_routing_function = custom_routing_function self.activation = activation + self.routed_scaling_factor = routed_scaling_factor if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -215,6 +217,7 @@ class EPMoE(torch.nn.Module): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, + routed_scaling_factor=self.routed_scaling_factor, ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 57e910943..ce9940e1a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -26,6 +26,7 @@ def fused_moe_forward_native( apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: if apply_router_weight_on_input: @@ -41,6 +42,7 @@ def fused_moe_forward_native( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, torch_native=True, ) @@ -71,6 +73,7 @@ def moe_forward_native( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.activation import GeluAndMul, SiluAndMul @@ -86,6 +89,7 @@ def moe_forward_native( custom_routing_function=custom_routing_function, correction_bias=correction_bias, torch_native=True, + routed_scaling_factor=routed_scaling_factor, ) # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 61400787a..f237f2135 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -1547,6 +1547,7 @@ def fused_moe( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1601,6 +1602,7 @@ def fused_moe( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + routed_scaling_factor=routed_scaling_factor, ) return fused_experts( diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index a33cf691f..6ce240cb3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: return self.forward( x=x, @@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input=apply_router_weight_on_input, inplace=inplace, no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_cuda( @@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) if _is_hip and get_bool_env_var("CK_MOE"): @@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module): use_presharded_weights: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ): super().__init__() @@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module): self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() ) + self.routed_scaling_factor = routed_scaling_factor self.top_k = top_k self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 @@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module): correction_bias=self.correction_bias, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, + routed_scaling_factor=self.routed_scaling_factor, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index e59f0e299..47915cf40 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -98,6 +98,7 @@ def grouped_topk( num_expert_group: int = 0, topk_group: int = 0, n_share_experts_fusion: int = 0, + routed_scaling_factor: Optional[float] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -127,9 +128,7 @@ def grouped_topk( dtype=topk_ids.dtype, device=topk_ids.device, ) - topk_weights[:, -1] = ( - topk_weights[:, :-1].sum(dim=-1) / 2.5 - ) # 2.5 is the routed_scaling_factor. + topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor if renormalize: topk_weights_sum = ( @@ -151,6 +150,7 @@ def biased_grouped_topk_impl( num_expert_group: int = 0, topk_group: int = 0, n_share_experts_fusion: int = 0, + routed_scaling_factor: Optional[float] = None, ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" @@ -187,9 +187,7 @@ def biased_grouped_topk_impl( dtype=topk_ids.dtype, device=topk_ids.device, ) - topk_weights[:, -1] = ( - topk_weights[:, :-1].sum(dim=-1) / 2.5 - ) # 2.5 is the routed_scaling_factor. + topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor if renormalize: topk_weights_sum = ( @@ -216,13 +214,16 @@ def biased_grouped_topk( topk_group: int = 0, compiled: bool = True, n_share_experts_fusion: int = 0, + routed_scaling_factor: Optional[float] = None, ): + assert ( + routed_scaling_factor is not None + ), "routed_scaling_factor is required for biased_grouped_topk" # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now. if ( _is_cuda and gating_output.shape[1] // num_expert_group <= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion. - and n_share_experts_fusion == 0 and is_power_of_two(correction_bias.shape[0]) ): return moe_fused_gate( @@ -231,6 +232,8 @@ def biased_grouped_topk( num_expert_group, topk_group, topk, + n_share_experts_fusion, + routed_scaling_factor, ) else: biased_grouped_topk_fn = ( @@ -249,6 +252,7 @@ def biased_grouped_topk( num_expert_group, topk_group, n_share_experts_fusion=n_share_experts_fusion, + routed_scaling_factor=routed_scaling_factor, ) @@ -263,10 +267,9 @@ def select_experts( custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, + routed_scaling_factor: Optional[float] = None, ): - n_share_experts_fusion = 0 - if global_server_args_dict["n_share_experts_fusion"] is not None: - n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] + n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] # DeekSeek V2/V3/R1 serices models uses grouped_top_k if use_grouped_topk: assert topk_group is not None @@ -280,6 +283,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, + routed_scaling_factor=routed_scaling_factor, ) else: topk_weights, topk_ids = biased_grouped_topk( @@ -291,6 +295,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, n_share_experts_fusion=n_share_experts_fusion, + routed_scaling_factor=routed_scaling_factor, ) elif torch_native and custom_routing_function is None: topk_weights, topk_ids = fused_topk_native( diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 885f9fe50..6b8719bfa 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ): assert activation == "silu" assert inplace and not no_combine diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index c147d5b2f..25c91da6e 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -373,6 +373,7 @@ class BlockInt8MoEMethod: apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -388,6 +389,7 @@ class BlockInt8MoEMethod: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) # Expert fusion with INT8 quantization diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b8d9d637e..67496b14b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -283,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): inplace: bool = True, no_combine: bool = False, apply_router_weight_on_input: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -297,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) return fused_experts( @@ -633,6 +635,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.topk import select_experts @@ -653,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): custom_routing_function=custom_routing_function, scoring_func=scoring_func, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) return torch.ops.vllm.fused_marlin_moe( diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 5ba2b3fb8..b37a86887 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -892,6 +892,7 @@ class Fp8MoEMethod: apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -907,6 +908,7 @@ class Fp8MoEMethod: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) if _is_hip: diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index 4c3e1dfc7..122f310dc 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -347,6 +347,7 @@ class MoeWNA16Method: apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: # avoid circular import from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -363,6 +364,7 @@ class MoeWNA16Method: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) weight_bits = self.quant_config.weight_bits diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 12d1eab19..48cf5db34 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -294,6 +294,7 @@ class W8A8FP8MoEMethod: activation: str = "silu", inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -309,6 +310,7 @@ class W8A8FP8MoEMethod: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) return fused_experts( diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index df345a0a2..829e9e8ae 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -231,6 +231,7 @@ class W8A8Int8MoEMethod: apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -246,6 +247,7 @@ class W8A8Int8MoEMethod: num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, ) return fused_experts( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2b5cd1b62..a324a2e5d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -81,7 +81,6 @@ global_server_args_dict = { "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, "chunked_prefill_size": ServerArgs.chunked_prefill_size, "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, - "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion, "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache, } diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5f226f870..97f5888ae 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -163,7 +163,6 @@ class ModelRunner: "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, - "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, "use_mla_backend": self.use_mla_backend, } diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 26073bd67..96c63d0ae 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -189,11 +189,7 @@ class DeepseekV2MoE(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts - self.n_share_experts_fusion = ( - global_server_args_dict["n_share_experts_fusion"] - if global_server_args_dict["n_share_experts_fusion"] is not None - else 0 - ) + self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] if self.tp_size > config.n_routed_experts: raise ValueError( @@ -226,6 +222,7 @@ class DeepseekV2MoE(nn.Module): num_expert_group=config.n_group, topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, prefix=add_prefix("experts", prefix), **( dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) @@ -334,6 +331,7 @@ class DeepseekV2MoE(nn.Module): topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, ) if self.ep_size > 1: # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value @@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module): return final_hidden_states def _forward_shared_experts(self, hidden_states): - if self.n_shared_experts is not None and self.n_share_experts_fusion == 0: + if self.n_share_experts_fusion == 0: return self.shared_experts(hidden_states) else: return None @@ -1346,24 +1344,21 @@ class DeepseekV2ForCausalLM(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"] - # Only Deepseek V3/R1 can use shared experts fusion optimization now. - if ( - global_server_args_dict.get("disable_shared_experts_fusion", False) - or self.config.architectures[0] != "DeepseekV3ForCausalLM" - or self.config.n_routed_experts != 256 - or self.config.routed_scaling_factor != 2.5 - ): - self.n_share_experts_fusion = None - global_server_args_dict["n_share_experts_fusion"] = None - logger.info( - "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." - ) - elif self.n_share_experts_fusion is None: - global_server_args_dict["n_share_experts_fusion"] = self.tp_size - self.n_share_experts_fusion = self.tp_size - logger.info( - f"Shared experts fusion optimization is default enabled in DeepSeek V3/R1, and n_share_experts_fusion is set to {self.tp_size}. You can tune it by setting --n_share_experts_fusion or disable it by setting --disable_shared_experts_fusion." - ) + if self.n_share_experts_fusion > 0: + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + if ( + self.config.architectures[0] != "DeepseekV3ForCausalLM" + or self.config.n_routed_experts != 256 + ): + self.n_share_experts_fusion = 0 + global_server_args_dict["n_share_experts_fusion"] = 0 + logger.info( + "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled." + ) + else: + assert ( + self.n_share_experts_fusion == self.tp_size + ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace." self.model = DeepseekV2Model( config, quant_config, prefix=add_prefix("model", prefix) @@ -1484,7 +1479,7 @@ class DeepseekV2ForCausalLM(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - if self.n_share_experts_fusion is not None and self.n_share_experts_fusion > 0: + if self.n_share_experts_fusion > 0: weights_list = list(weights) weights_dict = dict(weights_list) if self.quant_config.get_name() == "w8a8_int8": @@ -1543,12 +1538,7 @@ class DeepseekV2ForCausalLM(nn.Module): ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts - + ( - self.n_share_experts_fusion - if self.n_share_experts_fusion is not None - else 0 - ), + num_experts=self.config.n_routed_experts + self.n_share_experts_fusion, ) params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 25ad02a77..bb8887168 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -183,7 +183,6 @@ class ServerArgs: warmups: Optional[str] = None moe_dense_tp_size: Optional[int] = None n_share_experts_fusion: int = 0 - disable_shared_experts_fusion: bool = False disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False @@ -229,9 +228,6 @@ class ServerArgs: # GPU memory is not known yet or no GPU is available. gpu_mem = None - if is_hip(): - self.disable_shared_experts_fusion = True - # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: if self.tp_size >= 16: @@ -1126,13 +1122,8 @@ class ServerArgs: "--n-share-experts-fusion", type=int, default=0, - help="The number of shared_experts need to be replica to fuse with normal experts in deepseek v3/r1 " - "we use tp_size by default.", - ) - parser.add_argument( - "--disable-shared-experts-fusion", - action="store_true", - help="Disable shared experts fusion by setting n_share_experts_fusion to 0.", + help="The number of shared_experts need to be replicated to fuse with normal experts in deepseek v3/r1, " + "set it to tp_size can get best optimized performace.", ) parser.add_argument( "--disable-chunked-prefix-cache",