[BugFix]Fix group list type of mc2. (#3890)
### What this PR does / why we need it? Fix the precision issue caused by the inconsistency between the group list type used by mc2 and that of eplb. --------- Signed-off-by: offline0806 <3337230449@qq.com>
This commit is contained in:
@@ -244,7 +244,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.expert_map != -1) if self.expert_map is not None else
|
||||
self.global_num_experts)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
self.moe_load = torch.zeros(local_num_experts,
|
||||
dtype=torch.int64).npu()
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
@@ -340,9 +341,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load += expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
|
||||
final_hidden_states = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=final_hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user