From 29c173ab48001f3d26805db2d833be712687fb1a Mon Sep 17 00:00:00 2001 From: Mercykid-bash Date: Tue, 23 Sep 2025 10:27:14 +0800 Subject: [PATCH] FlashLB algorithm (#3042) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Purpose This Pull Request enhances the EPLB (Expert Parallelism Load Balancing) system by introducing a novel balancing algorithm: FlashLB. ## Motivation 1. The default algorithm adopts a two-stage greedy strategy: a. Replica allotment: Determine the number of expert replicas by minimizing the maximum load per replica (Min Max Replica, MMR). b. Replica placement: Distribute replicas across devices by repeatedly assigning the heaviest replica to the least loaded device (Longest Processing Time First, LPT). However, this sequential process lacks inter-stage collaborative optimization, often leading to suboptimal load balancing. For example, in the simple case shown in the figure below: given 8 logical experts with hotness values of 600, 560, 120, 120, 20, 10, 10, 10, and 2 replicas allocated per device across 8 devices, the EPLB algorithm yields a maximum per-device hotness of 232, while our proposed FlashLB algorithm can reduce this value to 205. 2. The default algorithm relies on the averaged expert hotness over a fixed time window for optimization. While this provides a coarse approximation of the hotness distribution, it fails to capture oscillatory deviations and temporal correlations of expert hotness observed across iterations in real-world scenarios, limiting optimization quality. 3. The default algorithm periodically regenerates the expert placement table. However, it generates the table for each individual layer, and the new table does not account for correlations with the previous one; these two factors collectively lead to nearly full-scale expert reassignment. ## FlashLB Algorithm Principle 1. Joint Optimization FlashLB achieves joint optimization of replica allotment and placement through group-based decision-making. Each group gradually determines the replica count and placement for a subset of experts, ensuring that the expected inter-device load balance (considering both deployed and pending expert replicas) is holistically optimized. To attain superior load balancing, FlashLB employs tree search to expand the solution space while integrating pruning and precompilation techniques for acceleration, thereby delivering load balancing that is both high-quality and practically efficient. 2. Multi-Shot Enhancement FlashLB partitions each profiling interval (e.g., 1024 iterations) into consecutive smaller sub-intervals (e.g., 16 iterations), each capturing independent hotness measurements. It then performs multi-shot optimization to co-optimize these sub-intervals simultaneously—enabling adaptation to time-variant expert hotness while enhancing robustness. 3. Incremental Adjustment To reduce the overhead of frequent expert re-deployment, FlashLB introduces an incremental adjustment scheme operating at both inter-layer and intra-layer levels: a. Inter-Layer: Hotness variations are tracked at the layer level. Only layers with fluctuations exceeding a predefined threshold trigger re-computation of expert placement, avoiding unnecessary redeployment for stable layers; b. Intra-Layer (Optional): A lightweight incremental LPT algorithm (LPT-Incremental) is applied. Instead of recomputing full placement for all experts in a layer, it selectively adjusts only the hottest experts or those with replica count changes, further reducing migration overhead. This incremental strategy significantly reduces adjustment costs while maintaining balanced performance across layers and devices. ## Co-author: Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/9607d5eb449711b349d4c2bee0a9c94afcc7ed14 --------- Signed-off-by: sdmyzlp Signed-off-by: Che Ruan Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com> Signed-off-by: Yikun Jiang Signed-off-by: 22dimensions Signed-off-by: zhanghaiwen Signed-off-by: hfadzxy Signed-off-by: Lucas Kabela Signed-off-by: wangli Signed-off-by: MengqingCao Signed-off-by: wangxiyuan Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: dependabot[bot] Signed-off-by: tangtianyi Signed-off-by: Angazenn Signed-off-by: Yizhou Liu Signed-off-by: rjg-lyh <1318825571@qq.com> Signed-off-by: Pr0Wh1teGivee Signed-off-by: fems14 <1804143737@qq.com> Co-authored-by: sdmyzlp <117554856+sdmyzlp@users.noreply.github.com> Co-authored-by: Che Ruan Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shanshan Shen <467638484@qq.com> Co-authored-by: Yikun Jiang Co-authored-by: 22dimensions Co-authored-by: zhanghw0354 Co-authored-by: zhanghaiwen Co-authored-by: zhangxinyuehfad <59153331+zhangxinyuehfad@users.noreply.github.com> Co-authored-by: Lucas Kabela Co-authored-by: Li Wang Co-authored-by: MengqingCao Co-authored-by: wangxiyuan Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: tianyitang Co-authored-by: Angazenn Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Co-authored-by: weichen <132029610+Pr0Wh1teGivee@users.noreply.github.com> Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com> Co-authored-by: fems14 <74094523+fems14@users.noreply.github.com> --- .../eplb/core/policy/policy_factory.py | 9 +- .../eplb/core/policy/policy_flashlb.py | 651 ++++++++++++++++++ 2 files changed, 659 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/eplb/core/policy/policy_flashlb.py diff --git a/vllm_ascend/eplb/core/policy/policy_factory.py b/vllm_ascend/eplb/core/policy/policy_factory.py index 03a57e3..bbf7315 100644 --- a/vllm_ascend/eplb/core/policy/policy_factory.py +++ b/vllm_ascend/eplb/core/policy/policy_factory.py @@ -3,6 +3,7 @@ from .policy_abstract import DynamicConfig, EplbPolicy from .policy_dynamic_ep import DynamicEplb from .policy_dynamic_ep_v2 import DynamicEplbV2 +from .policy_flashlb import FlashLB from .policy_random import RandomLoadBalance @@ -22,5 +23,11 @@ class PolicyFactory: DynamicEplb, # Dynamic EPLB policy: overall expert replacement based on current moe load 2: DynamicEplbV2, # Dynamic EPLB policy V2: expert replacement with constrained number of expert shuffle + 3: + FlashLB, # FlashLB EPLB policy: expert replacement based on Joint Optimization, Multi-Shot Enhancement and Incremental Adjustment } - return policy.get(policy_type, RandomLoadBalance)(config) + policy_class = policy.get(policy_type, RandomLoadBalance) + policy_instance = policy_class(config) + if policy_type == 3: + policy_instance.warm_up() + return policy_instance \ No newline at end of file diff --git a/vllm_ascend/eplb/core/policy/policy_flashlb.py b/vllm_ascend/eplb/core/policy/policy_flashlb.py new file mode 100644 index 0000000..2bf6551 --- /dev/null +++ b/vllm_ascend/eplb/core/policy/policy_flashlb.py @@ -0,0 +1,651 @@ +# Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. + +import logging +from collections import deque +from typing import Dict + +import numpy as np +import torch +from numba import njit # type: ignore + +from .policy_abstract import DynamicConfig, EplbPolicy + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) + + +@njit +def compute_piece_counts(X, P, stage_weights): + n_stage, N = X.shape + S = P - N + pieces = np.ones(N, dtype=np.int32) + unit = X / pieces # unit[i, j] = X[i, j] / pieces[j] + + for _ in range(S): + deltas = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + # Find top1 and top2 + idx1 = -1 + idx2 = -1 + val1 = -1.0 + val2 = -1.0 + for j in range(N): + v = unit[i, j] + if v > val1: + val2 = val1 + idx2 = idx1 + val1 = v + idx1 = j + elif v > val2: + val2 = v + idx2 = j + + origin = unit[i, idx1] + secv = unit[i, idx2] + alt = X[i, idx1] / (pieces[idx1] + 1) + delta = origin - (alt if alt > secv else secv) + deltas[idx1] += delta * stage_weights[i] if np.any( + delta) != 0 else stage_weights[i] + + max_idx = np.argmax(deltas) + pieces[max_idx] += 1 + for i in range(n_stage): + unit[i, max_idx] = X[i, max_idx] / pieces[max_idx] + + # Compute max load + max_load = 0.0 + for j in range(N): + total = 0.0 + for i in range(n_stage): + total += unit[i, j] + if total > max_load: + max_load = total + + return pieces + + +@njit +def jsq_placement(X, pieces, M, stage_weights): + n_stage, N = X.shape + total_piece = pieces.sum() + num_per_group = total_piece // M + + # 1. Compute unit_hotness + unit_hotness = np.empty((n_stage, N), dtype=np.float32) + for i in range(N): + if pieces[i] > 0: + for s in range(n_stage): + unit_hotness[s, i] = X[s, i] / pieces[i] + else: + for s in range(n_stage): + unit_hotness[s, i] = 0.0 + + # 2. Sort by total hotness + scores = np.zeros(N, dtype=np.float32) + for i in range(N): + for s in range(n_stage): + scores[i] += unit_hotness[s, i] + idx = np.argsort(-scores) + + # 3. Initialization + loads = np.zeros((n_stage, M), dtype=np.float32) + dev_phy_exp_n = np.zeros(M, dtype=np.int32) + deployment = -np.ones((M, num_per_group), dtype=np.int32) + dep_ptr = np.zeros(M, dtype=np.int32) + + # 4. Main loop + for t in range(N): + i = idx[t] + used_device = list() + for _ in range(pieces[i]): + # 4.1 Construct w vector + w = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + w[s] = unit_hotness[s, i] + + # 4.2 Compute stage-level maximum load + stage_max = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + max_val = loads[s, 0] + for k in range(1, M): + if loads[s, k] > max_val: + max_val = loads[s, k] + stage_max[s] = max_val + + # 4.3 Compute denominator + denom = np.empty(n_stage, dtype=np.float32) + for s in range(n_stage): + sum_tmp = 0.0 + for j in range(M): + sum_tmp += loads[s, j] + w[s] + denom[s] = sum_tmp / M + 1e-2 + + # 4.4 Find best device j + best_j = -1 + best_val = 1e30 + for j in range(M): + if dev_phy_exp_n[j] >= num_per_group: + continue + if j in used_device: + continue + score = 0.0 + for s in range(n_stage): + tmp_sj = loads[s, j] + w[s] + numer_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s] + score += stage_weights[s] * (numer_sj / denom[s]) + if score < best_val: + best_val = score + best_j = j + if best_j == -1: + continue + + used_device.append(best_j) + + # 4.5 Update status + for s in range(n_stage): + loads[s, best_j] += w[s] + ptr = dep_ptr[best_j] + deployment[best_j, ptr] = i + dep_ptr[best_j] += 1 + dev_phy_exp_n[best_j] += 1 + + # Handle remaining -1 values: fill with random elements from range(N) not in current column + for rank in range(M): + for col in range(num_per_group): + if deployment[rank, col] == -1: + # Get elements already in current column + current_rank_elements = set(deployment[rank, :]) + # Filter elements from range(N) not in current column + available = [ + x for x in range(N) if x not in current_rank_elements + ] + # Randomly select an available element to fill + if len(available) > 0: + rand_idx = np.random.randint(0, len(available)) + deployment[rank, col] = available[rand_idx] + elif N > 0: + # All unique experts are already in this rank's column, so we can pick any expert randomly. + deployment[rank, col] = np.random.randint(0, N) + + return deployment + + +@njit +def slice_values(X, pieces): + total_len = 0 + for i in range(X.shape[0]): + total_len += pieces[i] + result = np.empty(total_len, dtype=np.float32) + idx = 0 + for i in range(X.shape[0]): + val = X[i] / pieces[i] + for _ in range(pieces[i]): + result[idx] = val + idx += 1 + return result + + +@njit +def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, + simulated_deployment, stage_weights): + n_stage, N = X.shape + num_group = P // M + + X_all = np.zeros(N, dtype=np.float32) + for i in range(n_stage): + for j in range(N): + X_all[j] += X[i, j] + + sort_idx = np.argsort(np.negative(X_all)) + X_sorted = X[:, sort_idx] + + unit_load = np.empty(N, dtype=np.float32) + for j in range(N): + unit_load[j] = X_all[j] / simulated_pieces[j] + + flat_deployment = simulated_deployment.reshape(-1) + simulated_load = np.zeros(M, dtype=np.float32) + for i in range(flat_deployment.shape[0]): + simulated_load[i // (flat_deployment.shape[0] // + M)] += unit_load[flat_deployment[i]] + + slice_vals = slice_values(X_all, simulated_pieces) + sorted_slices = np.sort(slice_vals)[::-1] + simulated_slopes = (sorted_slices[:-M + 1] - sorted_slices[M - 1:]) / M + + cumulative_slices_used = np.zeros(N, dtype=np.int32) + acc = 0 + for i in range(N): + acc += simulated_pieces[sort_idx[i]] + cumulative_slices_used[i] = acc + + group_boundary_indices = np.zeros(num_group, dtype=np.int32) + for i in range(1, num_group + 1): + for j in range(N): + if cumulative_slices_used[j] >= i * M: + group_boundary_indices[i - 1] = j + break + + slices_used_per_group = np.zeros(num_group, dtype=np.int32) + slices_used_per_group[0] = group_boundary_indices[0] + for i in range(1, num_group): + slices_used_per_group[ + i] = group_boundary_indices[i] - group_boundary_indices[i - 1] + slices_used_per_group = M - slices_used_per_group + + loads = np.zeros(M, dtype=np.float32) + pieces = np.zeros(N, dtype=np.int32) + num_remain_slice = P - N + current_idx = 0 + + for g in range(num_group): + window = X_sorted[:, current_idx:current_idx + 2 * M] + low = max(0, current_idx + M - N) + high = min(num_remain_slice, M - 1) + + while (high - low) > 1: + mid = int((high + low) // 2) + keep = M - mid + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, + stage_weights) + current_pieces = np.maximum(current_pieces, 1) + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + current_loads = loads + current_slice_sorted + current_max: np.float32 = np.max(current_loads) + current_min: np.float32 = np.min(current_loads) + current_slope = (current_max - current_min) / M + next_slope: np.float32 = np.max(simulated_slopes[current_idx + + keep:]) + + if abs(current_slope) > abs(next_slope): + low = mid + else: + high = mid + + S = high + keep = M - S + current_group = window[:, :keep] + current_pieces = compute_piece_counts(current_group, M, stage_weights) + + for i in range(keep): + pieces[sort_idx[current_idx + i]] = current_pieces[i] + + current_slice = slice_values(current_group.sum(0), current_pieces) + current_slice_sorted = np.sort(current_slice) + loads += current_slice_sorted + loads = np.sort(loads)[::-1] + + current_idx += keep + num_remain_slice -= S + + return pieces + + +@njit +def compute_objective(deployment, X, pieces): + M, P = deployment.shape + loads = np.zeros(M) + + for i in range(M): + for j in range(P): + expert = deployment[i, j] + if pieces[expert] == 0: + continue + loads[i] += X[expert] / pieces[expert] + + mean_load = np.mean(loads) + max_load: np.float32 = np.max(loads) + obj = max_load / mean_load + return obj, loads + + +@njit +def auto_fix_new_placement(old_placement, new_placement): + """ + Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both + old_placement and new_placement remain in their original positions from old_placement. + New elements (unique to new_placement) will fill the remaining empty positions. + + Args: + old_placement: Old deployment matrix with shape (num_ranks, num_experts) + new_placement: New deployment matrix to be fixed, must have the same shape as old_placement + + Returns: + fixed_new: adjusted version of the new_placement matrix + """ + num_ranks, num_experts = old_placement.shape + fixed_new = np.empty_like(new_placement) + + max_expert_old = old_placement.max() if num_experts > 0 else 0 + max_expert_new = new_placement.max() if num_experts > 0 else 0 + max_expert = max(max_expert_old, max_expert_new) + + for rank_id in range(num_ranks): + old_row = old_placement[rank_id] + new_row = new_placement[rank_id] + + index_array = np.full((max_expert + 1, num_experts), + -1, + dtype=np.int32) + count_array = np.zeros(max_expert + 1, dtype=np.int32) + + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + pos = count_array[val] + index_array[val, pos] = idx + count_array[val] += 1 + + old_counter = np.zeros(max_expert + 1, dtype=np.int32) + for idx in range(num_experts): + val = old_row[idx] + if val >= 0 and val <= max_expert: + old_counter[val] += 1 + + retain_elements = np.empty(num_experts, dtype=new_placement.dtype) + new_elements = np.empty(num_experts, dtype=new_placement.dtype) + retain_ptr = 0 + new_ptr = 0 + + for val in new_row: + if val >= 0 and val <= max_expert and old_counter[val] > 0: + retain_elements[retain_ptr] = val + retain_ptr += 1 + old_counter[val] -= 1 + else: + new_elements[new_ptr] = val + new_ptr += 1 + + current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype) + + for i in range(retain_ptr): + val = retain_elements[i] + if val >= 0 and val <= max_expert: + pos = count_array[val] - 1 + if pos >= 0: + idx = index_array[val, pos] + current_fixed[idx] = val + count_array[val] -= 1 + + empty_indices = np.empty(num_experts, dtype=np.int32) + empty_ptr = 0 + for idx in range(num_experts): + if current_fixed[idx] == -1: + empty_indices[empty_ptr] = idx + empty_ptr += 1 + + for i in range(new_ptr): + if i < empty_ptr: + current_fixed[empty_indices[i]] = new_elements[i] + + fixed_new[rank_id] = current_fixed + + return fixed_new + + +class FlashLB(EplbPolicy): + + def __init__(self, config: DynamicConfig): + super().__init__(config) + self.par_history: Dict[int, float] = {} + self.hotness_window: Dict[int, deque[float]] = {} + self.max_stage_window = (config.max_stage_window if hasattr( + config, "max_stage_window") else 1) + self.buffer_expert_layer_num = ( + config.buffer_expert_layer_num if hasattr( + config, "buffer_expert_layer_num") else 58) + self.threshold_ratio = (config.threshold_ratio if hasattr( + config, "threshold_ratio") else 0) + + def compute_expert_hotness(self, num_of_expert: int, + deployment: np.ndarray, rank_load: np.ndarray): + hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) + deployment_flat = deployment.ravel() + rank_load_flat = rank_load.ravel() + np.add.at(hotness, deployment_flat, rank_load_flat) + return hotness + + def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray): + n_stage, N = hotness.shape + if np.any(deployment < 0): + print(f"Invalid deployment with negative values: {deployment}") + raise ValueError("Deployment table contains negative values.") + counts = np.bincount(deployment.reshape(-1), minlength=N) + unit_hotness = np.divide(hotness, + counts, + out=np.zeros_like(hotness, dtype=float), + where=counts != 0) + stage_par = np.zeros(n_stage) + for i in range(n_stage): + stage_load = unit_hotness[i][deployment].sum(-1) + stage_par[i] = stage_load.max() / stage_load.mean() + return stage_par.mean() + + def group_based_adaptive_bloating(self, + X, + P, + M, + stage_weights=None, + recorsive=False): + n_stage, N = X.shape + if stage_weights is None: + stage_weights = np.ones(n_stage, dtype=np.float32) + + if recorsive: + ( + simulated_deployment, + simulated_pieces, + ) = self.group_based_adaptive_bloating(X, + P, + M, + stage_weights, + recorsive=False) + else: + simulated_pieces = compute_piece_counts(X, P, stage_weights) + simulated_deployment = jsq_placement(X, simulated_pieces, M, + stage_weights) + + pieces = group_based_adaptive_bloating_kernel( + X.astype(np.float32), + P, + M, + simulated_pieces.astype(np.int32), + simulated_deployment.astype(np.int32), + stage_weights.astype(np.float32), + ) + + deployment = jsq_placement(X, pieces, M, stage_weights) + + X_all = X.sum(0) + unit_load = np.divide(X_all, + pieces, + out=np.zeros_like(X_all, dtype=float), + where=pieces != 0) + load = unit_load[deployment].sum(-1) + + sim_unit_load = X_all / simulated_pieces + sim_load = sim_unit_load[simulated_deployment].sum(-1) + + if load.max() > sim_load.max(): + return simulated_deployment, simulated_pieces + return deployment, pieces + + def need_update(self, current_par, layer_id=0): + threshold = self.par_history.get(layer_id, 0.0) + return current_par >= self.threshold_ratio * threshold + + def compute_stage_weight(self, hotness): + n_stage = hotness.shape[0] + stage_weights = np.zeros(n_stage) + for i in range(n_stage): + stage_weights[i] = hotness[i].sum() + + stage_weights = stage_weights / stage_weights.max() + return stage_weights + + def rebalance_layer(self, deployment, hotness, layer_id=0): + num_rank, expert_per_rank = deployment.shape + num_expert = np.unique(deployment.reshape(-1)).shape[0] + num_of_redundant_expert = num_rank * expert_per_rank - num_expert + + current_par = self.compute_rank_load(deployment, hotness) + + if not self.need_update(current_par, layer_id): + return deployment, current_par, current_par + + stage_weights = self.compute_stage_weight(hotness) + new_deployment, _ = self.group_based_adaptive_bloating( + hotness, + num_expert + num_of_redundant_expert, + num_rank, + stage_weights, + recorsive=False, + ) + if np.any(new_deployment < 0): + print(f"{new_deployment=}") + new_par = self.compute_rank_load(new_deployment, hotness) + + return new_deployment, new_par, current_par + + def register_hotness(self, deployment, rank_load, num_layer, num_expert): + for layer in range(num_layer): + if layer not in self.hotness_window: + self.hotness_window[layer] = deque( + maxlen=self.max_stage_window) + hotness = self.compute_expert_hotness(num_expert, + deployment[layer], + rank_load[layer]) + self.hotness_window[layer].append(hotness) + + def compress_by_avg_pooling_fast_nd(self, arr, m): + n, d = arr.shape + idx = (np.arange(n) * m // n) + result = np.zeros((m, d)) + counts = np.zeros((m, 1)) + np.add.at(result, idx, arr) + np.add.at(counts, idx, 1) + return result / counts + + def rebalance_experts(self, current_expert_table, expert_workload): + current_deployment = np.array(current_expert_table) + expert_workload = np.array(expert_workload) + expert_workload += 1 + num_layer = expert_workload.shape[0] + num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] + self.register_hotness(current_deployment, expert_workload, num_layer, + num_expert) + + new_deployment = current_deployment.copy() + + layers_need_update = np.arange(num_layer) + + new_par = np.zeros(layers_need_update.shape[0]) + current_par = np.zeros(layers_need_update.shape[0]) + for i, layer in enumerate(layers_need_update): + hotness = np.array(self.hotness_window[layer]) + if hotness.shape[0] > self.max_stage_window: + hotness = self.compress_by_avg_pooling_fast_nd( + hotness, self.max_stage_window) + + ( + new_deployment[layer], + new_par[i], + current_par[i], + ) = self.rebalance_layer(current_deployment[layer], + hotness, + layer_id=layer) + + priority = new_par / current_par + priority_idx = np.argsort(priority) + priority_idx = priority_idx[priority[priority_idx] < + 1][:self.buffer_expert_layer_num] + + if np.all(expert_workload == 1): + for _, layer in enumerate(layers_need_update): + self.hotness_window[layer].pop() + return False, np.array([], dtype=int), current_deployment + change = len(priority_idx) > 0 + if change: + for idx in priority_idx: + self.par_history[layers_need_update[idx]] = new_par[idx] + + layers_need_update = priority_idx + deployment = current_deployment + for layer in layers_need_update: + deployment[layer] = auto_fix_new_placement( + current_deployment[layer], new_deployment[layer]) + + return change, layers_need_update, deployment + + +def generate_layered_experts(num_layers=58, + layer_shape=(32, 9), + expert_min=0, + expert_max=255): + """ + Generate expert deployment matrix meeting the following conditions: + - Total of num_layers layers + - Each layer has shape layer_shape (32,9) + - Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer + + Args: + num_layers: Number of layers, default 58 + layer_shape: Shape of a single layer, default (32,9) + expert_min: Minimum expert ID, default 0 + expert_max: Maximum expert ID, default 255 + Returns: + torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1]) + """ + # 1. Basic parameter calculation + expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) + layer_total = layer_shape[0] * layer_shape[ + 1] # Total elements in a single layer: 32*9=288 + extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 + + # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) + assert layer_total >= expert_num, ( + f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, " + "cannot cover all experts") + + # 3. Generate layers one by one + layers = [] + for _ in range(num_layers): + # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) + full_experts = torch.arange(expert_min, + expert_max + 1, + dtype=torch.int64) # shape (256,) + + # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) + extra_experts = torch.randint(expert_min, + expert_max + 1, + size=(extra_slots, ), + dtype=torch.int64) # shape (32,) + + # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) + layer_flat = torch.cat([full_experts, extra_experts], + dim=0) # shape (288,) + # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) + shuffle_idx = torch.randperm(layer_flat.shape[0]) + layer_shuffled = layer_flat[shuffle_idx] + + # 3.4 Reshape to layer_shape (32,9) + layer = layer_shuffled.reshape(layer_shape) + layers.append(layer) + + # 4. Stack all layers to get the final tensor + return torch.stack(layers, dim=0) # shape (58,32,9) + + +def warm_up(): + exam_config = DynamicConfig() + exam_config.ep_worldsize = 32 + exam_config.num_die_per_host = 16 + algo = FlashLB(exam_config) + # Generate target tensor + expert_tensor = generate_layered_experts(num_layers=58, + layer_shape=(32, 9)) + + algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9)))