[EPLB][Nightly][Bugfix] Get expert from moe layer only (#5908)

### What this PR does / why we need it?
1. If the model has dense layers, the current code will attempt to
obtain the routing experts of the dense layers, which will cause an
error. This should be fixed by modifying the code to skip the dense
layers when obtaining the routing experts.
2. The global_expert_map that the function directly outputs a affects
the performance of dsv3.2.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

DeepSeek V3.1 conversation is normal.

#### aime precision test (dsv3.1)
baseline without eplb
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 66.67 |

eplb
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 70.00 |

- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-19 09:23:28 +08:00
committed by GitHub
parent ad3a1eaf70
commit 9fed2636cb
4 changed files with 12 additions and 13 deletions

View File

@@ -34,15 +34,14 @@ class TestAscendConfig(unittest.TestCase):
self.moe_config = moe_config
self.mock_npu = patch("torch.Tensor.npu",
new=lambda self: self).start()
self.rank = 1
def test_init_eplb_config_with_eplb(self):
eplb_config = init_ascend_config(self.vllm_config).eplb_config
expert_map, log2phy, redundant_experts = init_eplb_config(
_, expert_map, log2phy, redundant_experts = init_eplb_config(
eplb_config, 0, self.moe_config)
gt_expert_map = torch.tensor([4, -1, -1, -1, 0, 1, 2, 3])
gt_log2phy = torch.tensor([9, 1, 2, 3, 5, 6, 7, 8])
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertTrue(torch.equal(log2phy, gt_log2phy))
self.assertEqual(redundant_experts, 2)
@@ -51,20 +50,20 @@ class TestAscendConfig(unittest.TestCase):
self.vllm_config.additional_config["eplb_config"][
"expert_map_path"] = _TEST_DIR + "/expert_map.json"
eplb_config = init_ascend_config(self.vllm_config).eplb_config
expert_map, log2phy, redundant_experts = init_eplb_config(
_, expert_map, log2phy, redundant_experts = init_eplb_config(
eplb_config, 0, self.moe_config)
gt_expert_map = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3])
gt_log2phy = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8])
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertTrue(torch.equal(log2phy, gt_log2phy))
self.assertEqual(redundant_experts, 2)
def test_init_eplb_config_without_eplb(self):
self.vllm_config.additional_config = {"refresh": True}
eplb_config = init_ascend_config(self.vllm_config).eplb_config
expert_map, log2phy, redundant_experts = init_eplb_config(
_, expert_map, log2phy, redundant_experts = init_eplb_config(
eplb_config, 0, self.moe_config)
gt_expert_map = torch.tensor([-1, -1, -1, -1, 0, 1, 2, 3])
print(expert_map, log2phy, redundant_experts)
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertEqual(redundant_experts, 0)

View File

@@ -188,7 +188,7 @@ class VllmEplbAdaptor(EplbAdaptor):
all_layer_global_expert_map = []
for layer_id in range(self.num_moe_layers):
map_cpu = self.model.model.layers[
layer_id].mlp.experts.global_expert_map.cpu()
self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu()
all_layer_global_expert_map.append(map_cpu)
self.expert_map_per_layer_cpu[self.num_dense_layers +
layer_id] = map_cpu[self.rank_id]

View File

@@ -81,7 +81,7 @@ def init_eplb_config(eplb_config, layer_id, moe_config):
if ep_size == 1:
assert not eplb_enable, "EPLB must used in expert parallelism."
return None, None, n_redundant
return None, None, None, n_redundant
global_expert_map = []
for rankid in range(ep_size):
expert_map = torch.full((n_experts, ), -1, dtype=torch.int32)
@@ -89,10 +89,12 @@ def init_eplb_config(eplb_config, layer_id, moe_config):
expert_map[local_placement] = torch.arange(local_placement.shape[0],
dtype=torch.int32)
global_expert_map.append(expert_map)
if rankid == moe_config.ep_rank:
local_expert_map = expert_map.npu()
log2phy = generate_log2phy_map(
global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
return torch.stack(global_expert_map), log2phy, n_redundant
return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant
def generate_log2phy_map(global_expert_map, ep_rank):

View File

@@ -202,10 +202,8 @@ class AscendFusedMoE(FusedMoE):
# init moe
eplb_config = ascend_config.eplb_config
self.global_expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
eplb_config, self.moe_instance_id, self.moe_config)
if self.global_expert_map is not None:
self._expert_map = self.global_expert_map[self.ep_rank].npu()
self.global_num_experts = num_experts + self.global_redundant_expert_num
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy
is not None)