[CPU] fix no attribute 'can_fuse_mlp_allreduce' error (#8010)
This commit is contained in:
@@ -462,7 +462,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
):
|
):
|
||||||
return self.forward_cpu(hidden_states)
|
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
||||||
|
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
@@ -479,7 +479,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_cpu(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_cpu(
|
||||||
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
fused_experts_out = self.experts(
|
fused_experts_out = self.experts(
|
||||||
@@ -528,7 +530,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
None, # a2_scale
|
None, # a2_scale
|
||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user