[V0.18.0][EPLB][BugFix] Fix moe_load precision in allgather (#7890)
### What this PR does / why we need it? Fixed the bug of incorrect reshape usage. For example: ori_tensor: [[1, 2, 3], [4, 5, 6]] after reshape: [[1, 2], [3, 4], [5, 6]] after permute: [[1, 4], [2, 5], [3, 6]] Now, we will directly use squeeze for a more intuitive understanding. pr for main: #7887 ### Does this PR introduce _any_ user-facing change? The actual peak-to-average ratio has successfully decreased. Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
81
tests/ut/eplb/test_eplb_updator.py
Normal file
81
tests/ut/eplb/test_eplb_updator.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
|
||||
|
||||
class TestEplbUpdatorComputeAndSetMoeLoad(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
# ====================== 1. Mock environment ======================
|
||||
self.rank = 0
|
||||
self.world_size = 4
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# mock dist
|
||||
p1 = patch("torch.distributed.get_rank", return_value=self.rank)
|
||||
p2 = patch("torch.distributed.get_world_size", return_value=self.world_size)
|
||||
self.addCleanup(p1.stop)
|
||||
self.addCleanup(p2.stop)
|
||||
p1.start()
|
||||
p2.start()
|
||||
|
||||
# ====================== 2. Mock comm group ======================
|
||||
self.mock_comm_group = MagicMock()
|
||||
|
||||
def mock_all_gather(tensor, dim):
|
||||
gathered = torch.cat([tensor for _ in range(self.world_size)], dim=dim)
|
||||
return gathered
|
||||
|
||||
self.mock_comm_group.all_gather = mock_all_gather
|
||||
|
||||
p3 = patch("vllm_ascend.eplb.eplb_updator.get_dynamic_eplb_group",
|
||||
return_value=self.mock_comm_group)
|
||||
self.addCleanup(p3.stop)
|
||||
p3.start()
|
||||
|
||||
# ====================== 3. Mock EplbUpdator ======================
|
||||
self.eplb_config = MagicMock()
|
||||
self.loader = MagicMock()
|
||||
self.eplb_process = MagicMock()
|
||||
self.process = MagicMock()
|
||||
self.eplb_process.shared_dict = {}
|
||||
|
||||
self.updator = EplbUpdator(
|
||||
eplb_config=self.eplb_config,
|
||||
loader=self.loader,
|
||||
eplb_process=self.eplb_process,
|
||||
process=self.process
|
||||
)
|
||||
|
||||
# ====================== 4. Mock adaptor ======================
|
||||
self.adaptor = MagicMock()
|
||||
self.adaptor.num_moe_layers = 4
|
||||
self.adaptor.num_dense_layers = 2
|
||||
self.mock_local_load = torch.randn(58, 100, 8, device=self.device)
|
||||
self.adaptor.get_rank_expert_workload.return_value = self.mock_local_load
|
||||
|
||||
self.updator.set_adaptor(self.adaptor)
|
||||
|
||||
def test_compute_and_set_moe_load_normal(self):
|
||||
self.updator.multi_stage = False
|
||||
|
||||
moe_load = self.updator.compute_and_set_moe_load()
|
||||
|
||||
self.assertEqual(moe_load.shape, (58, self.world_size, 100, 8))
|
||||
self.assertTrue("moe_load" in self.updator.shared_dict)
|
||||
self.assertEqual(moe_load.device.type, "cpu")
|
||||
self.assertEqual(moe_load.shape[1], self.world_size)
|
||||
|
||||
def test_compute_and_set_moe_load_multi_stage(self):
|
||||
self.updator.multi_stage = True
|
||||
|
||||
moe_load = self.updator.compute_and_set_moe_load()
|
||||
|
||||
self.assertEqual(moe_load.shape, (100, 58, self.world_size, 8))
|
||||
self.assertTrue("moe_load" in self.updator.shared_dict)
|
||||
self.assertEqual(moe_load.device.type, "cpu")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -131,10 +131,8 @@ class EplbUpdator:
|
||||
self.update_iteration()
|
||||
|
||||
def compute_and_set_moe_load(self):
|
||||
local_load = self.adaptor.get_rank_expert_workload()
|
||||
moe_load = (
|
||||
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
|
||||
)
|
||||
local_load = self.adaptor.get_rank_expert_workload().unsqueeze(1)
|
||||
moe_load = self.comm_group.all_gather(local_load, dim=1).cpu()
|
||||
|
||||
if self.multi_stage:
|
||||
moe_load = moe_load.permute(2, 0, 1, 3)
|
||||
|
||||
Reference in New Issue
Block a user