From 24dc2bee97423356a539877650db7ca92f87ec2c Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Fri, 12 Sep 2025 15:36:02 +0800 Subject: [PATCH] Fix Bailing MoE model bugs (#10362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: luoyuan.luo Co-authored-by: 羽癫 --- python/sglang/srt/models/bailing_moe.py | 11 +++++++---- python/sglang/srt/server_args.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index b16d22320..5eb3e9373 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -128,7 +128,9 @@ class BailingMoEMLP(nn.Module): gate_up, _ = self.gate_up_proj(hidden_states) hidden_states = self.act_fn(gate_up) - hidden_states, _ = self.down_proj(hidden_states) + hidden_states, _ = self.down_proj( + hidden_states, skip_all_reduce=use_reduce_scatter + ) return hidden_states @@ -328,7 +330,7 @@ class BailingMoESparseMoeBlock(nn.Module): ) -> torch.Tensor: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) - shared_output = self._forward_shared_experts(hidden_states) + shared_output = self._forward_shared_experts(hidden_states.clone()) with torch.cuda.stream(self.alt_stream): router_output = self._forward_router_experts(hidden_states) @@ -347,8 +349,9 @@ class BailingMoESparseMoeBlock(nn.Module): DUAL_STREAM_TOKEN_THRESHOLD = 1024 if ( self.alt_stream is not None - and num_tokens > 0 - and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD + and hidden_states.shape[0] > 0 + and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + and get_is_capture_mode() ): final_hidden_states, shared_output = self.forward_normal_dual_stream( hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 88fe9d6fe..77199e7d3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -757,7 +757,7 @@ class ServerArgs: if model_arch in [ "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", - "BailingMoeV2ForCausalLM", + "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", ]: # Auto set draft_model_path DeepSeek-V3/R1