[EPLB] support deepseek eplb strategy (#1196)

### What this PR does / why we need it?

This PR implements the DeepSeek Expert Parallel Load Balancing (EPLB)
strategy to optimize expert distribution in vllm-ascend. The
implementation:
- Adapts the expert-map format to work with vllm-ascend's architecture
- Provides DeepSeek-provided mechanism to balance expert workload across
devices

### Does this PR introduce _any_ user-facing change?

This PR adds a new script that allows users to:
- Generate expert map configurations based on workload analysis
- Optimize expert distribution for their specific use case

### How was this patch tested?

To use this feature:
1. First collect expert heat information during model execution
2. Run the provided script to generate the expert map configuration
3. Apply the generated configuration to your vllm-ascend deployment

User example:

```bash
# expert_load_view.pt:  dumped expert heat info file
python3 examples/eplb/eplb_strategy.py --exp_name 'deepseek_demo' \
    --input_path expert_load_view.pt  --output_path examples/eplb/results/demo \
    --num_nodes 4
```

---------

Signed-off-by: ZhengWG <zwg0606@gmail.com>
This commit is contained in:
Zheng Wengang
2025-07-07 17:22:08 +08:00
committed by GitHub
parent 4e29c5a808
commit 9c886d0a1f
2 changed files with 388 additions and 0 deletions

View File

@@ -0,0 +1,205 @@
# SPDX-License-Identifier: Apache-2.0
"""
Expert parallelism load balancer (EPLB) for vLLM.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
"""
from typing import Tuple
import torch
def balanced_packing(weight: torch.Tensor,
num_packs: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs
are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index = torch.arange(weight.size(-1),
dtype=torch.int64,
device=weight.device).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight,
fill_value=-1,
dtype=torch.int64,
device='cpu')
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
pack = min(
(i
for i in range(num_packs) if pack_items[i] < groups_per_pack),
key=pack_weights.__getitem__)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor,
num_phy: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64,
device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(weight: torch.Tensor,
num_physical_experts: int, num_groups: int,
num_nodes: int, num_gpus: int):
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [num_moe_layers, num_physical_experts]
logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
logical_count: [num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1, perm,
torch.arange(perm.size(1), dtype=torch.int64,
device=perm.device).expand(perm.shape))
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(
tokens_per_group, num_nodes)
log2mlog = (((group_pack_index * groups_per_node + group_rank_in_pack) *
group_size).unsqueeze(-1) +
torch.arange(group_size,
dtype=torch.int64,
device=group_pack_index.device)).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy,
num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (pphy2mlog.view(num_layers, num_nodes, -1) + torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device).view(1, -1, 1)).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: torch.Tensor, num_replicas: int, num_groups: int,
num_nodes: int,
num_gpus: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device)
log2phy.view(num_layers, -1).scatter_(
-1, phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64,
device=log2phy.device).expand(num_layers, -1))
return phy2log, log2phy, logcnt
__all__ = ['rebalance_experts']

View File

