[EPLB][refactor] Modification of the initialization logic for expert_map and log2phy(depend on pr5285) (#5311)
### What this PR does / why we need it?
Unify the loading logic for expert_map and log2phy.
1. The map generated when enabling the redundancy expert is incorrect.
The community generation map function only accepts the number of global
experts. When we pass in the number of logical experts plus redundant
experts, the local expert ID of the last card will index to an expert ID
that does not exist. Now we ensure that the index points to a real
existing expert ID, and each expert can be accessed. Moreover, when
redundant experts are not enabled, the output of our function remains
consistent with the community's function.
2. The map we generate is based on the length of the physical expert,
but in reality, we only need to use the length of the logical expert.
Later on, we will need to pad it accordingly, so we can simply generate
a map with the length of the logical [expert.]
3. Unify the initialization logic across different scenarios and
simplify the code for fused_moe.
**Before refactoring**
- map path is not None:
expert map: get_rank_placement_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
log2phy: get_rank_log2phy_map from _'expert_load_balancer.py'_,
maintains the map for all ranks and all layers.
- map path is None:
expert map: determine_expert_map from '_vllm.laye_r', The function does
not support the redundant experts of vllm-ascend.
log2phy: determine_default_log2phy_map from _'eplb_utils.py'_. The
function does not support the redundant experts of vllm-ascend.
**Refactoring**
eplb_utils.py
init_eplb_config
generate placement
generate expert map
generate log2phy
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 16
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 0 1 2 3 4 5 6 7 8
9 10 11 12 13 14 15 16]
+++++++++++++++++++++++++++++++++++++++++
Improved map:
[16 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
Expert Mapping Test Generation:
ep size: 16, num of experts: 256, num of redundant experts: 0
+++++++++++++++++++++++++++++++++++++++++
Expert Mapping (Non-1 indicates the expert responsible for this rank)
for Rank 15:
vllm map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
+++++++++++++++++++++++++++++++++++++++
Improved map:
[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15]
dsr1 baselie:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
dsr1 eplb:
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| gsm8k-lite | 7cd45e | accuracy | gen | 100.00 |
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -15,87 +15,111 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove eplb utils.
|
||||
import json
|
||||
import os.path
|
||||
import random
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
|
||||
def generate_log2phy_map(expert_map):
|
||||
num_local_experts = expert_map.max() + 1
|
||||
log2phy_map = expert_map.clone()
|
||||
num_ranks, num_global_expert = log2phy_map.shape
|
||||
def expert_file_to_tensor(expert_map_path, layer_id):
|
||||
with open(expert_map_path, "r") as f:
|
||||
data = json.load(f)
|
||||
physical_count = 0
|
||||
device_data = []
|
||||
if layer_id > data["moe_layer_count"]:
|
||||
raise ValueError("Invalid EPLB Table")
|
||||
if layer_id == data["moe_layer_count"]:
|
||||
logger.warning("Init expert map of mtp/eagle when using sample.")
|
||||
return None, None
|
||||
for device in data["layer_list"][layer_id]["device_list"]:
|
||||
physical_count += len(device["device_expert"])
|
||||
device_data.append(device["device_expert"])
|
||||
global_placement = torch.tensor(device_data, dtype=torch.int32)
|
||||
return global_placement, physical_count
|
||||
|
||||
row_indices = torch.arange(num_ranks).view(-1, 1).expand(num_ranks, \
|
||||
num_global_expert) * num_local_experts
|
||||
log2phy_map[log2phy_map != -1] += row_indices[log2phy_map != -1]
|
||||
|
||||
for idx in range(num_global_expert):
|
||||
positive_rank_idx = torch.where(log2phy_map[:, idx] != -1)[0]
|
||||
negative_rank_idx = torch.where(log2phy_map[:, idx] == -1)[0]
|
||||
num_rank_holding_expert = positive_rank_idx.size(0)
|
||||
|
||||
if num_rank_holding_expert == 0:
|
||||
log2phy_map[:, idx] = torch.full((num_ranks, ),
|
||||
0,
|
||||
dtype=log2phy_map.dtype)
|
||||
|
||||
if num_rank_holding_expert == 1:
|
||||
log2phy_map[negative_rank_idx, idx] = torch.full(
|
||||
(num_ranks - 1, ),
|
||||
log2phy_map[positive_rank_idx, idx].item(),
|
||||
dtype=log2phy_map.dtype)
|
||||
def generate_global_placement(n_expert, ep_size, n_redundant):
|
||||
all_experts = np.arange(n_expert)
|
||||
groups = np.array_split(all_experts, ep_size)
|
||||
for i in range(n_redundant):
|
||||
j = i % ep_size + 1
|
||||
if len(groups[-j]) == 0:
|
||||
groups[-j] = np.append(groups[-j], j)
|
||||
else:
|
||||
try:
|
||||
random_list = [
|
||||
random.choice(log2phy_map[positive_rank_idx, idx])
|
||||
for _ in range(num_ranks - num_rank_holding_expert)
|
||||
]
|
||||
log2phy_map[negative_rank_idx,
|
||||
idx] = torch.tensor(random_list,
|
||||
dtype=log2phy_map.dtype)
|
||||
except Exception as e:
|
||||
logger.error(f"Fail to get log2phy_map: {str(e)}")
|
||||
groups[-j] = np.append(groups[-j], (groups[-j][-1] + 1) % n_expert)
|
||||
return torch.tensor(groups, dtype=torch.int32)
|
||||
|
||||
|
||||
def init_eplb_config(ascend_config, layer_id, moe_config):
|
||||
expert_map_path = ascend_config.expert_map_path
|
||||
n_experts = moe_config.num_experts
|
||||
ep_size = moe_config.ep_size
|
||||
global_placement = None
|
||||
eplb_enable = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
n_redundant = ascend_config.init_redundancy_expert if eplb_enable else 0
|
||||
if expert_map_path:
|
||||
if not (os.path.exists(expert_map_path)
|
||||
and os.access(expert_map_path, os.R_OK)):
|
||||
raise ValueError("Invalid EPLB path")
|
||||
eplb_enable = True
|
||||
global_placement, physical_count = expert_file_to_tensor(
|
||||
expert_map_path, layer_id)
|
||||
if physical_count is not None:
|
||||
n_redundant = physical_count - n_experts
|
||||
if not moe_config.supports_eplb:
|
||||
raise ValueError(
|
||||
"Eplb supports only w8a8_dynamic quantization.")
|
||||
else:
|
||||
eplb_enable = False
|
||||
|
||||
if global_placement is None:
|
||||
global_placement = generate_global_placement(n_experts, ep_size,
|
||||
n_redundant)
|
||||
|
||||
if ep_size == 1:
|
||||
return None, None, n_redundant
|
||||
global_expert_map = []
|
||||
for rankid in range(ep_size):
|
||||
expert_map = torch.full((n_experts, ), -1, dtype=torch.int32)
|
||||
local_placement = global_placement[rankid]
|
||||
expert_map[local_placement] = torch.arange(local_placement.shape[0],
|
||||
dtype=torch.int32)
|
||||
global_expert_map.append(expert_map)
|
||||
local_expert_map = global_expert_map[moe_config.ep_rank].npu()
|
||||
log2phy = generate_log2phy_map(
|
||||
global_expert_map, moe_config.ep_rank).npu() if eplb_enable else None
|
||||
|
||||
return local_expert_map, log2phy, n_redundant
|
||||
|
||||
|
||||
def generate_log2phy_map(global_expert_map, ep_rank):
|
||||
log2phy_map = defaultdict(list)
|
||||
valid_count = torch.sum(global_expert_map[0] != -1)
|
||||
for rankid, map_per_rank in enumerate(global_expert_map):
|
||||
for idx, val in enumerate(map_per_rank):
|
||||
val = val.item()
|
||||
# 计算value:当前值 + i * 有效元素个数
|
||||
if val != -1:
|
||||
log2phy_map[idx].append(val + rankid * valid_count)
|
||||
|
||||
for key in log2phy_map.keys():
|
||||
num_of_duplications = len(log2phy_map[key])
|
||||
log2phy_map[key] = log2phy_map[key][ep_rank % num_of_duplications]
|
||||
|
||||
log2phy_map = torch.scatter(
|
||||
torch.zeros(len(log2phy_map.keys()), dtype=torch.int32), 0,
|
||||
torch.tensor(list(log2phy_map.keys()), dtype=torch.int64),
|
||||
torch.tensor(list(log2phy_map.values()), dtype=torch.int32))
|
||||
|
||||
return log2phy_map
|
||||
|
||||
|
||||
def determine_default_log2phy_map(global_expert_num, world_size, rank_id):
|
||||
if world_size == 1:
|
||||
local_ids = torch.arange(global_expert_num, dtype=torch.int32)
|
||||
expert_map_all = local_ids.unsqueeze(0).expand(world_size, -1)
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
return log2phy_map_all[rank_id]
|
||||
|
||||
local_num_experts = global_expert_num // world_size
|
||||
|
||||
expert_map_all = torch.full((world_size, global_expert_num),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
|
||||
for r in range(world_size):
|
||||
if r < world_size - 1:
|
||||
start = r * local_num_experts
|
||||
end = (r + 1) * local_num_experts
|
||||
local_count = local_num_experts
|
||||
else:
|
||||
start = r * local_num_experts
|
||||
end = global_expert_num
|
||||
local_count = global_expert_num - r * local_num_experts
|
||||
|
||||
if isinstance(local_count, int):
|
||||
local_ids = torch.arange(local_count, dtype=torch.int32)
|
||||
expert_map_all[r, start:end] = local_ids
|
||||
|
||||
log2phy_map_all = generate_log2phy_map(expert_map_all)
|
||||
|
||||
return log2phy_map_all[rank_id]
|
||||
|
||||
|
||||
class EPLBParamUtils:
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -377,8 +377,8 @@ class EplbWorker:
|
||||
|
||||
maps.append(new_expert_map[self.rank_id].numpy().tolist())
|
||||
|
||||
log2phy_map = generate_log2phy_map(new_expert_map)
|
||||
log2phy_all.append(log2phy_map[self.rank_id].numpy().tolist())
|
||||
log2phy_map = generate_log2phy_map(new_expert_map, self.rank_id)
|
||||
log2phy_all.append(log2phy_map.numpy().tolist())
|
||||
|
||||
layer_ids.append(layer_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user