FlashLB algorithm (#3042)

## Purpose
This Pull Request enhances the EPLB (Expert Parallelism Load Balancing)
system by introducing a novel balancing algorithm: FlashLB.

## Motivation
1. The default algorithm adopts a two-stage greedy strategy: 
a. Replica allotment: Determine the number of expert replicas by
minimizing the maximum load per replica (Min Max Replica, MMR).
b. Replica placement: Distribute replicas across devices by repeatedly
assigning the heaviest replica to the least loaded device (Longest
Processing Time First, LPT).

However, this sequential process lacks inter-stage collaborative
optimization, often leading to suboptimal load balancing. For example,
in the simple case shown in the figure below: given 8 logical experts
with hotness values of 600, 560, 120, 120, 20, 10, 10, 10, and 2
replicas allocated per device across 8 devices, the EPLB algorithm
yields a maximum per-device hotness of 232, while our proposed FlashLB
algorithm can reduce this value to 205.

2. The default algorithm relies on the averaged expert hotness over a
fixed time window for optimization. While this provides a coarse
approximation of the hotness distribution, it fails to capture
oscillatory deviations and temporal correlations of expert hotness
observed across iterations in real-world scenarios, limiting
optimization quality.

3. The default algorithm periodically regenerates the expert placement
table. However, it generates the table for each individual layer, and
the new table does not account for correlations with the previous one;
these two factors collectively lead to nearly full-scale expert
reassignment.

## FlashLB Algorithm Principle
1. Joint Optimization
FlashLB achieves joint optimization of replica allotment and placement
through group-based decision-making. Each group gradually determines the
replica count and placement for a subset of experts, ensuring that the
expected inter-device load balance (considering both deployed and
pending expert replicas) is holistically optimized. To attain superior
load balancing, FlashLB employs tree search to expand the solution space
while integrating pruning and precompilation techniques for
acceleration, thereby delivering load balancing that is both
high-quality and practically efficient.

2. Multi-Shot Enhancement
FlashLB partitions each profiling interval (e.g., 1024 iterations) into
consecutive smaller sub-intervals (e.g., 16 iterations), each capturing
independent hotness measurements. It then performs multi-shot
optimization to co-optimize these sub-intervals simultaneously—enabling
adaptation to time-variant expert hotness while enhancing robustness.

3. Incremental Adjustment
To reduce the overhead of frequent expert re-deployment, FlashLB
introduces an incremental adjustment scheme operating at both
inter-layer and intra-layer levels:
a. Inter-Layer: Hotness variations are tracked at the layer level. Only
layers with fluctuations exceeding a predefined threshold trigger
re-computation of expert placement, avoiding unnecessary redeployment
for stable layers;
b. Intra-Layer (Optional): A lightweight incremental LPT algorithm
(LPT-Incremental) is applied. Instead of recomputing full placement for
all experts in a layer, it selectively adjusts only the hottest experts
or those with replica count changes, further reducing migration
overhead.

This incremental strategy significantly reduces adjustment costs while
maintaining balanced performance across layers and devices.

## Co-author:

Co-authored-by: Skywalker-EP 173723846@qq.com

- vLLM version: v0.10.2
- vLLM main:
9607d5eb44

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
Signed-off-by: zhanghaiwen <zhanghaiwen@cmss.chinamobile.com>
Signed-off-by: hfadzxy <starmoon_zhang@163.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: tangtianyi <tangtianyi4@huawei.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: rjg-lyh <1318825571@qq.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: fems14 <1804143737@qq.com>
Co-authored-by: sdmyzlp <117554856+sdmyzlp@users.noreply.github.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: 22dimensions <waitingwind@foxmail.com>
Co-authored-by: zhanghw0354 <zhanghaiwencmss@139.com>
Co-authored-by: zhanghaiwen <zhanghaiwen@cmss.chinamobile.com>
Co-authored-by: zhangxinyuehfad <59153331+zhangxinyuehfad@users.noreply.github.com>
Co-authored-by: Lucas Kabela <lucasakabela@gmail.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Icey <1790571317@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: Angazenn <supperccell@163.com>
Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com>
Co-authored-by: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com>
Co-authored-by: weichen <132029610+Pr0Wh1teGivee@users.noreply.github.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: fems14 <74094523+fems14@users.noreply.github.com>
This commit is contained in:
Mercykid-bash
2025-09-23 10:27:14 +08:00
committed by GitHub
parent 8dd53c8860
commit 29c173ab48
2 changed files with 659 additions and 1 deletions

View File

@@ -3,6 +3,7 @@
from .policy_abstract import DynamicConfig, EplbPolicy
from .policy_dynamic_ep import DynamicEplb
from .policy_dynamic_ep_v2 import DynamicEplbV2
from .policy_flashlb import FlashLB
from .policy_random import RandomLoadBalance
@@ -22,5 +23,11 @@ class PolicyFactory:
DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load
2:
DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle
3:
FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment
}
return policy.get(policy_type, RandomLoadBalance)(config)
policy_class = policy.get(policy_type, RandomLoadBalance)
policy_instance = policy_class(config)
if policy_type == 3:
policy_instance.warm_up()
return policy_instance