@@ -0,0 +1,183 @@
# coding=utf-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import json
import logging
import os
import matplotlib.pyplot as plt # type: ignore
import numpy as np
import torch
logger = logging.getLogger("msit_logger")
def save_matrix_to_json(output_path, file_name, deployment):
num_layers = deployment.shape[0]
num_cards = deployment.shape[1]
data = {"moe_layer_count": num_layers}
layer_list = []
for i in range(num_layers):
layer = {"layer_id": i, "device_count": num_cards}
device_list = []
for j in range(num_cards):
device = {
"device_id": j,
"device_expert": deployment[i, j].tolist()
}
device_list.append(device)
layer["device_list"] = device_list
layer_list.append(layer)
data["layer_list"] = layer_list
file_name = f"{output_path}{file_name}.json"
# Save as JSON file
try:
with open(file_name, 'w') as f:
json.dump(data, f, indent=4)
except Exception as e:
print(f"write {file_name} failed: {e}")
def calculate_average(lst):
"""calculate the average of a list"""
if not lst:
raise ValueError("list is empty")
total = 0.0
count = 0
for element in lst:
# Check if element is numeric
if isinstance(element, (int, float, np.int64, np.float64)):
total += float(element)
count += 1
else:
# Non-numeric elements will be ignored with a warning
print(f"warning: element {element} is not a number, ignored")
if count == 0:
raise ValueError("list does not contain any number")
return total / count
def layer_imblance_polt(y_list, label_names, device_num, output_path,
file_name):
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['axes.unicode_minus'] = False
x = [i for i in range(58)]
for index, y in enumerate(y_list):
plt.plot(x,
y,
label=rf'{label_names[index]}avg={calculate_average(y)}')
plt.legend()
plt.title(rf'Load Distribution (num_gpus={device_num})')
plt.xlabel('layer')
plt.ylabel('Device Load')
# Show grid lines
plt.grid(True)
plt.savefig(os.path.join(output_path, file_name), dpi=300)
# Clear current plot
plt.close()
def deepseek_deploy(workload, num_redundancy_expert, num_groups, num_nodes,
num_gpus, num_original_expert):
from eplb_deepseek import rebalance_experts
num_replicas = num_original_expert + num_redundancy_expert
hy2log, log2phy, logcnt = rebalance_experts(workload, num_replicas,
num_groups, num_nodes,
num_gpus)
# Convert to global_deployment
workload = workload.cpu().numpy()
global_deployment = []
layer_num = log2phy.shape[0]
num_physical_experts_local = (num_original_expert +
num_redundancy_expert) // num_gpus
for layer_idx in range(layer_num):
layer_deployment = []
for gpu_idx in range(num_gpus):
local_deployment = hy2log[layer_idx][gpu_idx *
num_physical_experts_local:
(gpu_idx + 1) *
num_physical_experts_local]
local_deployment = local_deployment.flatten()
layer_deployment.append(local_deployment.tolist())
global_deployment.append(layer_deployment)
# Remap expert distribution according to log2phy
original_weights = []
max_weights = []
average_weights = []
y_list = []
for layer_idx in range(layer_num):
new_value = workload[layer_idx].reshape(num_gpus, -1)
row_sum = np.sum(new_value, axis=1)
original_weights.append(row_sum.max())
average_weights.append((np.sum(workload[layer_idx]) / num_gpus))
opt_workload = np.zeros((num_original_expert + num_redundancy_expert),
dtype=np.float64)
for expert_idx in range(num_original_expert):
physical_expert_idxs = log2phy[layer_idx][expert_idx]
physical_expert_idxs = physical_expert_idxs.flatten()
physical_expert_idxs = physical_expert_idxs[
physical_expert_idxs != -1]
for physical_expert_idx in physical_expert_idxs:
opt_workload[physical_expert_idx] += workload[layer_idx][
expert_idx] / len(physical_expert_idxs)
opt_workload = opt_workload.reshape(num_gpus, -1)
row_sum = np.sum(opt_workload, axis=1)
max_weights.append(row_sum.max())
y_list = [original_weights, max_weights, average_weights]
return global_deployment, y_list
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--exp_name", type=str, default="gsm8k_temp0.0")
parser.add_argument("--num_original_expert", type=int, default=256)
parser.add_argument("--input_path", type=str, default="")
parser.add_argument("--output_path", type=str, default="")
parser.add_argument("--num_redundancy_expert", type=int, default=0)
parser.add_argument("--num_devices", type=int, default=32)
parser.add_argument("--num_groups", type=int, default=8)
parser.add_argument("--num_nodes", type=int, default=4)
args = parser.parse_args()
exp_name = args.exp_name
input_path = args.input_path
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
num_redundancy_expert = args.num_redundancy_expert
num_devices = args.num_devices
num_original_expert = args.num_original_expert
num_groups = args.num_groups
num_nodes = args.num_nodes
# NOTE: assume input workload format: [layer_num, num_experts]
workload = torch.load(input_path, map_location=torch.device('cpu'))
global_deployment, y_list = deepseek_deploy(workload,
num_redundancy_expert,
num_groups, num_nodes,
num_devices,
num_original_expert)
file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}"
save_matrix_to_json(output_path, file_name, np.array(global_deployment))
label_names = [
'default deployment max load', 'balanced load max load',
'balanced load avg load'
]
new_file_name = f"{exp_name}_{num_devices}_{num_redundancy_expert}.png"
layer_imblance_polt(y_list, label_names, num_devices, output_path,
new_file_name)