[V0.18.0][EPLB][BugFix] Fix moe_load precision in allgather (#7890)

### What this PR does / why we need it?
Fixed the bug of incorrect reshape usage.
For example:
ori_tensor: [[1, 2, 3], [4, 5, 6]]
after reshape:
[[1, 2], [3, 4], [5, 6]]
after permute:
[[1, 4], [2, 5], [3, 6]]
Now, we will directly use squeeze for a more intuitive understanding.
pr for main:
#7887 

### Does this PR introduce _any_ user-facing change?
The actual peak-to-average ratio has successfully decreased.

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-04-02 09:20:31 +08:00
committed by GitHub
parent 99e1ea0fe6
commit 4b2f0130bc
2 changed files with 83 additions and 4 deletions

View File

@@ -131,10 +131,8 @@ class EplbUpdator:
self.update_iteration()
def compute_and_set_moe_load(self):
local_load = self.adaptor.get_rank_expert_workload()
moe_load = (
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
)
local_load = self.adaptor.get_rank_expert_workload().unsqueeze(1)
moe_load = self.comm_group.all_gather(local_load, dim=1).cpu()
if self.multi_stage:
moe_load = moe_load.permute(2, 0, 1, 3)