From 5b6acc1495f4c4d44bfdb0ce8090426de280b002 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Wed, 6 Aug 2025 18:02:31 -0700 Subject: [PATCH] fix glm4 moe (#8883) --- python/sglang/srt/models/glm4_moe.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 4744c0c31..67ef6ca79 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() def forward_normal_dual_stream( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): current_stream.wait_stream(self.alt_stream) if self.ep_size > 1: - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) final_hidden_states += shared_output else: final_hidden_states += shared_output - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) return final_hidden_states def forward_normal( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def forward(