View File

@@ -0,0 +1,651 @@
# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy.
import logging
from collections import deque
from typing import Dict
import numpy as np
import torch
from numba import njit # type: ignore
from .policy_abstract import DynamicConfig, EplbPolicy
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
@njit
def compute_piece_counts(X, P, stage_weights):
n_stage, N = X.shape
S = P - N
pieces = np.ones(N, dtype=np.int32)
unit = X / pieces # unit[i, j] = X[i, j] / pieces[j]
for _ in range(S):
deltas = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
# Find top1 and top2
idx1 = -1
idx2 = -1
val1 = -1.0
val2 = -1.0
for j in range(N):
v = unit[i, j]
if v > val1:
val2 = val1
idx2 = idx1
val1 = v
idx1 = j
elif v > val2:
val2 = v
idx2 = j
origin = unit[i, idx1]
secv = unit[i, idx2]
alt = X[i, idx1] / (pieces[idx1] + 1)
delta = origin - (alt if alt > secv else secv)
deltas[idx1] += delta * stage_weights[i] if np.any(
delta) != 0 else stage_weights[i]
max_idx = np.argmax(deltas)
pieces[max_idx] += 1
for i in range(n_stage):
unit[i, max_idx] = X[i, max_idx] / pieces[max_idx]
# Compute max load
max_load = 0.0
for j in range(N):
total = 0.0
for i in range(n_stage):
total += unit[i, j]
if total > max_load:
max_load = total
return pieces
@njit
def jsq_placement(X, pieces, M, stage_weights):
n_stage, N = X.shape
total_piece = pieces.sum()
num_per_group = total_piece // M
# 1. Compute unit_hotness
unit_hotness = np.empty((n_stage, N), dtype=np.float32)
for i in range(N):
if pieces[i] > 0:
for s in range(n_stage):
unit_hotness[s, i] = X[s, i] / pieces[i]
else:
for s in range(n_stage):
unit_hotness[s, i] = 0.0
# 2. Sort by total hotness
scores = np.zeros(N, dtype=np.float32)
for i in range(N):
for s in range(n_stage):
scores[i] += unit_hotness[s, i]
idx = np.argsort(-scores)
# 3. Initialization
loads = np.zeros((n_stage, M), dtype=np.float32)
dev_phy_exp_n = np.zeros(M, dtype=np.int32)
deployment = -np.ones((M, num_per_group), dtype=np.int32)
dep_ptr = np.zeros(M, dtype=np.int32)
# 4. Main loop
for t in range(N):
i = idx[t]
used_device = list()
for _ in range(pieces[i]):
# 4.1 Construct w vector
w = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
w[s] = unit_hotness[s, i]
# 4.2 Compute stage-level maximum load
stage_max = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
max_val = loads[s, 0]
for k in range(1, M):
if loads[s, k] > max_val:
max_val = loads[s, k]
stage_max[s] = max_val
# 4.3 Compute denominator
denom = np.empty(n_stage, dtype=np.float32)
for s in range(n_stage):
sum_tmp = 0.0
for j in range(M):
sum_tmp += loads[s, j] + w[s]
denom[s] = sum_tmp / M + 1e-2
# 4.4 Find best device j
best_j = -1
best_val = 1e30
for j in range(M):
if dev_phy_exp_n[j] >= num_per_group:
continue
if j in used_device:
continue
score = 0.0
for s in range(n_stage):
tmp_sj = loads[s, j] + w[s]
numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s]
score += stage_weights[s] * (numer_sj / denom[s])
if score < best_val:
best_val = score
best_j = j
if best_j == -1:
continue
used_device.append(best_j)
# 4.5 Update status
for s in range(n_stage):
loads[s, best_j] += w[s]
ptr = dep_ptr[best_j]
deployment[best_j, ptr] = i
dep_ptr[best_j] += 1
dev_phy_exp_n[best_j] += 1
# Handle remaining -1 values: fill with random elements from range(N) not in current column
for rank in range(M):
for col in range(num_per_group):
if deployment[rank, col] == -1:
# Get elements already in current column
current_rank_elements = set(deployment[rank, :])
# Filter elements from range(N) not in current column
available = [
x for x in range(N) if x not in current_rank_elements
]
# Randomly select an available element to fill
if len(available) > 0:
rand_idx = np.random.randint(0, len(available))
deployment[rank, col] = available[rand_idx]
elif N > 0:
# All unique experts are already in this rank's column, so we can pick any expert randomly.
deployment[rank, col] = np.random.randint(0, N)
return deployment
@njit
def slice_values(X, pieces):
total_len = 0
for i in range(X.shape[0]):
total_len += pieces[i]
result = np.empty(total_len, dtype=np.float32)
idx = 0
for i in range(X.shape[0]):
val = X[i] / pieces[i]
for _ in range(pieces[i]):
result[idx] = val
idx += 1
return result
@njit
def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces,
simulated_deployment, stage_weights):
n_stage, N = X.shape
num_group = P // M
X_all = np.zeros(N, dtype=np.float32)
for i in range(n_stage):
for j in range(N):
X_all[j] += X[i, j]
sort_idx = np.argsort(np.negative(X_all))
X_sorted = X[:, sort_idx]
unit_load = np.empty(N, dtype=np.float32)
for j in range(N):
unit_load[j] = X_all[j] / simulated_pieces[j]
flat_deployment = simulated_deployment.reshape(-1)
simulated_load = np.zeros(M, dtype=np.float32)
for i in range(flat_deployment.shape[0]):
simulated_load[i // (flat_deployment.shape[0] //
M)] += unit_load[flat_deployment[i]]
slice_vals = slice_values(X_all, simulated_pieces)
sorted_slices = np.sort(slice_vals)[::-1]
simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M
cumulative_slices_used = np.zeros(N, dtype=np.int32)
acc = 0
for i in range(N):
acc += simulated_pieces[sort_idx[i]]
cumulative_slices_used[i] = acc
group_boundary_indices = np.zeros(num_group, dtype=np.int32)
for i in range(1, num_group + 1):
for j in range(N):
if cumulative_slices_used[j] >= i * M:
group_boundary_indices[i - 1] = j
break
slices_used_per_group = np.zeros(num_group, dtype=np.int32)
slices_used_per_group[0] = group_boundary_indices[0]
for i in range(1, num_group):
slices_used_per_group[
i] = group_boundary_indices[i] - group_boundary_indices[i - 1]
slices_used_per_group = M - slices_used_per_group
loads = np.zeros(M, dtype=np.float32)
pieces = np.zeros(N, dtype=np.int32)
num_remain_slice = P - N
current_idx = 0
for g in range(num_group):
window = X_sorted[:, current_idx:current_idx + 2 * M]
low = max(0, current_idx + M - N)
high = min(num_remain_slice, M - 1)
while (high - low) > 1:
mid = int((high + low) // 2)
keep = M - mid
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M,
stage_weights)
current_pieces = np.maximum(current_pieces, 1)
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
current_loads = loads + current_slice_sorted
current_max: np.float32 = np.max(current_loads)
current_min: np.float32 = np.min(current_loads)
current_slope = (current_max - current_min) / M
next_slope: np.float32 = np.max(simulated_slopes[current_idx +
keep:])
if abs(current_slope) > abs(next_slope):
low = mid
else:
high = mid
S = high
keep = M - S
current_group = window[:, :keep]
current_pieces = compute_piece_counts(current_group, M, stage_weights)
for i in range(keep):
pieces[sort_idx[current_idx + i]] = current_pieces[i]
current_slice = slice_values(current_group.sum(0), current_pieces)
current_slice_sorted = np.sort(current_slice)
loads += current_slice_sorted
loads = np.sort(loads)[::-1]
current_idx += keep
num_remain_slice -= S
return pieces
@njit
def compute_objective(deployment, X, pieces):
M, P = deployment.shape
loads = np.zeros(M)
for i in range(M):
for j in range(P):
expert = deployment[i, j]
if pieces[expert] == 0:
continue
loads[i] += X[expert] / pieces[expert]
mean_load = np.mean(loads)
max_load: np.float32 = np.max(loads)
obj = max_load / mean_load
return obj, loads
@njit
def auto_fix_new_placement(old_placement, new_placement):
"""
Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both
old_placement and new_placement remain in their original positions from old_placement.
New elements (unique to new_placement) will fill the remaining empty positions.
Args:
old_placement: Old deployment matrix with shape (num_ranks, num_experts)
new_placement: New deployment matrix to be fixed, must have the same shape as old_placement
Returns:
fixed_new: adjusted version of the new_placement matrix
"""
num_ranks, num_experts = old_placement.shape
fixed_new = np.empty_like(new_placement)
max_expert_old = old_placement.max() if num_experts > 0 else 0
max_expert_new = new_placement.max() if num_experts > 0 else 0
max_expert = max(max_expert_old, max_expert_new)
for rank_id in range(num_ranks):
old_row = old_placement[rank_id]
new_row = new_placement[rank_id]
index_array = np.full((max_expert + 1, num_experts),
-1,
dtype=np.int32)
count_array = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
pos = count_array[val]
index_array[val, pos] = idx
count_array[val] += 1
old_counter = np.zeros(max_expert + 1, dtype=np.int32)
for idx in range(num_experts):
val = old_row[idx]
if val >= 0 and val <= max_expert:
old_counter[val] += 1
retain_elements = np.empty(num_experts, dtype=new_placement.dtype)
new_elements = np.empty(num_experts, dtype=new_placement.dtype)
retain_ptr = 0
new_ptr = 0
for val in new_row:
if val >= 0 and val <= max_expert and old_counter[val] > 0:
retain_elements[retain_ptr] = val
retain_ptr += 1
old_counter[val] -= 1
else:
new_elements[new_ptr] = val
new_ptr += 1
current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype)
for i in range(retain_ptr):
val = retain_elements[i]
if val >= 0 and val <= max_expert:
pos = count_array[val] - 1
if pos >= 0:
idx = index_array[val, pos]
current_fixed[idx] = val
count_array[val] -= 1
empty_indices = np.empty(num_experts, dtype=np.int32)
empty_ptr = 0
for idx in range(num_experts):
if current_fixed[idx] == -1:
empty_indices[empty_ptr] = idx
empty_ptr += 1
for i in range(new_ptr):
if i < empty_ptr:
current_fixed[empty_indices[i]] = new_elements[i]
fixed_new[rank_id] = current_fixed
return fixed_new
class FlashLB(EplbPolicy):
def __init__(self, config: DynamicConfig):
super().__init__(config)
self.par_history: Dict[int, float] = {}
self.hotness_window: Dict[int, deque[float]] = {}
self.max_stage_window = (config.max_stage_window if hasattr(
config, "max_stage_window") else 1)
self.buffer_expert_layer_num = (
config.buffer_expert_layer_num if hasattr(
config, "buffer_expert_layer_num") else 58)
self.threshold_ratio = (config.threshold_ratio if hasattr(
config, "threshold_ratio") else 0)
def compute_expert_hotness(self, num_of_expert: int,
deployment: np.ndarray, rank_load: np.ndarray):
hotness = np.zeros(num_of_expert, dtype=rank_load.dtype)
deployment_flat = deployment.ravel()
rank_load_flat = rank_load.ravel()
np.add.at(hotness, deployment_flat, rank_load_flat)
return hotness
def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray):
n_stage, N = hotness.shape
if np.any(deployment < 0):
print(f"Invalid deployment with negative values: {deployment}")
raise ValueError("Deployment table contains negative values.")
counts = np.bincount(deployment.reshape(-1), minlength=N)
unit_hotness = np.divide(hotness,
counts,
out=np.zeros_like(hotness, dtype=float),
where=counts != 0)
stage_par = np.zeros(n_stage)
for i in range(n_stage):
stage_load = unit_hotness[i][deployment].sum(-1)
stage_par[i] = stage_load.max() / stage_load.mean()
return stage_par.mean()
def group_based_adaptive_bloating(self,
X,
P,
M,
stage_weights=None,
recorsive=False):
n_stage, N = X.shape
if stage_weights is None:
stage_weights = np.ones(n_stage, dtype=np.float32)
if recorsive:
(
simulated_deployment,
simulated_pieces,
) = self.group_based_adaptive_bloating(X,
P,
M,
stage_weights,
recorsive=False)
else:
simulated_pieces = compute_piece_counts(X, P, stage_weights)
simulated_deployment = jsq_placement(X, simulated_pieces, M,
stage_weights)
pieces = group_based_adaptive_bloating_kernel(
X.astype(np.float32),
P,
M,
simulated_pieces.astype(np.int32),
simulated_deployment.astype(np.int32),
stage_weights.astype(np.float32),
)
deployment = jsq_placement(X, pieces, M, stage_weights)
X_all = X.sum(0)
unit_load = np.divide(X_all,
pieces,
out=np.zeros_like(X_all, dtype=float),
where=pieces != 0)
load = unit_load[deployment].sum(-1)
sim_unit_load = X_all / simulated_pieces
sim_load = sim_unit_load[simulated_deployment].sum(-1)
if load.max() > sim_load.max():
return simulated_deployment, simulated_pieces
return deployment, pieces
def need_update(self, current_par, layer_id=0):
threshold = self.par_history.get(layer_id, 0.0)
return current_par >= self.threshold_ratio * threshold
def compute_stage_weight(self, hotness):
n_stage = hotness.shape[0]
stage_weights = np.zeros(n_stage)
for i in range(n_stage):
stage_weights[i] = hotness[i].sum()
stage_weights = stage_weights / stage_weights.max()
return stage_weights
def rebalance_layer(self, deployment, hotness, layer_id=0):
num_rank, expert_per_rank = deployment.shape
num_expert = np.unique(deployment.reshape(-1)).shape[0]
num_of_redundant_expert = num_rank * expert_per_rank - num_expert
current_par = self.compute_rank_load(deployment, hotness)
if not self.need_update(current_par, layer_id):
return deployment, current_par, current_par
stage_weights = self.compute_stage_weight(hotness)
new_deployment, _ = self.group_based_adaptive_bloating(
hotness,
num_expert + num_of_redundant_expert,
num_rank,
stage_weights,
recorsive=False,
)
if np.any(new_deployment < 0):
print(f"{new_deployment=}")
new_par = self.compute_rank_load(new_deployment, hotness)
return new_deployment, new_par, current_par
def register_hotness(self, deployment, rank_load, num_layer, num_expert):
for layer in range(num_layer):
if layer not in self.hotness_window:
self.hotness_window[layer] = deque(
maxlen=self.max_stage_window)
hotness = self.compute_expert_hotness(num_expert,
deployment[layer],
rank_load[layer])
self.hotness_window[layer].append(hotness)
def compress_by_avg_pooling_fast_nd(self, arr, m):
n, d = arr.shape
idx = (np.arange(n) * m // n)
result = np.zeros((m, d))
counts = np.zeros((m, 1))
np.add.at(result, idx, arr)
np.add.at(counts, idx, 1)
return result / counts
def rebalance_experts(self, current_expert_table, expert_workload):
current_deployment = np.array(current_expert_table)
expert_workload = np.array(expert_workload)
expert_workload += 1
num_layer = expert_workload.shape[0]
num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0]
self.register_hotness(current_deployment, expert_workload, num_layer,
num_expert)
new_deployment = current_deployment.copy()
layers_need_update = np.arange(num_layer)
new_par = np.zeros(layers_need_update.shape[0])
current_par = np.zeros(layers_need_update.shape[0])
for i, layer in enumerate(layers_need_update):
hotness = np.array(self.hotness_window[layer])
if hotness.shape[0] > self.max_stage_window:
hotness = self.compress_by_avg_pooling_fast_nd(
hotness, self.max_stage_window)
(
new_deployment[layer],
new_par[i],
current_par[i],
) = self.rebalance_layer(current_deployment[layer],
hotness,
layer_id=layer)
priority = new_par / current_par
priority_idx = np.argsort(priority)
priority_idx = priority_idx[priority[priority_idx] <
1][:self.buffer_expert_layer_num]
if np.all(expert_workload == 1):
for _, layer in enumerate(layers_need_update):
self.hotness_window[layer].pop()
return False, np.array([], dtype=int), current_deployment
change = len(priority_idx) > 0
if change:
for idx in priority_idx:
self.par_history[layers_need_update[idx]] = new_par[idx]
layers_need_update = priority_idx
deployment = current_deployment
for layer in layers_need_update:
deployment[layer] = auto_fix_new_placement(
current_deployment[layer], new_deployment[layer])
return change, layers_need_update, deployment
def generate_layered_experts(num_layers=58,
layer_shape=(32, 9),
expert_min=0,
expert_max=255):
"""
Generate expert deployment matrix meeting the following conditions:
- Total of num_layers layers
- Each layer has shape layer_shape (32,9)
- Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer
Args:
num_layers: Number of layers, default 58
layer_shape: Shape of a single layer, default (32,9)
expert_min: Minimum expert ID, default 0
expert_max: Maximum expert ID, default 255
Returns:
torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1])
"""
# 1. Basic parameter calculation
expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255)
layer_total = layer_shape[0] * layer_shape[
1] # Total elements in a single layer: 32*9=288
extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32
# 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts)
assert layer_total >= expert_num, (
f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, "
"cannot cover all experts")
# 3. Generate layers one by one
layers = []
for _ in range(num_layers):
# 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included)
full_experts = torch.arange(expert_min,
expert_max + 1,
dtype=torch.int64) # shape (256,)
# 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255)
extra_experts = torch.randint(expert_min,
expert_max + 1,
size=(extra_slots, ),
dtype=torch.int64) # shape (32,)
# 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer)
layer_flat = torch.cat([full_experts, extra_experts],
dim=0) # shape (288,)
# Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues)
shuffle_idx = torch.randperm(layer_flat.shape[0])
layer_shuffled = layer_flat[shuffle_idx]
# 3.4 Reshape to layer_shape (32,9)
layer = layer_shuffled.reshape(layer_shape)
layers.append(layer)
# 4. Stack all layers to get the final tensor
return torch.stack(layers, dim=0) # shape (58,32,9)
def warm_up():
exam_config = DynamicConfig()
exam_config.ep_worldsize = 32
exam_config.num_die_per_host = 16
algo = FlashLB(exam_config)
# Generate target tensor
expert_tensor = generate_layered_experts(num_layers=58,
layer_shape=(32, 9))
algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))