diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 3bd48b241..f0c0c5f6e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -180,6 +180,9 @@ class EPMoE(torch.nn.Module): self.layer_id = layer_id self.num_experts = num_experts assert self.num_experts % self.tp_size == 0 + assert ( + num_fused_shared_experts == 0 + ), "num_fused_shared_experts is not supported in EP" self.num_experts_per_partition = self.num_experts // self.tp_size self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 @@ -191,7 +194,6 @@ class EPMoE(torch.nn.Module): if self.use_grouped_topk: assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group - self.num_fused_shared_experts = num_fused_shared_experts self.topk_group = topk_group self.correction_bias = correction_bias self.custom_routing_function = custom_routing_function @@ -252,7 +254,6 @@ class EPMoE(torch.nn.Module): renormalize=self.renormalize, topk_group=self.topk_group, num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, correction_bias=self.correction_bias, custom_routing_function=self.custom_routing_function, routed_scaling_factor=self.routed_scaling_factor, @@ -886,6 +887,7 @@ class DeepEPMoE(EPMoE): renormalize: bool = True, use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, + num_fused_shared_experts: int = 0, topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, @@ -897,23 +899,24 @@ class DeepEPMoE(EPMoE): deepep_mode: DeepEPMode = DeepEPMode.auto, ): super().__init__( - num_experts, - top_k, - hidden_size, - intermediate_size, - layer_id, - params_dtype, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - quant_config, - tp_size, - prefix, - correction_bias, - custom_routing_function, - activation, - routed_scaling_factor, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + layer_id=layer_id, + params_dtype=params_dtype, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + topk_group=topk_group, + quant_config=quant_config, + tp_size=tp_size, + prefix=prefix, + correction_bias=correction_bias, + custom_routing_function=custom_routing_function, + activation=activation, + routed_scaling_factor=routed_scaling_factor, ) self.deepep_mode = deepep_mode if self.deepep_mode.enable_low_latency():