diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index b16d22320..5eb3e9373 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -128,7 +128,9 @@ class BailingMoEMLP(nn.Module): gate_up, _ = self.gate_up_proj(hidden_states) hidden_states = self.act_fn(gate_up) - hidden_states, _ = self.down_proj(hidden_states) + hidden_states, _ = self.down_proj( + hidden_states, skip_all_reduce=use_reduce_scatter + ) return hidden_states @@ -328,7 +330,7 @@ class BailingMoESparseMoeBlock(nn.Module): ) -> torch.Tensor: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) - shared_output = self._forward_shared_experts(hidden_states) + shared_output = self._forward_shared_experts(hidden_states.clone()) with torch.cuda.stream(self.alt_stream): router_output = self._forward_router_experts(hidden_states) @@ -347,8 +349,9 @@ class BailingMoESparseMoeBlock(nn.Module): DUAL_STREAM_TOKEN_THRESHOLD = 1024 if ( self.alt_stream is not None - and num_tokens > 0 - and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD + and hidden_states.shape[0] > 0 + and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD + and get_is_capture_mode() ): final_hidden_states, shared_output = self.forward_normal_dual_stream( hidden_states diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 88fe9d6fe..77199e7d3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -757,7 +757,7 @@ class ServerArgs: if model_arch in [ "DeepseekV3ForCausalLM", "Glm4MoeForCausalLM", - "BailingMoeV2ForCausalLM", + "BailingMoeForCausalLM", "BailingMoeV2ForCausalLM", ]: # Auto set draft_model_path DeepSeek-V3/R1