From 9fed2636cbf478ad8a57c4deed62679f2b8d2510 Mon Sep 17 00:00:00 2001 From: LI SHENGYONG <49200266+shenchuxiaofugui@users.noreply.github.com> Date: Mon, 19 Jan 2026 09:23:28 +0800 Subject: [PATCH] [EPLB][Nightly][Bugfix] Get expert from moe layer only (#5908) ### What this PR does / why we need it? 1. If the model has dense layers, the current code will attempt to obtain the routing experts of the dense layers, which will cause an error. This should be fixed by modifying the code to skip the dense layers when obtaining the routing experts. 2. The global_expert_map that the function directly outputs a affects the performance of dsv3.2. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? DeepSeek V3.1 conversation is normal. #### aime precision test (dsv3.1) baseline without eplb | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 66.67 | eplb | dataset | version | metric | mode | vllm-api-general-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 70.00 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 Signed-off-by: shenchuxiaofugui <1311027364@qq.com> --- tests/ut/eplb/core/test_eplb_utils.py | 13 ++++++------- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 2 +- vllm_ascend/eplb/core/eplb_utils.py | 6 ++++-- vllm_ascend/ops/fused_moe/fused_moe.py | 4 +--- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/ut/eplb/core/test_eplb_utils.py b/tests/ut/eplb/core/test_eplb_utils.py index 4d7f8fee..bc112c89 100644 --- a/tests/ut/eplb/core/test_eplb_utils.py +++ b/tests/ut/eplb/core/test_eplb_utils.py @@ -34,15 +34,14 @@ class TestAscendConfig(unittest.TestCase): self.moe_config = moe_config self.mock_npu = patch("torch.Tensor.npu", new=lambda self: self).start() - self.rank = 1 def test_init_eplb_config_with_eplb(self): eplb_config = init_ascend_config(self.vllm_config).eplb_config - expert_map, log2phy, redundant_experts = init_eplb_config( + _, expert_map, log2phy, redundant_experts = init_eplb_config( eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([4, -1, -1, -1, 0, 1, 2, 3]) gt_log2phy = torch.tensor([9, 1, 2, 3, 5, 6, 7, 8]) - self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map)) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) self.assertTrue(torch.equal(log2phy, gt_log2phy)) self.assertEqual(redundant_experts, 2) @@ -51,20 +50,20 @@ class TestAscendConfig(unittest.TestCase): self.vllm_config.additional_config["eplb_config"][ "expert_map_path"] = _TEST_DIR + "/expert_map.json" eplb_config = init_ascend_config(self.vllm_config).eplb_config - expert_map, log2phy, redundant_experts = init_eplb_config( + _, expert_map, log2phy, redundant_experts = init_eplb_config( eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3]) gt_log2phy = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8]) - self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map)) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) self.assertTrue(torch.equal(log2phy, gt_log2phy)) self.assertEqual(redundant_experts, 2) def test_init_eplb_config_without_eplb(self): self.vllm_config.additional_config = {"refresh": True} eplb_config = init_ascend_config(self.vllm_config).eplb_config - expert_map, log2phy, redundant_experts = init_eplb_config( + _, expert_map, log2phy, redundant_experts = init_eplb_config( eplb_config, 0, self.moe_config) gt_expert_map = torch.tensor([-1, -1, -1, -1, 0, 1, 2, 3]) print(expert_map, log2phy, redundant_experts) - self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map)) + self.assertTrue(torch.equal(expert_map, gt_expert_map)) self.assertEqual(redundant_experts, 0) diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 2a347539..94cedc1d 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -188,7 +188,7 @@ class VllmEplbAdaptor(EplbAdaptor): all_layer_global_expert_map = [] for layer_id in range(self.num_moe_layers): map_cpu = self.model.model.layers[ - layer_id].mlp.experts.global_expert_map.cpu() + self.num_dense_layers + layer_id].mlp.experts.global_expert_map.cpu() all_layer_global_expert_map.append(map_cpu) self.expert_map_per_layer_cpu[self.num_dense_layers + layer_id] = map_cpu[self.rank_id] diff --git a/vllm_ascend/eplb/core/eplb_utils.py b/vllm_ascend/eplb/core/eplb_utils.py index 4faab717..88032d09 100644 --- a/vllm_ascend/eplb/core/eplb_utils.py +++ b/vllm_ascend/eplb/core/eplb_utils.py @@ -81,7 +81,7 @@ def init_eplb_config(eplb_config, layer_id, moe_config): if ep_size == 1: assert not eplb_enable, "EPLB must used in expert parallelism." - return None, None, n_redundant + return None, None, None, n_redundant global_expert_map = [] for rankid in range(ep_size): expert_map = torch.full((n_experts, ), -1, dtype=torch.int32) @@ -89,10 +89,12 @@ def init_eplb_config(eplb_config, layer_id, moe_config): expert_map[local_placement] = torch.arange(local_placement.shape[0], dtype=torch.int32) global_expert_map.append(expert_map) + if rankid == moe_config.ep_rank: + local_expert_map = expert_map.npu() log2phy = generate_log2phy_map( global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None - return torch.stack(global_expert_map), log2phy, n_redundant + return torch.stack(global_expert_map), local_expert_map, log2phy, n_redundant def generate_log2phy_map(global_expert_map, ep_rank): diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 81bf0796..0aed4a61 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -202,10 +202,8 @@ class AscendFusedMoE(FusedMoE): # init moe eplb_config = ascend_config.eplb_config - self.global_expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config( + self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config( eplb_config, self.moe_instance_id, self.moe_config) - if self.global_expert_map is not None: - self._expert_map = self.global_expert_map[self.ep_rank].npu() self.global_num_experts = num_experts + self.global_redundant_expert_num self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None)