[EPLB][Bugfix] Get expert map from layers (#5817)

### What this PR does / why we need it?
The initialization method of expert_map used by the eplb module is
different from that used by the fused_moe module. This PR deletes the
expert_map initialization method used by the eplb module to make the
initialization methods consistent.

#### before bugfix
self._expert_map=tensor([64, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2,
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
59, 60, 61,62, 63], device='npu:1', dtype=torch.int32)

self.shared_dict["expert_maps"][0]=tensor([-1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64]], dtype=torch.int32)

### How was this patch tested?

#### qwen3-235B-w8a8 aime
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 86.67 |

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-14 09:16:51 +08:00
committed by GitHub
parent 48ec97821a
commit ecf2fa482e
7 changed files with 23 additions and 173 deletions

View File

@@ -12,9 +12,6 @@ class DummyAdaptor(EplbAdaptor):
def get_rank_expert_workload(self):
return "workload"
def get_init_expert_map(self, num_moe_layers):
return {"layers": num_moe_layers}
def do_update_expert_map(self, layer_id, updated_expert_map):
return {"layer_id": layer_id, "map": updated_expert_map}
@@ -31,8 +28,6 @@ def test_base_class_methods_raise():
adaptor = EplbAdaptor()
with pytest.raises(NotImplementedError):
adaptor.get_rank_expert_workload()
with pytest.raises(NotImplementedError):
adaptor.get_init_expert_map(1)
with pytest.raises(NotImplementedError):
adaptor.do_update_expert_map(1, {})
with pytest.raises(NotImplementedError):
@@ -50,13 +45,6 @@ def test_get_rank_expert_workload():
assert result == "workload"
def test_get_init_expert_map():
adaptor = DummyAdaptor()
result = adaptor.get_init_expert_map(5)
assert isinstance(result, dict)
assert result["layers"] == 5
def test_do_update_expert_map():
adaptor = DummyAdaptor()
updated = {"expert": 1}

View File

@@ -32,13 +32,14 @@ class TestAscendConfig(unittest.TestCase):
self.moe_config = moe_config
self.mock_npu = patch("torch.Tensor.npu",
new=lambda self: self).start()
self.rank = 1
def test_init_eplb_config_with_eplb(self):
expert_map, log2phy, redundant_experts = init_eplb_config(
self.ascend_config, 0, self.moe_config)
gt_expert_map = torch.tensor([4, -1, -1, -1, 0, 1, 2, 3])
gt_log2phy = torch.tensor([9, 1, 2, 3, 5, 6, 7, 8])
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertTrue(torch.equal(log2phy, gt_log2phy))
self.assertEqual(redundant_experts, 2)
@@ -49,7 +50,7 @@ class TestAscendConfig(unittest.TestCase):
self.ascend_config, 0, self.moe_config)
gt_expert_map = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3])
gt_log2phy = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8])
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertTrue(torch.equal(log2phy, gt_log2phy))
self.assertEqual(redundant_experts, 2)
@@ -60,7 +61,7 @@ class TestAscendConfig(unittest.TestCase):
self.ascend_config, 0, self.moe_config)
gt_expert_map = torch.tensor([-1, -1, -1, -1, 0, 1, 2, 3])
print(expert_map, log2phy, redundant_experts)
self.assertTrue(torch.equal(expert_map, gt_expert_map))
self.assertTrue(torch.equal(expert_map[self.rank], gt_expert_map))
self.assertEqual(redundant_experts, 0)

View File

@@ -28,10 +28,6 @@ class EplbAdaptor():
def get_rank_expert_workload(self):
raise NotImplementedError
@abstractmethod
def get_init_expert_map(self, num_moe_layers: Any) -> Any:
raise NotImplementedError
@abstractmethod
def do_update_expert_map(self, layer_id: Any,
updated_expert_map: Any) -> Any:

View File

