diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 4744c0c31..67ef6ca79 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() def forward_normal_dual_stream( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): current_stream.wait_stream(self.alt_stream) if self.ep_size > 1: - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) final_hidden_states += shared_output else: final_hidden_states += shared_output - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) return final_hidden_states def forward_normal( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def forward(