From 4b2f0130bc08d71f8f406583a3bb674331a09631 Mon Sep 17 00:00:00 2001 From: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:20:31 +0800 Subject: [PATCH] [V0.18.0][EPLB][BugFix] Fix moe_load precision in allgather (#7890) ### What this PR does / why we need it? Fixed the bug of incorrect reshape usage. For example: ori_tensor: [[1, 2, 3], [4, 5, 6]] after reshape: [[1, 2], [3, 4], [5, 6]] after permute: [[1, 4], [2, 5], [3, 6]] Now, we will directly use squeeze for a more intuitive understanding. pr for main: #7887 ### Does this PR introduce _any_ user-facing change? The actual peak-to-average ratio has successfully decreased. Signed-off-by: shenchuxiaofugui <1311027364@qq.com> --- tests/ut/eplb/test_eplb_updator.py | 81 ++++++++++++++++++++++++++++++ vllm_ascend/eplb/eplb_updator.py | 6 +-- 2 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 tests/ut/eplb/test_eplb_updator.py diff --git a/tests/ut/eplb/test_eplb_updator.py b/tests/ut/eplb/test_eplb_updator.py new file mode 100644 index 00000000..d45ad2b0 --- /dev/null +++ b/tests/ut/eplb/test_eplb_updator.py @@ -0,0 +1,81 @@ +import unittest +from unittest.mock import MagicMock, patch +import torch +from vllm_ascend.eplb.eplb_updator import EplbUpdator + + +class TestEplbUpdatorComputeAndSetMoeLoad(unittest.TestCase): + def setUp(self): + + # ====================== 1. Mock environment ====================== + self.rank = 0 + self.world_size = 4 + self.device = torch.device("cpu") + + # mock dist + p1 = patch("torch.distributed.get_rank", return_value=self.rank) + p2 = patch("torch.distributed.get_world_size", return_value=self.world_size) + self.addCleanup(p1.stop) + self.addCleanup(p2.stop) + p1.start() + p2.start() + + # ====================== 2. Mock comm group ====================== + self.mock_comm_group = MagicMock() + + def mock_all_gather(tensor, dim): + gathered = torch.cat([tensor for _ in range(self.world_size)], dim=dim) + return gathered + + self.mock_comm_group.all_gather = mock_all_gather + + p3 = patch("vllm_ascend.eplb.eplb_updator.get_dynamic_eplb_group", + return_value=self.mock_comm_group) + self.addCleanup(p3.stop) + p3.start() + + # ====================== 3. Mock EplbUpdator ====================== + self.eplb_config = MagicMock() + self.loader = MagicMock() + self.eplb_process = MagicMock() + self.process = MagicMock() + self.eplb_process.shared_dict = {} + + self.updator = EplbUpdator( + eplb_config=self.eplb_config, + loader=self.loader, + eplb_process=self.eplb_process, + process=self.process + ) + + # ====================== 4. Mock adaptor ====================== + self.adaptor = MagicMock() + self.adaptor.num_moe_layers = 4 + self.adaptor.num_dense_layers = 2 + self.mock_local_load = torch.randn(58, 100, 8, device=self.device) + self.adaptor.get_rank_expert_workload.return_value = self.mock_local_load + + self.updator.set_adaptor(self.adaptor) + + def test_compute_and_set_moe_load_normal(self): + self.updator.multi_stage = False + + moe_load = self.updator.compute_and_set_moe_load() + + self.assertEqual(moe_load.shape, (58, self.world_size, 100, 8)) + self.assertTrue("moe_load" in self.updator.shared_dict) + self.assertEqual(moe_load.device.type, "cpu") + self.assertEqual(moe_load.shape[1], self.world_size) + + def test_compute_and_set_moe_load_multi_stage(self): + self.updator.multi_stage = True + + moe_load = self.updator.compute_and_set_moe_load() + + self.assertEqual(moe_load.shape, (100, 58, self.world_size, 8)) + self.assertTrue("moe_load" in self.updator.shared_dict) + self.assertEqual(moe_load.device.type, "cpu") + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 285bd486..c3e1f32f 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -131,10 +131,8 @@ class EplbUpdator: self.update_iteration() def compute_and_set_moe_load(self): - local_load = self.adaptor.get_rank_expert_workload() - moe_load = ( - self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu() - ) + local_load = self.adaptor.get_rank_expert_workload().unsqueeze(1) + moe_load = self.comm_group.all_gather(local_load, dim=1).cpu() if self.multi_stage: moe_load = moe_load.permute(2, 0, 1, 3)