@@ -22,7 +22,6 @@ import torch
import torch.distributed as dist
from vllm.logger import logger
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
@@ -41,8 +40,6 @@ class VllmEplbAdaptor(EplbAdaptor):
self.num_dense_layers = self.model.config.first_k_dense_replace
self.global_expert_num = self.model.config.n_routed_experts
self.num_moe_layers = self.model.config.num_hidden_layers - self.num_dense_layers
self.init_redundancy_expert = get_ascend_config(
).init_redundancy_expert
for i in range(self.num_dense_layers,
self.model.config.num_hidden_layers):
@@ -139,67 +136,11 @@ class VllmEplbAdaptor(EplbAdaptor):
self.moe_load = self.model.get_all_moe_loads()
return self.moe_load
def get_init_expert_map(self, num_moe_layers):
expert_map = self.model.get_all_expert_map(num_moe_layers)
if dist.is_initialized():
world_size = dist.get_world_size()
gathered = torch.empty(
(world_size, *expert_map.shape), # [W, L, E]
dtype=expert_map.dtype,
device=expert_map.device)
dist.all_gather_into_tensor(gathered, expert_map)
all_maps = gathered.permute(1, 0, 2)
all_expert_maps = all_maps.cpu()
for layer_idx in range(num_moe_layers):
self.expert_map_per_layer_cpu[self.num_dense_layers + layer_idx] = \
all_expert_maps[layer_idx][self.rank_id]
return all_expert_maps
def get_init_expert_map_from_file(self, num_moe_layers, expert_map_path):
try:
expert_map_tensor, layers_num, ranks_num = self._expert_file_to_tensor(
expert_map_path)
expert_map_all = self.local2global(expert_map_tensor)
except (TypeError, FileNotFoundError, OSError):
expert_map_all = self.determine_expert_map_all()
for layer_idx in range(num_moe_layers):
if self.model.config.model_type == "qwen3_moe":
self.expert_map_per_layer_cpu[layer_idx] = \
expert_map_all[layer_idx][self.rank_id]
else:
self.expert_map_per_layer_cpu[layer_idx + self.num_dense_layers] = \
expert_map_all[layer_idx][self.rank_id]
return expert_map_all
def _expert_file_to_tensor(self, expert_map_path: str):
with open(expert_map_path, "r") as f:
data = json.load(f)
layers_num = data["moe_layer_count"]
gpus_num = data["layer_list"][0]["device_count"]
tensor_data = []
for layer in data["layer_list"]:
device_data = []
for device in layer["device_list"]:
device_data.append(device["device_expert"])
tensor_data.append(device_data)
expert_map_tensor = torch.tensor(tensor_data, dtype=torch.int32)
return expert_map_tensor, layers_num, gpus_num
logger.error(f"failed to read expert_map_path: {expert_map_path}")
def _export_tensor_to_file(self, expert_maps, expert_map_record_path: str):
if self.rank_id == 0:
num_local_experts = expert_maps.max() + 1
expert_maps_local = self.global2local(expert_maps,
num_local_experts)
expert_maps_list = expert_maps_local.tolist()
expert_maps_list = expert_maps.tolist()
record: dict[str, Any] = {
"moe_layer_count": len(expert_maps_list),
"layer_list": []
@@ -213,9 +154,12 @@ class VllmEplbAdaptor(EplbAdaptor):
}
for device_idx, experts in enumerate(layer_data):
placement = [
experts.index(i) for i in range(num_local_experts)
]
device_record = {
"device_id": device_idx,
"device_expert": experts
"device_expert": placement
}
layer_record["device_list"].append(device_record)
@@ -240,81 +184,13 @@ class VllmEplbAdaptor(EplbAdaptor):
if self.log2phy_map_per_layer[layer_id] is not None:
self.log2phy_map_per_layer[layer_id].copy_(updated_log2phy_map)
def global2local(self, placement: torch.Tensor,
E_local: int) -> torch.Tensor:
def get_global_expert_map(self):
all_layer_global_expert_map = []
for layer_id in range(self.num_moe_layers):
map_cpu = self.model.model.layers[
layer_id].mlp.experts.global_expert_map.cpu()
all_layer_global_expert_map.append(map_cpu)
self.expert_map_per_layer_cpu[self.num_dense_layers +
layer_id] = map_cpu[self.rank_id]
L, G, _ = placement.shape
device = placement.device
pt_local = torch.full((L, G, E_local),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement >= 0
l_idx, g_idx, k_idx = valid.nonzero(as_tuple=True)
slot_idx = placement[l_idx, g_idx, k_idx]
pt_local[l_idx, g_idx, slot_idx] = k_idx
return pt_local
def local2global(self, placement_local: torch.Tensor) -> torch.Tensor:
L, G, E_local = placement_local.shape
device = placement_local.device
max_id = torch.max(placement_local)
E_global = (max_id + 1).item() if max_id >= 0 else 0
if E_global == 0:
return torch.empty((L, G, 0), dtype=torch.long, device=device)
placement_global = torch.full((L, G, E_global),
fill_value=-1,
dtype=torch.long,
device=device)
valid = placement_local >= 0
l_idx, g_idx, slot_idx = valid.nonzero(as_tuple=True)
gid_idx = placement_local[l_idx, g_idx, slot_idx]
placement_global[l_idx, g_idx, gid_idx] = slot_idx
return placement_global
def determine_expert_map_all(self):
if self.world_size == 1:
local_ids = torch.arange(self.global_expert_num, dtype=torch.int32)
return local_ids.view(1, 1, -1).expand(self.num_moe_layers, 1, -1)
local_num_experts = self.global_expert_num // self.world_size
expert_map_all = torch.full(
(self.num_moe_layers, self.world_size, self.global_expert_num),
-1,
dtype=torch.int32)
for r in range(self.world_size):
if r < self.world_size - 1:
start = r * local_num_experts
end = (r + 1) * local_num_experts
local_count = local_num_experts
else:
start = r * local_num_experts
end = self.global_expert_num
local_count = self.global_expert_num - r * local_num_experts
if r < self.init_redundancy_expert:
local_count += 1
if end < self.global_expert_num:
end += 1
else:
start -= 1
local_ids = torch.arange(local_count, dtype=torch.int32)
expert_map_all[:, r, start:end] = local_ids.unsqueeze(0).expand(
self.num_moe_layers, -1)
return expert_map_all
return torch.stack(all_layer_global_expert_map)

View File

@@ -91,11 +91,10 @@ def init_eplb_config(ascend_config, layer_id, moe_config):
expert_map[local_placement] = torch.arange(local_placement.shape[0],
dtype=torch.int32)
global_expert_map.append(expert_map)
local_expert_map = global_expert_map[moe_config.ep_rank].npu()
log2phy = generate_log2phy_map(
global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
return local_expert_map, log2phy, n_redundant
return torch.stack(global_expert_map), log2phy, n_redundant
def generate_log2phy_map(global_expert_map, ep_rank):

View File

@@ -58,7 +58,6 @@ class EplbUpdator:
self.num_expert_load_gather = self.num_iterations_eplb_update
self.periodic_load_gather = False
self.expert_map_initialized = False
self.gate_eplb = self.ascend_config.gate_eplb
self.reqs = []
@@ -101,17 +100,6 @@ class EplbUpdator:
return (weight_update_counter >= 0
and weight_update_counter < self.num_moe_layers)
def get_init_expert_map(self):
try:
if not self.expert_map_initialized:
self.shared_dict[
"expert_maps"] = self.adaptor.get_init_expert_map_from_file(
self.num_moe_layers, self.expert_map_path)
self.expert_map_initialized = True
except Exception as e:
logger.warning(f"[ModelRunner] Failed to wake EPLB process: {e}",
exc_info=True)
def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1)
@@ -218,7 +206,7 @@ class EplbUpdator:
def warm_up_eplb(self):
self.get_init_expert_map()
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
self.compute_and_set_moe_load()
src_tensor = torch.empty((1, ), device=self.device)

View File

@@ -187,8 +187,10 @@ class AscendFusedMoE(FusedMoE):
dtype=vllm_config.model_config.dtype)
# init moe
self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
self.global_expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
ascend_config, self.moe_instance_id, self.moe_config)
if self.global_expert_map is not None:
self._expert_map = self.global_expert_map[self.ep_rank].npu()
self.global_num_experts = num_experts + self.global_redundant_expert_num
self.dynamic_eplb = (ascend_config.dynamic_eplb
or ascend_config.expert_map_record_path) and (