Fix Bailing MoE model bugs (#10362)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: 羽癫 <yudian.zy@antgroup.com>
This commit is contained in:
@@ -128,7 +128,9 @@ class BailingMoEMLP(nn.Module):
|
|||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(hidden_states)
|
gate_up, _ = self.gate_up_proj(hidden_states)
|
||||||
hidden_states = self.act_fn(gate_up)
|
hidden_states = self.act_fn(gate_up)
|
||||||
hidden_states, _ = self.down_proj(hidden_states)
|
hidden_states, _ = self.down_proj(
|
||||||
|
hidden_states, skip_all_reduce=use_reduce_scatter
|
||||||
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -328,7 +330,7 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states.clone())
|
||||||
|
|
||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
router_output = self._forward_router_experts(hidden_states)
|
router_output = self._forward_router_experts(hidden_states)
|
||||||
@@ -347,8 +349,9 @@ class BailingMoESparseMoeBlock(nn.Module):
|
|||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
if (
|
if (
|
||||||
self.alt_stream is not None
|
self.alt_stream is not None
|
||||||
and num_tokens > 0
|
and hidden_states.shape[0] > 0
|
||||||
and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
|
and get_is_capture_mode()
|
||||||
):
|
):
|
||||||
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
final_hidden_states, shared_output = self.forward_normal_dual_stream(
|
||||||
hidden_states
|
hidden_states
|
||||||
|
|||||||
@@ -757,7 +757,7 @@ class ServerArgs:
|
|||||||
if model_arch in [
|
if model_arch in [
|
||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM",
|
||||||
"Glm4MoeForCausalLM",
|
"Glm4MoeForCausalLM",
|
||||||
"BailingMoeV2ForCausalLM",
|
"BailingMoeForCausalLM",
|
||||||
"BailingMoeV2ForCausalLM",
|
"BailingMoeV2ForCausalLM",
|
||||||
]:
|
]:
|
||||||
# Auto set draft_model_path DeepSeek-V3/R1
|
# Auto set draft_model_path DeepSeek-V3/R1
|
||||||
|
|||||||
Reference in New Issue
Block a user