From 2ae95d17e80710d5ed1189398f36905ad43f5baa Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Fri, 1 Aug 2025 12:02:35 -0700 Subject: [PATCH] Disable tp for shared experts under expert parallelism for GLM4.5 model (#8647) (#8647) Co-authored-by: Stefan He Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> --- python/sglang/srt/models/glm4_moe.py | 78 ++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index ab9a83c73..badbb56ca 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -387,6 +387,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ): nn.Module.__init__(self) self.tp_size = get_tensor_model_parallel_world_size() + self.ep_size = get_moe_expert_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.num_fused_shared_experts = ( @@ -480,11 +481,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): quant_config=quant_config, reduce_results=False, prefix=add_prefix("shared_experts", prefix), - **( - dict(tp_rank=0, tp_size=1) - if global_server_args_dict["moe_a2a_backend"].is_deepep() - else {} - ), + **(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}), ) is_packed_weight = hasattr( self.shared_experts.gate_up_proj.quant_method, "quant_config" @@ -531,6 +528,77 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() + def forward_normal_dual_stream( + self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + ) -> torch.Tensor: + + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + + with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + kwargs = {"hidden_states": hidden_states} + if self.topk is not None: + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + else: + kwargs["router_logits"] = router_logits + final_hidden_states = self.experts(**kwargs) + if not _is_cuda: + final_hidden_states *= self.routed_scaling_factor + current_stream.wait_stream(self.alt_stream) + + if self.ep_size > 1: + if self.tp_size > 1 and not can_fuse_mlp_allreduce: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states + ) + final_hidden_states += shared_output + else: + final_hidden_states += shared_output + if self.tp_size > 1 and not can_fuse_mlp_allreduce: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states + ) + return final_hidden_states + + def forward_normal( + self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + ) -> torch.Tensor: + if hasattr(self, "shared_experts") and use_intel_amx_backend( + self.shared_experts.gate_up_proj + ): + return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce) + + shared_output = self._forward_shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + kwargs = {"hidden_states": hidden_states} + if self.topk is not None: + kwargs["topk_output"] = self.topk(hidden_states, router_logits) + else: + kwargs["router_logits"] = router_logits + final_hidden_states = self.experts(**kwargs) + if not _is_cuda and not _use_aiter: + # fused in biased_grouped_topk so we can skip here + final_hidden_states *= self.routed_scaling_factor + if self.ep_size > 1: + if self.tp_size > 1 and not can_fuse_mlp_allreduce: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states + ) + if shared_output is not None: + final_hidden_states += shared_output + else: + if shared_output is not None: + final_hidden_states += shared_output + if self.tp_size > 1 and not can_fuse_mlp_allreduce: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states + ) + return final_hidden_states + class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): def __init__(