【EPLB】Eplb Redundant Experts Bugfix (#4232)

### What this PR does / why we need it?
Redundant experts bugfix
The calculation logic for redundant experts has been fixed, allowing the
correct number of redundant experts to be calculated using the map.
Therefore, there is no longer a need to set the redundant expert
parameter when passing the map.

### Does this PR introduce _any_ user-facing change?
After configuring the path for experts_map, users do not need to
configure iinit_redundancy_expert.

### How was this patch tested?
The accuracy of EPLB was tested with and without the use of redundant
experts.

---------

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2025-12-03 12:00:05 +08:00
committed by GitHub
parent b6d63bbd52
commit 593a96056c
9 changed files with 45 additions and 65 deletions

View File

@@ -64,21 +64,17 @@ def test_generate_log2phy_map_multiple_rank_holding(monkeypatch):
def test_determine_default_log2phy_map_world_size_1():
log2phy = eplb_utils.determine_default_log2phy_map(
global_expert_num=3,
world_size=1,
rank_id=0,
global_redundant_expert_num=0)
log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=3,
world_size=1,
rank_id=0)
assert log2phy.shape == (3, )
assert (log2phy >= 0).all()
def test_determine_default_log2phy_map_world_size_multiple():
log2phy = eplb_utils.determine_default_log2phy_map(
global_expert_num=6,
world_size=2,
rank_id=1,
global_redundant_expert_num=1)
log2phy = eplb_utils.determine_default_log2phy_map(global_expert_num=6,
world_size=2,
rank_id=1)
assert log2phy.shape == (6, )
assert (log2phy >= 0).all()

View File

@@ -48,8 +48,7 @@ class TestExpertLoadBalancer(TestBase):
with open(json_file, 'r') as f:
self.expert_map: MockData = json.load(f)
self.expert_load_balancer = ExpertLoadBalancer(json_file,
global_expert_num=8)
self.expert_load_balancer = ExpertLoadBalancer(json_file, 8)
def test_init(self):
@@ -83,7 +82,7 @@ class TestExpertLoadBalancer(TestBase):
)
self.assertEqual(expert_placement_map.shape,
(self.expert_load_balancer.layers_num,
self.expert_load_balancer.ranks_num, 8))
self.expert_load_balancer.ranks_num, 10))
self.assertTrue(torch.all(expert_placement_map >= -1))
def test_generate_log2phy_expert_map(self):
@@ -91,7 +90,7 @@ class TestExpertLoadBalancer(TestBase):
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
layer_id)
self.assertEqual(log2phy_map.shape,
(self.expert_load_balancer.ranks_num, 8))
(self.expert_load_balancer.ranks_num, 10))
self.assertTrue(torch.all(log2phy_map >= -1))
@mock.patch("torch_npu.npu._lazy_init")
@@ -102,7 +101,7 @@ class TestExpertLoadBalancer(TestBase):
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
self.assertEqual(rank_local_expert_num, 5)
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0, -1, -1],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
@@ -110,7 +109,7 @@ class TestExpertLoadBalancer(TestBase):
rank_id = 1
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3, -1, -1],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
@@ -120,7 +119,7 @@ class TestExpertLoadBalancer(TestBase):
rank_id = 0
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0, -1, -1],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
@@ -128,7 +127,7 @@ class TestExpertLoadBalancer(TestBase):
rank_id = 1
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8, -1, -1],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))