fix glm4 moe (#8883)
This commit is contained in:
@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||||
|
|
||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
|
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if (
|
||||||
|
self.tp_size > 1
|
||||||
|
and not can_fuse_mlp_allreduce
|
||||||
|
and not use_reduce_scatter
|
||||||
|
):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states
|
final_hidden_states
|
||||||
)
|
)
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
else:
|
else:
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if (
|
||||||
|
self.tp_size > 1
|
||||||
|
and not can_fuse_mlp_allreduce
|
||||||
|
and not use_reduce_scatter
|
||||||
|
):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states
|
final_hidden_states
|
||||||
)
|
)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
layer_scatter_modes=self.layer_scatter_modes,
|
layer_scatter_modes=self.layer_scatter_modes,
|
||||||
input_layernorm=self.input_layernorm,
|
input_layernorm=self.input_layernorm,
|
||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Reference in New Issue
Block a user