From 132f3c5d0ae9a99a304e632dae77db1c3094eb06 Mon Sep 17 00:00:00 2001 From: Mercykid-bash Date: Thu, 12 Mar 2026 15:49:09 +0800 Subject: [PATCH] Support per-step heat collection and enhance FlashLB for multi-stage load balancing (#6477) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Feature: FlashLB algorithm ## Purpose This Pull Request enhances the EPLB (Expert Parallelism Load Balancing) system by introducing a novel load balancing algorithm: FlashLB. 1. The default algorithm adopts two separate sub-procedures to optimize expert replication and placement independently: a. **Expert Replica Allotment Sub-procedure** : Determines the number of replicas for all experts. At each step, it greedily adds one more replica to the expert with the highest per-replica load, aiming to minimize load skew at the expert replica granularity (Min Max Replica, MMR). b. **Expert Replica Placement Sub-procedure** : Distributes all replicas across devices. First, it sorts the generated replicas in descending order of hotness, then iteratively places the currently hottest replica onto the device with the lowest cumulative load and available slots. However, this simplistic combination of two separate procedures lacks synergy and often leads to sub-optimal load balancing. For example, in the simple scenario illustrated below: Given 8 logical experts with hotness values [600, 560, 120, 120, 20, 10, 10, 10], and 2 replicas allocated per device across 8 devices, the default EPLB algorithm results in a maximum per-device hotness of 232 (peak-average load ratio 1.28), while our proposed FlashLB algorithm reduces this value to 205 (peak-average load ratio 1.13).
2. The default algorithm simply aggregates hotness measurements across the entire profiling window. While this provides a coarse approximation of the hotness distribution, it fails to capture the time-phased variations and temporal correlations in expert hotness (both within and between experts) across iterations—phenomena that have been observed in real-world scenarios. Such single-point hotness estimation degrades the solution quality of the load balancing algorithm. 3. The default algorithm regularly recalculates updated expert placement results for all layers without discrimination. Considering that excessive expert updates can impact Service Level Objectives (SLOs), such full-scale redeployment leads to excessively high adjustment overhead, which negatively affects end-to-end performance. ## FlashLB Algorithm Principle ### 1. Joint Optimization of Replica Allotment and Placement FlashLB achieves joint optimization of replica allotment and placement through a novel tree search approach, combined with carefully designed e Fl fficient pruning and lightweight look-ahead estimation. We partition all experts into several subsets, and for each subset, hierarchically determine the optimal replica count and placement. Leveraging efficient pruning and lightweight look-ahead estimation, the process consistently aims to optimize the globally expected inter-device load balance degree (considering both deployed and unexplored experts) while ensuring sufficient computational efficiency. Additionally, precompilation techniques are employed for acceleration, delivering load balancing that is both high-quality and practically efficient. ### 2. Multi-Episode Enhancement Instead of performing full-duration averaging like the default algorithm, FlashLB partitions each profiling interval (e.g., 1024 iterations) into multiple consecutive smaller episodes (e.g., 16 iterations). This preserves hotness fluctuation and correlation information. It then constructs a multi-objective optimization problem to co-optimize these episodes simultaneously, enabling adaptability to interleaved hotness patterns and improving statistical robustness. ### 3. Layer-wise Cherry-Picking Redeployment To reduce the overhead of frequent expert redeployment, FlashLB introduces a cherry-picking redeployment scheme. During each algorithmic decision cycle, it real-time tracks load balance degree of all layers and triggers expert placement updates only for those layers whose peak-average ratio exceeds a predefined threshold. This avoids unnecessary redeployment for stable layers, significantly reducing adjustment overhead and thereby improving end-to-end performance gains. ## Co-author: Co-authored-by: Skywalker-EP 173723846@qq.com This PR mainly introduces two key optimizations for load balancing scheduling: 1. **Add per-step heat collection function**: Support real-time collection of per-step heat information during model inference. This enables more fine-grained load balancing decisions by taking per-step heat as the optimization target, improving scheduling accuracy for dynamic and fluctuating workloads. 2. **Update FlashLB algorithm**: Upgrade the FlashLB scheduling logic to better adapt to multi-stage heat distribution scenarios. The improved algorithm can comprehensively perceive and utilize multi-stage heat characteristics, achieving more stable and efficient load balancing under complex expert deployment and dynamic traffic patterns. --------- Signed-off-by: Mercykid-bash Signed-off-by: xuzewei28 Co-authored-by: xuzewei28 --- vllm_ascend/eplb/core/eplb_worker.py | 6 +- .../eplb/core/policy/policy_flashlb.py | 1437 +++++++++++------ vllm_ascend/eplb/eplb_updator.py | 10 +- vllm_ascend/ops/fused_moe/fused_moe.py | 17 +- 4 files changed, 960 insertions(+), 510 deletions(-) diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index dc32ce24..b9a62a0d 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -34,6 +34,7 @@ class EplbWorker: self.old_expert_maps = None self.enable_d2d = enable_d2d self.rank_id = dist.get_rank() + self.multi_stage = policy_type == 3 def do_update(self): # put data in to queue @@ -62,7 +63,10 @@ class EplbWorker: _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement) if self.rank_id == 0: - hotness = self._calculate_hotness(old_placement, load_info) + if self.multi_stage: + hotness = self._calculate_hotness(old_placement, load_info.sum(0)) + else: + hotness = self._calculate_hotness(old_placement, load_info) current_mean, current_max = self._compute_imbalance(old_placement, hotness) update_mean, update_max = self._compute_imbalance(new_placement, hotness) logger.info( diff --git a/vllm_ascend/eplb/core/policy/policy_flashlb.py b/vllm_ascend/eplb/core/policy/policy_flashlb.py index 117b8ed6..a33f1d39 100644 --- a/vllm_ascend/eplb/core/policy/policy_flashlb.py +++ b/vllm_ascend/eplb/core/policy/policy_flashlb.py @@ -1,597 +1,1022 @@ # Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +# Todo: Remove this policy after vllm-project/vllm/pull/24069 is merged into vllm. import logging -from collections import deque +import math +from typing import Any import numpy as np import torch from numba import njit # type: ignore +from scipy import stats # type: ignore +from scipy.optimize import linear_sum_assignment # type: ignore from .policy_abstract import DynamicConfig, EplbPolicy +# Suppress numba log warnings numba_logger = logging.getLogger("numba") numba_logger.setLevel(logging.WARNING) -@njit -def compute_piece_counts(X, P, stage_weights): - n_stage, N = X.shape - S = P - N - pieces = np.ones(N, dtype=np.int32) - unit = X / pieces # unit[i, j] = X[i, j] / pieces[j] - - for _ in range(S): - deltas = np.zeros(N, dtype=np.float32) - for i in range(n_stage): - # Find top1 and top2 - idx1 = -1 - idx2 = -1 - val1 = -1.0 - val2 = -1.0 - for j in range(N): - v = unit[i, j] - if v > val1: - val2 = val1 - idx2 = idx1 - val1 = v - idx1 = j - elif v > val2: - val2 = v - idx2 = j - - origin = unit[i, idx1] - secv = unit[i, idx2] - alt = X[i, idx1] / (pieces[idx1] + 1) - delta = origin - (alt if alt > secv else secv) - deltas[idx1] += delta * stage_weights[i] if np.any(delta) != 0 else stage_weights[i] - - max_idx = np.argmax(deltas) - pieces[max_idx] += 1 - for i in range(n_stage): - unit[i, max_idx] = X[i, max_idx] / pieces[max_idx] - - # Compute max load - max_load = 0.0 - for j in range(N): - total = 0.0 - for i in range(n_stage): - total += unit[i, j] - if total > max_load: - max_load = total - - return pieces - - -@njit -def jsq_placement(X, pieces, M, stage_weights): - n_stage, N = X.shape - total_piece = pieces.sum() - num_per_group = total_piece // M - - # 1. Compute unit_hotness - unit_hotness = np.empty((n_stage, N), dtype=np.float32) - for i in range(N): - if pieces[i] > 0: - for s in range(n_stage): - unit_hotness[s, i] = X[s, i] / pieces[i] - else: - for s in range(n_stage): - unit_hotness[s, i] = 0.0 - - # 2. Sort by total hotness - scores = np.zeros(N, dtype=np.float32) - for i in range(N): - for s in range(n_stage): - scores[i] += unit_hotness[s, i] - idx = np.argsort(-scores) - - # 3. Initialization - loads = np.zeros((n_stage, M), dtype=np.float32) - dev_phy_exp_n = np.zeros(M, dtype=np.int32) - deployment = -np.ones((M, num_per_group), dtype=np.int32) - dep_ptr = np.zeros(M, dtype=np.int32) - - # 4. Main loop - for t in range(N): - i = idx[t] - used_device = list() - for _ in range(pieces[i]): - # 4.1 Construct w vector - w = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - w[s] = unit_hotness[s, i] - - # 4.2 Compute stage-level maximum load - stage_max = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - max_val = loads[s, 0] - for k in range(1, M): - if loads[s, k] > max_val: - max_val = loads[s, k] - stage_max[s] = max_val - - # 4.3 Compute denominator - denom = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - sum_tmp = 0.0 - for j in range(M): - sum_tmp += loads[s, j] + w[s] - denom[s] = sum_tmp / M + 1e-2 - - # 4.4 Find best device j - best_j = -1 - best_val = 1e30 - for j in range(M): - if dev_phy_exp_n[j] >= num_per_group: - continue - if j in used_device: - continue - score = 0.0 - for s in range(n_stage): - tmp_sj = loads[s, j] + w[s] - number_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s] - score += stage_weights[s] * (number_sj / denom[s]) - if score < best_val: - best_val = score - best_j = j - if best_j == -1: - continue - - used_device.append(best_j) - - # 4.5 Update status - for s in range(n_stage): - loads[s, best_j] += w[s] - ptr = dep_ptr[best_j] - deployment[best_j, ptr] = i - dep_ptr[best_j] += 1 - dev_phy_exp_n[best_j] += 1 - - # Handle remaining -1 values: fill with random elements from range(N) not in current column - for rank in range(M): - for col in range(num_per_group): - if deployment[rank, col] == -1: - # Get elements already in current column - current_rank_elements = set(deployment[rank, :]) - # Filter elements from range(N) not in current column - available = [x for x in range(N) if x not in current_rank_elements] - # Randomly select an available element to fill - if len(available) > 0: - rand_idx = np.random.randint(0, len(available)) - deployment[rank, col] = available[rand_idx] - elif N > 0: - # All unique experts are already in this rank's column, so we can pick any expert randomly. - deployment[rank, col] = np.random.randint(0, N) - - return deployment - - -@njit -def slice_values(X, pieces): - total_len = 0 - for i in range(X.shape[0]): - total_len += pieces[i] - result = np.empty(total_len, dtype=np.float32) - idx = 0 - for i in range(X.shape[0]): - val = X[i] / pieces[i] - for _ in range(pieces[i]): - result[idx] = val - idx += 1 - return result - - -@njit -def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, simulated_deployment, stage_weights): - n_stage, N = X.shape - num_group = P // M - - X_all = np.zeros(N, dtype=np.float32) - for i in range(n_stage): - for j in range(N): - X_all[j] += X[i, j] - - sort_idx = np.argsort(np.negative(X_all)) - X_sorted = X[:, sort_idx] - - unit_load = np.empty(N, dtype=np.float32) - for j in range(N): - unit_load[j] = X_all[j] / simulated_pieces[j] - - flat_deployment = simulated_deployment.reshape(-1) - simulated_load = np.zeros(M, dtype=np.float32) - for i in range(flat_deployment.shape[0]): - simulated_load[i // (flat_deployment.shape[0] // M)] += unit_load[flat_deployment[i]] - - slice_vals = slice_values(X_all, simulated_pieces) - sorted_slices = np.sort(slice_vals)[::-1] - simulated_slopes = (sorted_slices[: -M + 1] - sorted_slices[M - 1 :]) / M - - cumulative_slices_used = np.zeros(N, dtype=np.int32) - acc = 0 - for i in range(N): - acc += simulated_pieces[sort_idx[i]] - cumulative_slices_used[i] = acc - - group_boundary_indices = np.zeros(num_group, dtype=np.int32) - for i in range(1, num_group + 1): - for j in range(N): - if cumulative_slices_used[j] >= i * M: - group_boundary_indices[i - 1] = j - break - - slices_used_per_group = np.zeros(num_group, dtype=np.int32) - slices_used_per_group[0] = group_boundary_indices[0] - for i in range(1, num_group): - slices_used_per_group[i] = group_boundary_indices[i] - group_boundary_indices[i - 1] - slices_used_per_group = M - slices_used_per_group - - loads = np.zeros(M, dtype=np.float32) - pieces = np.zeros(N, dtype=np.int32) - num_remain_slice = P - N - current_idx = 0 - - for g in range(num_group): - window = X_sorted[:, current_idx : current_idx + 2 * M] - low = max(0, current_idx + M - N) - high = min(num_remain_slice, M - 1) - - while (high - low) > 1: - mid = int((high + low) // 2) - keep = M - mid - current_group = window[:, :keep] - current_pieces = compute_piece_counts(current_group, M, stage_weights) - current_pieces = np.maximum(current_pieces, 1) - current_slice = slice_values(current_group.sum(0), current_pieces) - current_slice_sorted = np.sort(current_slice) - current_loads = loads + current_slice_sorted - current_max: np.float32 = np.max(current_loads) - current_min: np.float32 = np.min(current_loads) - current_slope = (current_max - current_min) / M - next_slope: np.float32 = np.max(simulated_slopes[current_idx + keep :]) - - if abs(current_slope) > abs(next_slope): - low = mid - else: - high = mid - - S = high - keep = M - S - current_group = window[:, :keep] - current_pieces = compute_piece_counts(current_group, M, stage_weights) - - for i in range(keep): - pieces[sort_idx[current_idx + i]] = current_pieces[i] - - current_slice = slice_values(current_group.sum(0), current_pieces) - current_slice_sorted = np.sort(current_slice) - loads += current_slice_sorted - loads = np.sort(loads)[::-1] - - current_idx += keep - num_remain_slice -= S - - return pieces - - -@njit -def compute_objective(deployment, X, pieces): - M, P = deployment.shape - loads = np.zeros(M) - - for i in range(M): - for j in range(P): - expert = deployment[i, j] - if pieces[expert] == 0: - continue - loads[i] += X[expert] / pieces[expert] - - mean_load = np.mean(loads) - max_load: np.float32 = np.max(loads) - obj = max_load / mean_load - return obj, loads - - -@njit -def auto_fix_new_placement(old_placement, new_placement): +@njit(fastmath=True, cache=True) +def min_max_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: """ - Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both - old_placement and new_placement remain in their original positions from old_placement. - New elements (unique to new_placement) will fill the remaining empty positions. + Original min-max replica allocation algorithm + Allocates replicas iteratively to expert with maximum unit load value Args: - old_placement: Old deployment matrix with shape (num_ranks, num_experts) - new_placement: New deployment matrix to be fixed, must have the same shape as old_placement + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) Returns: - fixed_new: adjusted version of the new_placement matrix + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) """ - num_ranks, num_experts = old_placement.shape - fixed_new = np.empty_like(new_placement) + N = mu.shape[0] + unit_value = (mu + z_score * np.sqrt(var)) / current_replicas + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] - max_expert_old = old_placement.max() if num_experts > 0 else 0 - max_expert_new = new_placement.max() if num_experts > 0 else 0 - max_expert = max(max_expert_old, max_expert_new) + # Allocate replicas to expert with maximum unit value iteratively + for r in range(num_available_replicas): + max_idx = -1 + max_value = -1.0 + for idx in range(N): + value = unit_value[idx] + if value > max_value: + max_value = value + max_idx = idx - for rank_id in range(num_ranks): - old_row = old_placement[rank_id] - new_row = new_placement[rank_id] + current_replicas[max_idx] += 1 + unit_value[max_idx] = (mu[max_idx] + z_score * np.sqrt(var[max_idx])) / current_replicas[max_idx] + replicas_history[r + 1, :] = current_replicas[:] - index_array = np.full((max_expert + 1, num_experts), -1, dtype=np.int32) - count_array = np.zeros(max_expert + 1, dtype=np.int32) + return current_replicas, replicas_history - for idx in range(num_experts): - val = old_row[idx] - if val >= 0 and val <= max_expert: - pos = count_array[val] - index_array[val, pos] = idx - count_array[val] += 1 - old_counter = np.zeros(max_expert + 1, dtype=np.int32) - for idx in range(num_experts): - val = old_row[idx] - if val >= 0 and val <= max_expert: - old_counter[val] += 1 +@njit +def max_delta_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: + """ + Maximum delta replica allocation algorithm + Allocates replicas by maximum unit value delta after increment - retain_elements = np.empty(num_experts, dtype=new_placement.dtype) - new_elements = np.empty(num_experts, dtype=new_placement.dtype) - retain_ptr = 0 - new_ptr = 0 + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) - for val in new_row: - if val >= 0 and val <= max_expert and old_counter[val] > 0: - retain_elements[retain_ptr] = val - retain_ptr += 1 - old_counter[val] -= 1 - else: - new_elements[new_ptr] = val - new_ptr += 1 + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + N = mu.shape[0] + unit_value = (mu + z_score * np.sqrt(var)) / current_replicas + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] - current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype) + # Allocate replicas by maximum unit value delta after increment + for r in range(num_available_replicas): + max_idx = -1 + max_value = -1.0 + for idx in range(N): + value = unit_value[idx] / (current_replicas[idx] + 1) + if value > max_value: + max_value = value + max_idx = idx - for i in range(retain_ptr): - val = retain_elements[i] - if val >= 0 and val <= max_expert: - pos = count_array[val] - 1 - if pos >= 0: - idx = index_array[val, pos] - current_fixed[idx] = val - count_array[val] -= 1 + current_replicas[max_idx] += 1 + unit_value[max_idx] = (mu[max_idx] + z_score * np.sqrt(var[max_idx])) / current_replicas[max_idx] + replicas_history[r + 1, :] = current_replicas[:] - empty_indices = np.empty(num_experts, dtype=np.int32) - empty_ptr = 0 - for idx in range(num_experts): - if current_fixed[idx] == -1: - empty_indices[empty_ptr] = idx - empty_ptr += 1 + return current_replicas, replicas_history - for i in range(new_ptr): - if i < empty_ptr: - current_fixed[empty_indices[i]] = new_elements[i] - fixed_new[rank_id] = current_fixed +@njit +def percentage_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: + """ + Proportional replica allocation algorithm + Allocates replicas proportionally to expert total load - return fixed_new + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) + + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + N = mu.shape[0] + total_load = mu + z_score * np.sqrt(var) + sum_total_load: float = np.sum(total_load) # Add type annotation for mypy + + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] + + # Allocate replicas proportionally to expert total load + for r in range(1, num_available_replicas + 1): + add_slots = np.zeros(N, dtype=np.int32) + + if sum_total_load == 0.0: + # Average allocation if total load is zero + base_add = r // N + extra = r % N + add_slots[:] = base_add + add_slots[:extra] += 1 + else: + # Proportional allocation with remainder compensation + quotas = (total_load / sum_total_load) * r + base_add = np.floor(quotas).astype(np.int32) + add_slots[:] = base_add + remaining = r - np.sum(base_add) + + if remaining > 0: + fractions = quotas - base_add + indices = np.argsort(-fractions) + add_slots[indices[:remaining]] += 1 + + replicas_history[r] = current_replicas + add_slots + + return replicas_history[-1], replicas_history + + +def make_replica( + mu: np.ndarray, + var: np.ndarray, + num_available_replicas: int, + current_replicas: np.ndarray, + z_score: float, + method: str = "percentage", +) -> tuple[np.ndarray, np.ndarray]: + if method == "percentage": + return percentage_replica(mu, var, num_available_replicas, current_replicas, z_score) + elif method == "max_delta": + return max_delta_replica(mu, var, num_available_replicas, current_replicas, z_score) + else: + return min_max_replica(mu, var, num_available_replicas, current_replicas, z_score) + + +@njit(fastmath=True, cache=True) +def compute_updated_device_variance( + new_expert_id: int, + device_slots: np.ndarray, + current_device_var: float, + expert_var: np.ndarray, + expert_cov: np.ndarray, + expert_replicas: np.ndarray, +) -> float: + """ + Compute updated device variance after adding a new expert + Includes both individual variance and covariance with existing experts + + Args: + new_expert_id: ID of the new expert to add + device_slots: Current expert slots on the device (-1 for empty) + current_device_var: Current variance of the device + expert_var: Variance of each expert's load (N,) + expert_cov: Covariance matrix between experts (N,N) + expert_replicas: Replica count per expert (N,) + + Returns: + new_device_var: Updated device variance after adding the expert + """ + # Add variance of new expert + new_device_var = current_device_var + expert_var[new_expert_id] / expert_replicas[new_expert_id] ** 2 + + # Add covariance between new expert and existing experts on device + for slot in device_slots: + if slot == -1: + break + new_device_var += 2 * expert_cov[new_expert_id, slot] / expert_replicas[new_expert_id] / expert_replicas[slot] + + return new_device_var + + +@njit(fastmath=True, cache=True) +def lpt_deployment( + mu: np.ndarray, + var: np.ndarray, + cov: np.ndarray, + deployment: np.ndarray, + deployed_replicas: np.ndarray, + total_replicas: np.ndarray, + z_score: float, +) -> np.ndarray: + """ + Largest Processing Time (LPT) deployment algorithm + Greedily deploys experts to device with minimal risk (mean + z*std) + + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + cov: Covariance matrix between experts (N,N) + deployment: Initial deployment matrix (num_devices, num_slots_per_device) + deployed_replicas: Already deployed replica count per expert (N,) + total_replicas: Total target replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) + + Returns: + new_deployment: Updated expert deployment matrix + """ + num_devices, num_slots_per_device = deployment.shape + + # Initialize unit value and sort experts by load + unit_value = mu / total_replicas + sorted_indices = np.argsort(-unit_value) + + new_deployment = -np.ones_like(deployment) + device_mu = np.zeros(num_devices, dtype=np.float32) + device_var = np.zeros(num_devices, dtype=np.float32) + dev_ptr = np.zeros(num_devices, dtype=np.int32) + + # Copy existing deployment first + for dev in range(num_devices): + for slot in deployment[dev]: + if slot != -1: + device_mu[dev] += mu[slot] / total_replicas[slot] + device_var[dev] += compute_updated_device_variance( + slot, new_deployment[dev], device_var[dev], var, cov, total_replicas + ) + new_deployment[dev, dev_ptr[dev]] = slot + dev_ptr[dev] += 1 + + # Greedily deploy remaining replicas to device with minimal risk + for idx in sorted_indices: + for _ in range(total_replicas[idx] - deployed_replicas[idx]): + best_dev = -1 + best_risk = 1e30 + best_mu = -1.0 + best_var = -1.0 + for dev in range(num_devices): + if dev_ptr[dev] >= num_slots_per_device: + continue + if idx in new_deployment[dev]: + continue + # Calculate temporary device load and risk + temp_mu = device_mu[dev] + mu[idx] / total_replicas[idx] + temp_var = compute_updated_device_variance( + idx, new_deployment[dev], device_var[dev], var, cov, total_replicas + ) + + risk = temp_mu + z_score * np.sqrt(temp_var) + if risk < best_risk: + best_risk = risk + best_dev = dev + best_mu = temp_mu + best_var = temp_var + + # Update device state with best deployment choice + device_mu[best_dev] = best_mu + device_var[best_dev] = best_var + new_deployment[best_dev, dev_ptr[best_dev]] = idx + dev_ptr[best_dev] += 1 + + return new_deployment + + +@njit(fastmath=True, cache=True) +def compute_score(val_data: np.ndarray, simulated_replicas: np.ndarray, simulated_deployment: np.ndarray) -> np.float32: + """ + Calculate load balance score: (max_device_load * num_devices) / total_load + Lower score means better load balance + + Args: + val_data: Validation load data (T, N) - T time steps, N experts + simulated_replicas: Replica count per expert (N,) + simulated_deployment: Expert deployment matrix (D, K) - D devices, K slots + + Returns: + mean_score: Average load balance score over time steps + """ + T, N = val_data.shape + D, K = simulated_deployment.shape + scores = np.empty((T,), dtype=np.float32) + for t in range(T): + max_load = 0.0 # Explicit float type to avoid int/float mix + tot_load = 0.0 + for d in range(D): + s = 0.0 + for k in range(K): + idx = simulated_deployment[d, k] + s += val_data[t, idx] / simulated_replicas[idx] + tot_load += s + max_load = max(max_load, s) + # Add small epsilon to avoid division by zero + scores[t] = (max_load * D + 1e-2) / (tot_load + 1e-2) + + return np.mean(scores) + + +class FlashTree: + def __init__(self, X, num_replicas, num_devices, z_score=0.674, depth=4, width=8): + super().__init__() + self.num_replicas = num_replicas + self.num_devices = num_devices + self.z_score = z_score + self.depth = depth + self.width = width + + self.X = X + self.mu, self.var, self.cov = FlashTree.compute_statistics(X) + + @staticmethod + def compute_statistics(X): + T, N = X.shape + mean_ = np.mean(X, axis=0) + if T > 1: + X_centered = X - mean_ + variance_ = np.sum(X_centered**2, axis=0) / (T - 1) + cov_matrix = (X_centered.T @ X_centered) / (T - 1) + else: + variance_ = np.zeros((N,)) + cov_matrix = np.zeros((N, N)) + return mean_, variance_, cov_matrix + + def neighbor_search( + self, low: int, high: int, initial: int, max_range: int, get_score: Any, *args: Any + ) -> tuple[int, float, np.ndarray]: + """ + Local neighbor search for optimal replica number + Search [initial-max_range, initial+max_range] within [low, high] + + Args: + low: Lower bound of search range + high: Upper bound of search range + initial: Initial replica number to start search + max_range: Maximum search range from initial value + get_score: Function to compute score for a given replica number + *args: Additional arguments for get_score function + + Returns: + best_x: Optimal replica number + best_score: Best load balance score + best_sim: Corresponding deployment simulation result + """ + max_range = min(max(initial - low, high - initial), max_range) + best_x = initial + best_score, best_sim = get_score(initial, *args) + + # Search left and right neighbors + for r in range(1, max_range + 1): + left = initial - r + if left >= low: + score, sim = get_score(left, *args) + if score < best_score: + best_x, best_score, best_sim = left, score, sim + + right = initial + r + if right <= high: + score, sim = get_score(right, *args) + if score < best_score: + best_x, best_score, best_sim = right, score, sim + + return best_x, best_score, best_sim + + def optimize_balanceness(self): + X_row = self.X + mu, var, cov = self.mu, self.var, self.cov + num_total_replicas = self.num_replicas + num_devices = self.num_devices + z_score = self.z_score + depth, width = self.depth, self.width + + num_experts = mu.shape[0] + num_available_replicas = num_total_replicas - num_experts + + if depth <= 1: + default_replicas = np.ones(num_experts, dtype=np.int32) + default_replicas = make_replica(mu, var, num_available_replicas, default_replicas, z_score)[0] + default_deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + default_deployment = lpt_deployment( + mu, var, cov, default_deployment, np.zeros(num_experts, dtype=np.int32), default_replicas, z_score + ) + default_par = compute_score(X_row, default_replicas, default_deployment) + return default_deployment, default_replicas, default_par + + interval_size = math.ceil(num_experts / depth) + weight = mu + z_score * np.sqrt(var) + idx = np.argsort(-weight) + + deployed_replicas = np.zeros(num_experts, dtype=np.int32) + deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + + def _lpt_deployment(replicas): + nonlocal mu, var, cov, deployment, deployed_replicas, z_score + return lpt_deployment(mu, var, cov, deployment, np.zeros_like(replicas), replicas, z_score) + + def get_score( + f: Any, + val_data: np.ndarray, + deployed_replicas: np.ndarray, + current_idx: np.ndarray, + current_replicas: np.ndarray, + remaind_idx: np.ndarray, + remaind_replicas: np.ndarray, + ) -> tuple[float, np.ndarray]: + """ + Wrapper to compute load balance score for replica allocation simulation + + Args: + f: Deployment function (e.g., _lpt_deployment) + val_data: Validation load data (T, N) + deployed_replicas: Already deployed replica count per expert (N,) + current_idx: Indices of current expert group + current_replicas: Replica count for current expert group + remaind_idx: Indices of remaining expert group + remaind_replicas: Replica count for remaining expert group + + Returns: + score: Load balance score + simulated_deployment: Simulated expert deployment matrix + """ + # Simulate replica allocation and deployment + simulated_replicas = deployed_replicas.copy() + simulated_replicas[current_idx] = current_replicas + simulated_replicas[remaind_idx] = remaind_replicas + simulated_deployment = f(simulated_replicas) + + # Calculate load balance score + score = compute_score(val_data, simulated_replicas, simulated_deployment) + return score, simulated_deployment + + for node in range(depth - 1): + low, high = 0, num_available_replicas + simulation_idx = idx[node * interval_size :] + current_idx = idx[node * interval_size : (node + 1) * interval_size] + remaind_idx = idx[(node + 1) * interval_size :] + + simulation_replicas = make_replica( + mu[simulation_idx], var[simulation_idx], high, np.ones(simulation_idx.shape[0], dtype=np.int32), z_score + )[0] + current_replicas_f = make_replica( + mu[current_idx], var[current_idx], high, np.ones(current_idx.shape[0], dtype=np.int32), z_score + )[1] + remaind_replicas_f = make_replica( + mu[remaind_idx], var[remaind_idx], high, np.ones(remaind_idx.shape[0], dtype=np.int32), z_score + )[1] + + initial_replicas = (simulation_replicas[:interval_size] - 1).sum() + + best_replica, _, _ = self.neighbor_search( + low, + high, + initial_replicas, + width, + lambda mid, + ci=current_idx, + crf=current_replicas_f, + ri=remaind_idx, + rrf=remaind_replicas_f, + nar=num_available_replicas: get_score( + _lpt_deployment, X_row, deployed_replicas, ci, crf[mid], ri, rrf[nar - mid] + ), + ) + + deployed_replicas[current_idx] = current_replicas_f[best_replica] + num_available_replicas -= best_replica + + if not num_available_replicas or node == depth - 2: + deployed_replicas[remaind_idx] = remaind_replicas_f[num_available_replicas] + break + + final_deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + final_deployment = lpt_deployment( + mu, var, cov, final_deployment, np.zeros_like(deployed_replicas), deployed_replicas, z_score + ) + final_par = compute_score(X_row, deployed_replicas, final_deployment) + + return final_deployment, deployed_replicas, final_par class FlashLB(EplbPolicy): + """ + Flash Load Balancing (FlashLB) policy for expert deployment optimization + Implements layered tree search with load balance score optimization + """ + def __init__(self, config: DynamicConfig): + """ + Initialize FlashLB policy with dynamic configuration + + Args: + config: Dynamic configuration object containing policy parameters + """ super().__init__(config) - self.par_history: dict[int, float] = {} - self.hotness_window: dict[int, deque[float]] = {} - self.max_stage_window = config.max_stage_window if hasattr(config, "max_stage_window") else 1 - self.buffer_expert_layer_num = ( - config.buffer_expert_layer_num if hasattr(config, "buffer_expert_layer_num") else 58 - ) - self.threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else 0 - - def compute_expert_hotness(self, num_of_expert: int, deployment: np.ndarray, rank_load: np.ndarray): - hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) - deployment_flat = deployment.ravel() - rank_load_flat = rank_load.ravel() - np.add.at(hotness, deployment_flat, rank_load_flat) - return hotness - - def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray): - n_stage, N = hotness.shape - if np.any(deployment < 0): - raise ValueError("Deployment table contains negative values.") - counts = np.bincount(deployment.reshape(-1), minlength=N) - unit_hotness = np.divide(hotness, counts, out=np.zeros_like(hotness, dtype=float), where=counts != 0) - stage_par = np.zeros(n_stage) - for i in range(n_stage): - stage_load = unit_hotness[i][deployment].sum(-1) - stage_par[i] = stage_load.max() / stage_load.mean() - return stage_par.mean() - - def group_based_adaptive_bloating(self, X, P, M, stage_weights=None, recorsive=False): - n_stage, N = X.shape - if stage_weights is None: - stage_weights = np.ones(n_stage, dtype=np.float32) - - if recorsive: - ( - simulated_deployment, - simulated_pieces, - ) = self.group_based_adaptive_bloating(X, P, M, stage_weights, recorsive=False) + if config.ep_worldsize >= 32: + threshold_ratio = 0.9 + threshold_value = 0.85 else: - simulated_pieces = compute_piece_counts(X, P, stage_weights) - simulated_deployment = jsq_placement(X, simulated_pieces, M, stage_weights) + threshold_ratio = 0.95 + threshold_value = 0.9 - pieces = group_based_adaptive_bloating_kernel( - X.astype(np.float32), - P, - M, - simulated_pieces.astype(np.int32), - simulated_deployment.astype(np.int32), - stage_weights.astype(np.float32), + # Max window size for expert hotness observation + self.max_observation_window = config.max_stage_window if hasattr(config, "max_stage_window") else 2000 + # Threshold ratio for load balance update trigger + self.update_threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else threshold_ratio + # Threshold value for load balance update trigger + self.update_threshold_value = config.threshold_value if hasattr(config, "threshold_value") else threshold_value + # Upper bound of layers to update per iteration + self.update_layers_upper_bound = config.layers_upper_bound if hasattr(config, "layers_upper_bound") else -1 + # Z-score for risk calculation (default: 75% confidence) + self.z_score = config.z_score if hasattr(config, "z_score") else stats.norm.ppf(0.75) + # Tree search depth for flash_tree algorithm + self.depth = config.depth if hasattr(config, "depth") else 4 + # Tree search width for neighbor search + self.width = config.width if hasattr(config, "width") else 8 + self.sample_size = ( + config.sample_size if hasattr(config, "sample_size") else min(self.max_observation_window, 64) ) - deployment = jsq_placement(X, pieces, M, stage_weights) + # Runtime state storage with type annotations + self.average_to_peak_history: dict[int, float] = {} # Layer-wise load balance history + self.hotness_window: dict[int, dict[str, Any]] = {} # Layer-wise hotness stats and buffer + self.current_deployment: dict[int, np.ndarray] = {} # Current expert deployment per layer + self.current_deployed_replicas: dict[int, np.ndarray] = {} # Current replica count per expert per layer - X_all = X.sum(0) - unit_load = np.divide(X_all, pieces, out=np.zeros_like(X_all, dtype=float), where=pieces != 0) - load = unit_load[deployment].sum(-1) + def min_max_replica( + self, mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float + ) -> tuple[np.ndarray, np.ndarray]: + """ + Wrapper for original min-max replica allocation - sim_unit_load = X_all / simulated_pieces - sim_load = sim_unit_load[simulated_deployment].sum(-1) + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) - if load.max() > sim_load.max(): - return simulated_deployment, simulated_pieces - return deployment, pieces + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + return min_max_replica(mu, var, num_available_replicas, current_replicas, z_score) - def need_update(self, current_par, layer_id=0): - threshold = self.par_history.get(layer_id, 0.0) - return current_par >= self.threshold_ratio * threshold + @staticmethod + def compute_statistics(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute mean, variance and covariance matrix from time series data - def compute_stage_weight(self, hotness): - n_stage = hotness.shape[0] - stage_weights = np.zeros(n_stage) - for i in range(n_stage): - stage_weights[i] = hotness[i].sum() + Args: + X: Time series data (T, N) - T steps, N experts - stage_weights = stage_weights / stage_weights.max() - return stage_weights + Returns: + mean: Mean load per expert (N,) + variance: Variance of load per expert (N,) + cov_matrix: Covariance matrix between experts (N,N) + """ + T, N = X.shape + mean_ = np.mean(X, axis=0) + if T > 1: + X_centered = X - mean_ + variance_ = np.sum(X_centered**2, axis=0) / (T - 1) + cov_matrix = (X_centered.T @ X_centered) / (T - 1) + else: + # Zero stats if only one sample + variance_ = np.zeros((N,)) + cov_matrix = np.zeros((N, N)) + return mean_, variance_, cov_matrix - def rebalance_layer(self, deployment, hotness, layer_id=0): - num_rank, expert_per_rank = deployment.shape - num_expert = np.unique(deployment.reshape(-1)).shape[0] - num_of_redundant_expert = num_rank * expert_per_rank - num_expert + @staticmethod + def sliding_update_stats( + mean: np.ndarray, cov: np.ndarray, x_old: np.ndarray, x_new: np.ndarray, T: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Update statistics with sliding window (replace old data with new data) - current_par = self.compute_rank_load(deployment, hotness) + Args: + mean: Current mean statistics (N,) + cov: Current covariance matrix (N,N) + x_old: Old data batch to remove (t, N) + x_new: New data batch to add (t, N) + T: Window size - if not self.need_update(current_par, layer_id): - return deployment, current_par, current_par + Returns: + new_mean: Updated mean (N,) + new_var: Updated variance (N,) + new_cov: Updated covariance matrix (N,N) + """ + assert x_new.shape == x_old.shape + mean = mean.astype(np.float64, copy=False) + cov = cov.astype(np.float64, copy=False) + x_old = x_old.astype(np.float64, copy=False) + x_new = x_new.astype(np.float64, copy=False) - stage_weights = self.compute_stage_weight(hotness) - new_deployment, _ = self.group_based_adaptive_bloating( - hotness, - num_expert + num_of_redundant_expert, - num_rank, - stage_weights, - recorsive=False, - ) - new_par = self.compute_rank_load(new_deployment, hotness) + # Update mean + sum_old = np.sum(x_old, axis=0) + sum_new = np.sum(x_new, axis=0) + deltaS = sum_new - sum_old + new_mean = mean + deltaS / T - return new_deployment, new_par, current_par + # Update covariance matrix + x_old_centered = x_old - mean + x_new_centered = x_new - mean + + SA_mu = np.dot(x_old_centered.T, x_old_centered) + SB_mu = np.dot(x_new_centered.T, x_new_centered) + + Sigma = cov * (T - 1) + Sigma_new = Sigma + SB_mu - SA_mu - np.outer(deltaS, deltaS) / T + new_cov = Sigma_new / (T - 1) + + new_var = np.diag(new_cov) + return new_mean, new_var, new_cov + + @staticmethod + def incremental_update_stats( + mean: np.ndarray, cov: np.ndarray, x_new: np.ndarray, T: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """ + Incrementally update statistics with new data (expand window) + + Args: + mean: Current mean statistics (N,) + cov: Current covariance matrix (N,N) + x_new: New data batch to add (t, N) + T: Current window size + + Returns: + new_mean: Updated mean (N,) + new_var: Updated variance (N,) + new_cov: Updated covariance matrix (N,N) + new_T: Updated window size + """ + t, N = x_new.shape + sum_new = np.sum(x_new, axis=0) + new_T = T + t + + # Update mean + new_mean = (T * mean + sum_new) / new_T + + # Update covariance matrix + if T > 1: + x_new_centered = x_new - new_mean + cov_new = cov * (T - 1) + cov_new += np.dot(x_new_centered.T, x_new_centered) + cov_new += T * np.outer(mean - new_mean, mean - new_mean) + new_cov = cov_new / (new_T - 1) + else: + # Special case for initial single sample + x_old = mean.reshape(1, -1) + x_old_centered = x_old - new_mean + x_new_centered = x_new - new_mean + sum_squares = np.dot(x_old_centered.T, x_old_centered) + np.dot(x_new_centered.T, x_new_centered) + new_cov = sum_squares / (new_T - 1) + + new_var = np.diag(new_cov) + return new_mean, new_var, new_cov, new_T + + def register_hotness( + self, deployment: np.ndarray, rank_load: np.ndarray, num_layers: int, num_experts: int + ) -> None: + """ + Update expert hotness statistics with sliding window for all layers + + Args: + deployment: Expert deployment matrix (num_layers, num_devices, num_slots) + rank_load: Load data (num_stages, num_layers, num_devices, num_slots) + num_layers: Total number of layers + num_experts: Total number of experts + """ + num_stage = rank_load.shape[0] + hotness = np.zeros((num_stage, num_layers, num_experts), dtype=rank_load.dtype) + for stage in range(num_stage): + for layer in range(num_layers): + deployment_flat = deployment[layer].ravel() + rank_load_flat = rank_load[stage, layer].ravel() + np.add.at(hotness[stage, layer], deployment_flat, rank_load_flat) + + hotness += 1 + window_length = self.max_observation_window + + for layer in range(num_layers): + new_X = hotness[-window_length:, layer, :] + t = new_X.shape[0] - def register_hotness(self, deployment, rank_load, num_layer, num_expert): - for layer in range(num_layer): if layer not in self.hotness_window: - self.hotness_window[layer] = deque(maxlen=self.max_stage_window) - hotness = self.compute_expert_hotness(num_expert, deployment[layer], rank_load[layer]) - self.hotness_window[layer].append(hotness) + self.hotness_window[layer] = { + "buffer": np.zeros((window_length, num_experts), dtype=new_X.dtype), + "start": 0, + "length": 0, + } - def compress_by_avg_pooling_fast_nd(self, arr, m): - n, d = arr.shape - idx = np.arange(n) * m // n - result = np.zeros((m, d)) - counts = np.zeros((m, 1)) - np.add.at(result, idx, arr) - np.add.at(counts, idx, 1) - return result / counts + info = self.hotness_window[layer] + buf = info["buffer"] + start = info["start"] + length = info["length"] - def rebalance_experts(self, current_expert_table, expert_workload): + if start + t <= window_length: + buf[start : start + t] = new_X + else: + first_part = window_length - start + buf[start:] = new_X[:first_part] + buf[: t - first_part] = new_X[first_part:] + + start = (start + t) % window_length + length = min(window_length, length + t) + + self.hotness_window[layer]["buffer"] = buf + self.hotness_window[layer]["start"] = start + self.hotness_window[layer]["length"] = length + + def need_update(self, layer_id: int = 0) -> bool: + """ + Check if layer needs load balance update + Trigger update if load balance ratio drops below threshold + + Args: + layer_id: Layer index to check + + Returns: + bool: True if update is needed, False otherwise + """ + past_average_to_peak_ratio = self.average_to_peak_history.get(layer_id, 0.0) + if past_average_to_peak_ratio == 0.0: + # Force update for first iteration + return True + + # Calculate current load balance ratio (average/peak load) + hotness = self.hotness_window[layer_id]["buffer"] + average_to_peak_ratio = 1 / compute_score( + hotness, self.current_deployed_replicas[layer_id], self.current_deployment[layer_id] + ) + + # Check update conditions + return ( + average_to_peak_ratio < past_average_to_peak_ratio * self.update_threshold_ratio + or average_to_peak_ratio < self.update_threshold_value + ) + + @staticmethod + @njit + def compute_match(src_counts: np.ndarray, dst_counts: np.ndarray, N: int, M: int) -> np.ndarray: + """ + Compute match matrix between source and destination expert counts + match[i,j] = total min(src_counts[i,k], dst_counts[j,k]) for all k + + Args: + src_counts: Source expert count histogram (N, max_val) + dst_counts: Destination expert count histogram (N, max_val) + N: Number of rows in deployment matrix + M: Number of columns in deployment matrix + + Returns: + matches: Match matrix (N, N) + """ + matches = np.zeros((N, N), dtype=np.int32) + for i in range(N): + for j in range(N): + match = 0 + for k in range(N * M): + match += min(src_counts[i, k], dst_counts[j, k]) + matches[i, j] = match + return matches + + @staticmethod + def minimize_redeploy_with_inner_permutation(src: np.ndarray, dst: np.ndarray) -> np.ndarray: + """ + Minimize expert redeployment by permuting destination rows/columns + Aligns destination deployment with source to reduce expert movement + + Args: + src: Source deployment matrix (N, M) + dst: Destination deployment matrix (N, M) + + Returns: + dst_reordered: Reordered destination deployment matrix + """ + if src.shape != dst.shape: + raise ValueError("src and dst must have same shape (N, M)") + N, M = src.shape + valid_src = src + valid_dst = dst + + # Calculate expert count histogram for each row + max_val = N * M + src_counts = np.array([np.bincount(row[row != -1], minlength=max_val) for row in valid_src], dtype=np.int32) + dst_counts = np.array([np.bincount(row[row != -1], minlength=max_val) for row in valid_dst], dtype=np.int32) + + # Compute match matrix and optimal row mapping (Hungarian algorithm) + matches = FlashLB.compute_match(src_counts, dst_counts, N, M) + cost = M - matches + row_ind, col_ind = linear_sum_assignment(cost) + mapping = list(zip(row_ind.tolist(), col_ind.tolist())) + + # Reorder dst rows and columns to align with src + dst_reordered = np.empty_like(dst) + for src_idx, dst_idx in mapping: + s_row = src[src_idx] + d_row = dst[dst_idx] + # Map expert values to their positions in dst row + val_to_positions: dict[int, list[int]] = {} # Add type annotation for mypy + for pos, v in enumerate(d_row): + val_to_positions.setdefault(v, []).append(pos) + + reordered = np.empty(M, dtype=dst.dtype) + assigned = [False] * M + used_dst_positions = set() + + # Assign existing experts first to minimize movement + for pos_src, v in enumerate(s_row): + positions = val_to_positions.get(v) + if positions: + dst_pos = positions.pop() + reordered[pos_src] = v + assigned[pos_src] = True + used_dst_positions.add(dst_pos) + + # Fill remaining positions with unassigned experts + remaining = [d_row[p] for p in range(M) if p not in used_dst_positions] + ri = 0 + for pos in range(M): + if not assigned[pos]: + reordered[pos] = remaining[ri] + ri += 1 + dst_reordered[src_idx] = reordered + return dst_reordered + + def rebalance_experts( + self, current_expert_table: torch.Tensor, expert_workload: torch.Tensor + ) -> tuple[bool, np.ndarray, np.ndarray]: + """ + Main expert rebalance entry point + Optimizes expert deployment to improve load balance + + Args: + current_expert_table: Current expert deployment (layers, devices, slots) + expert_workload: Expert load data (stages, layers, devices, slots) + + Returns: + change: True if any layers were updated + priority_idx: Indices of updated layers (sorted by improvement) + new_deployment: Updated expert deployment matrix + """ current_deployment = np.array(current_expert_table) expert_workload = np.array(expert_workload) - expert_workload += 1 - num_layer = expert_workload.shape[0] - num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] - self.register_hotness(current_deployment, expert_workload, num_layer, num_expert) - new_deployment = current_deployment.copy() + # Add batch dimension if missing + if expert_workload.ndim == 3: + expert_workload = expert_workload[np.newaxis, ...] + num_layers = expert_workload.shape[1] + num_expert = np.unique(current_deployment[0].reshape(-1)).shape[0] + num_devices = current_deployment.shape[1] + num_replicas = len(current_deployment[0].reshape(-1)) - layers_need_update = np.arange(num_layer) + # Update expert hotness statistics + self.register_hotness(current_deployment, expert_workload, num_layers, num_expert) - new_par = np.zeros(layers_need_update.shape[0]) - current_par = np.zeros(layers_need_update.shape[0]) - for i, layer in enumerate(layers_need_update): - hotness = np.array(self.hotness_window[layer]) - if hotness.shape[0] > self.max_stage_window: - hotness = self.compress_by_avg_pooling_fast_nd(hotness, self.max_stage_window) + # Initialize current deployment state for all layers + for layer in range(num_layers): + self.current_deployment[layer] = current_deployment[layer] + self.current_deployed_replicas[layer] = np.bincount( + current_deployment[layer].reshape(-1), minlength=num_expert + ) - ( - new_deployment[layer], - new_par[i], - current_par[i], - ) = self.rebalance_layer(current_deployment[layer], hotness, layer_id=layer) + # Initialize output variables + new_par = np.zeros((num_layers,), dtype=np.float32) + new_deployment = np.zeros((num_layers, num_devices, num_replicas // num_devices), dtype=np.int32) + new_deployed_replicas = np.zeros((num_layers, num_expert), dtype=np.int32) + new_average_to_peak_ratio = np.zeros((num_layers,), dtype=np.float32) + delta_average_to_peak_ratio = np.zeros((num_layers,), dtype=np.float32) + pars = np.zeros((num_layers,), dtype=np.float32) - priority = new_par / current_par - priority_idx = np.argsort(priority) - priority_idx = priority_idx[priority[priority_idx] < 1][: self.buffer_expert_layer_num] + # Optimize each layer + for layer in range(num_layers): + if not self.need_update(layer): + # Keep current deployment if no update needed + new_deployment[layer] = self.current_deployment[layer] + new_deployed_replicas[layer] = self.current_deployed_replicas[layer] + new_average_to_peak_ratio[layer] = self.average_to_peak_history.get(layer, 0.0) + new_par[layer] = 1 / new_average_to_peak_ratio[layer] if new_average_to_peak_ratio[layer] != 0 else 0.0 + delta_average_to_peak_ratio[layer] = 0 + continue - if np.all(expert_workload == 1): - for _, layer in enumerate(layers_need_update): - self.hotness_window[layer].pop() - return False, np.array([], dtype=int), current_deployment + # Get layer hotness stats + layer_info = self.hotness_window[layer] + buf = layer_info["buffer"] + start = layer_info["start"] + length = layer_info["length"] + + # Get valid hotness data from sliding window + idx = np.arange(start, start + length) % self.max_observation_window + data = buf[idx] + + shape = data.shape + window = max(length // self.sample_size, 1) + data = data[-window * self.sample_size :].reshape((-1, window, *shape[1:])).sum(1) + # Flash tree search for optimal deployment + flash_tree = FlashTree(data, num_replicas, num_devices, self.z_score, self.depth, self.width) + best_deployment, best_replicas, best_score = flash_tree.optimize_balanceness() + + # Update layer state + new_deployed_replicas[layer] = best_replicas + new_average_to_peak_ratio[layer] = 1 / best_score + + current_deployment = self.current_deployment.get(layer, None) + + new_deployment[layer] = best_deployment + # Minimize redeployment by permuting new deployment + new_deployment[layer] = FlashLB.minimize_redeploy_with_inner_permutation( + current_deployment, best_deployment + ) + current_average_to_peak_ratio = 1 / compute_score( + buf, self.current_deployed_replicas.get(layer), current_deployment + ) + delta_average_to_peak_ratio[layer] = new_average_to_peak_ratio[layer] - current_average_to_peak_ratio + pars[layer] = best_score + + # Select layers to update (sorted by improvement, positive delta only) + priority_idx = np.argsort(-delta_average_to_peak_ratio) + priority_idx = priority_idx[delta_average_to_peak_ratio[priority_idx] > 0] + # Apply upper bound of layers to update + if self.update_layers_upper_bound > 0: + priority_idx = priority_idx[: self.update_layers_upper_bound] + + # Update global state with optimal deployments + for layer in priority_idx: + self.current_deployment[layer] = new_deployment[layer] + self.current_deployed_replicas[layer] = new_deployed_replicas[layer] + self.average_to_peak_history[layer] = new_average_to_peak_ratio[layer] + + # Return update flag and results change = len(priority_idx) > 0 - if change: - for idx in priority_idx: - self.par_history[layers_need_update[idx]] = new_par[idx] - - layers_need_update = priority_idx - deployment = current_deployment - for layer in layers_need_update: - deployment[layer] = auto_fix_new_placement(current_deployment[layer], new_deployment[layer]) - - return change, layers_need_update, deployment + return change, priority_idx, new_deployment -def generate_layered_experts(num_layers=58, layer_shape=(32, 9), expert_min=0, expert_max=255): +def generate_layered_experts( + num_layers: int = 58, layer_shape: tuple[int, int] = (32, 9), expert_min: int = 0, expert_max: int = 255 +) -> torch.Tensor: """ - Generate expert deployment matrix meeting the following conditions: - - Total of num_layers layers - - Each layer has shape layer_shape (32,9) - - Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer + Generate layered expert deployment matrix + Each layer contains all experts [expert_min, expert_max] at least once + Remaining slots filled with random experts Args: - num_layers: Number of layers, default 58 - layer_shape: Shape of a single layer, default (32,9) - expert_min: Minimum expert ID, default 0 - expert_max: Maximum expert ID, default 255 - Returns: - torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1]) - """ - # 1. Basic parameter calculation - expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) - layer_total = layer_shape[0] * layer_shape[1] # Total elements in a single layer: 32*9=288 - extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 + num_layers: Number of layers to generate + layer_shape: Shape of each layer's deployment matrix (rows, cols) + expert_min: Minimum expert ID (inclusive) + expert_max: Maximum expert ID (inclusive) - # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) + Returns: + layers: Layered expert deployment tensor (num_layers, *layer_shape) + """ + # Basic parameter calculation + expert_num = expert_max - expert_min + 1 + layer_total = layer_shape[0] * layer_shape[1] + extra_slots = layer_total - expert_num + + # Feasibility check: layer capacity ≥ expert count assert layer_total >= expert_num, ( - f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, cannot cover all experts" + f"Layer element count {layer_total} < expert count {expert_num}, cannot cover all experts" ) - # 3. Generate layers one by one - layers = [] + # Generate each layer + layers: list[torch.Tensor] = [] for _ in range(num_layers): - # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) - full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # shape (256,) + # Full expert sequence (cover all experts once) + full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # (expert_num,) - # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) - extra_experts = torch.randint(expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64) # shape (32,) + # Random extra experts for remaining slots + extra_experts = torch.randint( + expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64 + ) # (extra_slots,) - # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) - layer_flat = torch.cat([full_experts, extra_experts], dim=0) # shape (288,) - # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) + # Concatenate and shuffle for random distribution + layer_flat = torch.cat([full_experts, extra_experts], dim=0) # (layer_total,) shuffle_idx = torch.randperm(layer_flat.shape[0]) layer_shuffled = layer_flat[shuffle_idx] - # 3.4 Reshape to layer_shape (32,9) + # Reshape to target layer shape layer = layer_shuffled.reshape(layer_shape) layers.append(layer) - # 4. Stack all layers to get the final tensor - return torch.stack(layers, dim=0) # shape (58,32,9) + # Stack all layers + return torch.stack(layers, dim=0) # (num_layers, *layer_shape) -def warm_up(): +def warm_up() -> None: + """ + Warm up FlashLB algorithm with dummy data + Pre-compiles numba functions and initializes state + """ exam_config = DynamicConfig() exam_config.ep_worldsize = 32 exam_config.num_die_per_host = 16 algo = FlashLB(exam_config) - # Generate target tensor + # Generate dummy expert deployment tensor expert_tensor = generate_layered_experts(num_layers=58, layer_shape=(32, 9)) - - algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) + # Run rebalance with dummy workload data + algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (100, 58, 32, 9))) diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 4d3da725..fa06737d 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -30,6 +30,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess class EplbUpdator: def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process): self.eplb_config = eplb_config + self.multi_stage = eplb_config.eplb_policy_type == 3 self.init_eplb(self.eplb_config.expert_map_path, process) self.eplb_loader = loader self.eplb_process = eplb_process @@ -131,9 +132,14 @@ class EplbUpdator: def compute_and_set_moe_load(self): local_load = self.adaptor.get_rank_expert_workload() - moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]) + moe_load = ( + self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu() + ) - self.shared_dict["moe_load"] = moe_load.cpu() + if self.multi_stage: + moe_load = moe_load.permute(2, 0, 1, 3) + + self.shared_dict["moe_load"] = moe_load logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") return moe_load diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index d5df0ad8..4304931b 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -299,7 +299,13 @@ class AscendFusedMoE(FusedMoE): get_compressed_expert_map(self._expert_map), ) if self.dynamic_eplb: + self.multi_stage = False self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu() + if eplb_config.eplb_policy_type == 3: + self.multi_stage = True + self.load_counter = torch.tensor(0, dtype=torch.int32, device="npu") + self.num_iter = eplb_config.expert_heat_collection_interval + self.moe_load = torch.zeros((self.num_iter, self.local_num_experts), dtype=torch.int32, device="npu") self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts @@ -361,6 +367,8 @@ class AscendFusedMoE(FusedMoE): def clear_moe_load(self): if self.moe_load is not None: self.moe_load.zero_() + if self.multi_stage: + self.load_counter.zero_() def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, @@ -493,7 +501,14 @@ class AscendFusedMoE(FusedMoE): if group_list_type == 1 else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) ) - self.moe_load.add_(local_load) + if self.multi_stage: + cur_iter = torch.remainder(self.load_counter, self.num_iter) + self.moe_load.index_add_( + dim=0, index=cur_iter, source=local_load.to(torch.int32, non_blocking=True).view(1, -1) + ) + self.load_counter.add_(1) + else: + self.moe_load.add_(local_load) routed_out = forward_context.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results,