fix glm4 moe (#8883)
This commit is contained in:
@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
|
||||
def forward_normal_dual_stream(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
|
||||
if self.ep_size > 1:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not can_fuse_mlp_allreduce
|
||||
and not use_reduce_scatter
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
final_hidden_states += shared_output
|
||||
else:
|
||||
final_hidden_states += shared_output
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not can_fuse_mlp_allreduce
|
||||
and not use_reduce_scatter
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_normal(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
allow_reduce_scatter=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user