diff --git a/vllm_ascend/eplb/core/eplb_worker.py b/vllm_ascend/eplb/core/eplb_worker.py index dc32ce24..b9a62a0d 100644 --- a/vllm_ascend/eplb/core/eplb_worker.py +++ b/vllm_ascend/eplb/core/eplb_worker.py @@ -34,6 +34,7 @@ class EplbWorker: self.old_expert_maps = None self.enable_d2d = enable_d2d self.rank_id = dist.get_rank() + self.multi_stage = policy_type == 3 def do_update(self): # put data in to queue @@ -62,7 +63,10 @@ class EplbWorker: _, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement) if self.rank_id == 0: - hotness = self._calculate_hotness(old_placement, load_info) + if self.multi_stage: + hotness = self._calculate_hotness(old_placement, load_info.sum(0)) + else: + hotness = self._calculate_hotness(old_placement, load_info) current_mean, current_max = self._compute_imbalance(old_placement, hotness) update_mean, update_max = self._compute_imbalance(new_placement, hotness) logger.info( diff --git a/vllm_ascend/eplb/core/policy/policy_flashlb.py b/vllm_ascend/eplb/core/policy/policy_flashlb.py index 117b8ed6..a33f1d39 100644 --- a/vllm_ascend/eplb/core/policy/policy_flashlb.py +++ b/vllm_ascend/eplb/core/policy/policy_flashlb.py @@ -1,597 +1,1022 @@ # Copyright Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -# Todo: Once https://github.com/vllm-project/vllm/pull/24069 is merged in vllm. Remove this policy. +# Todo: Remove this policy after vllm-project/vllm/pull/24069 is merged into vllm. import logging -from collections import deque +import math +from typing import Any import numpy as np import torch from numba import njit # type: ignore +from scipy import stats # type: ignore +from scipy.optimize import linear_sum_assignment # type: ignore from .policy_abstract import DynamicConfig, EplbPolicy +# Suppress numba log warnings numba_logger = logging.getLogger("numba") numba_logger.setLevel(logging.WARNING) -@njit -def compute_piece_counts(X, P, stage_weights): - n_stage, N = X.shape - S = P - N - pieces = np.ones(N, dtype=np.int32) - unit = X / pieces # unit[i, j] = X[i, j] / pieces[j] - - for _ in range(S): - deltas = np.zeros(N, dtype=np.float32) - for i in range(n_stage): - # Find top1 and top2 - idx1 = -1 - idx2 = -1 - val1 = -1.0 - val2 = -1.0 - for j in range(N): - v = unit[i, j] - if v > val1: - val2 = val1 - idx2 = idx1 - val1 = v - idx1 = j - elif v > val2: - val2 = v - idx2 = j - - origin = unit[i, idx1] - secv = unit[i, idx2] - alt = X[i, idx1] / (pieces[idx1] + 1) - delta = origin - (alt if alt > secv else secv) - deltas[idx1] += delta * stage_weights[i] if np.any(delta) != 0 else stage_weights[i] - - max_idx = np.argmax(deltas) - pieces[max_idx] += 1 - for i in range(n_stage): - unit[i, max_idx] = X[i, max_idx] / pieces[max_idx] - - # Compute max load - max_load = 0.0 - for j in range(N): - total = 0.0 - for i in range(n_stage): - total += unit[i, j] - if total > max_load: - max_load = total - - return pieces - - -@njit -def jsq_placement(X, pieces, M, stage_weights): - n_stage, N = X.shape - total_piece = pieces.sum() - num_per_group = total_piece // M - - # 1. Compute unit_hotness - unit_hotness = np.empty((n_stage, N), dtype=np.float32) - for i in range(N): - if pieces[i] > 0: - for s in range(n_stage): - unit_hotness[s, i] = X[s, i] / pieces[i] - else: - for s in range(n_stage): - unit_hotness[s, i] = 0.0 - - # 2. Sort by total hotness - scores = np.zeros(N, dtype=np.float32) - for i in range(N): - for s in range(n_stage): - scores[i] += unit_hotness[s, i] - idx = np.argsort(-scores) - - # 3. Initialization - loads = np.zeros((n_stage, M), dtype=np.float32) - dev_phy_exp_n = np.zeros(M, dtype=np.int32) - deployment = -np.ones((M, num_per_group), dtype=np.int32) - dep_ptr = np.zeros(M, dtype=np.int32) - - # 4. Main loop - for t in range(N): - i = idx[t] - used_device = list() - for _ in range(pieces[i]): - # 4.1 Construct w vector - w = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - w[s] = unit_hotness[s, i] - - # 4.2 Compute stage-level maximum load - stage_max = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - max_val = loads[s, 0] - for k in range(1, M): - if loads[s, k] > max_val: - max_val = loads[s, k] - stage_max[s] = max_val - - # 4.3 Compute denominator - denom = np.empty(n_stage, dtype=np.float32) - for s in range(n_stage): - sum_tmp = 0.0 - for j in range(M): - sum_tmp += loads[s, j] + w[s] - denom[s] = sum_tmp / M + 1e-2 - - # 4.4 Find best device j - best_j = -1 - best_val = 1e30 - for j in range(M): - if dev_phy_exp_n[j] >= num_per_group: - continue - if j in used_device: - continue - score = 0.0 - for s in range(n_stage): - tmp_sj = loads[s, j] + w[s] - number_sj = tmp_sj if tmp_sj > stage_max[s] else stage_max[s] - score += stage_weights[s] * (number_sj / denom[s]) - if score < best_val: - best_val = score - best_j = j - if best_j == -1: - continue - - used_device.append(best_j) - - # 4.5 Update status - for s in range(n_stage): - loads[s, best_j] += w[s] - ptr = dep_ptr[best_j] - deployment[best_j, ptr] = i - dep_ptr[best_j] += 1 - dev_phy_exp_n[best_j] += 1 - - # Handle remaining -1 values: fill with random elements from range(N) not in current column - for rank in range(M): - for col in range(num_per_group): - if deployment[rank, col] == -1: - # Get elements already in current column - current_rank_elements = set(deployment[rank, :]) - # Filter elements from range(N) not in current column - available = [x for x in range(N) if x not in current_rank_elements] - # Randomly select an available element to fill - if len(available) > 0: - rand_idx = np.random.randint(0, len(available)) - deployment[rank, col] = available[rand_idx] - elif N > 0: - # All unique experts are already in this rank's column, so we can pick any expert randomly. - deployment[rank, col] = np.random.randint(0, N) - - return deployment - - -@njit -def slice_values(X, pieces): - total_len = 0 - for i in range(X.shape[0]): - total_len += pieces[i] - result = np.empty(total_len, dtype=np.float32) - idx = 0 - for i in range(X.shape[0]): - val = X[i] / pieces[i] - for _ in range(pieces[i]): - result[idx] = val - idx += 1 - return result - - -@njit -def group_based_adaptive_bloating_kernel(X, P, M, simulated_pieces, simulated_deployment, stage_weights): - n_stage, N = X.shape - num_group = P // M - - X_all = np.zeros(N, dtype=np.float32) - for i in range(n_stage): - for j in range(N): - X_all[j] += X[i, j] - - sort_idx = np.argsort(np.negative(X_all)) - X_sorted = X[:, sort_idx] - - unit_load = np.empty(N, dtype=np.float32) - for j in range(N): - unit_load[j] = X_all[j] / simulated_pieces[j] - - flat_deployment = simulated_deployment.reshape(-1) - simulated_load = np.zeros(M, dtype=np.float32) - for i in range(flat_deployment.shape[0]): - simulated_load[i // (flat_deployment.shape[0] // M)] += unit_load[flat_deployment[i]] - - slice_vals = slice_values(X_all, simulated_pieces) - sorted_slices = np.sort(slice_vals)[::-1] - simulated_slopes = (sorted_slices[: -M + 1] - sorted_slices[M - 1 :]) / M - - cumulative_slices_used = np.zeros(N, dtype=np.int32) - acc = 0 - for i in range(N): - acc += simulated_pieces[sort_idx[i]] - cumulative_slices_used[i] = acc - - group_boundary_indices = np.zeros(num_group, dtype=np.int32) - for i in range(1, num_group + 1): - for j in range(N): - if cumulative_slices_used[j] >= i * M: - group_boundary_indices[i - 1] = j - break - - slices_used_per_group = np.zeros(num_group, dtype=np.int32) - slices_used_per_group[0] = group_boundary_indices[0] - for i in range(1, num_group): - slices_used_per_group[i] = group_boundary_indices[i] - group_boundary_indices[i - 1] - slices_used_per_group = M - slices_used_per_group - - loads = np.zeros(M, dtype=np.float32) - pieces = np.zeros(N, dtype=np.int32) - num_remain_slice = P - N - current_idx = 0 - - for g in range(num_group): - window = X_sorted[:, current_idx : current_idx + 2 * M] - low = max(0, current_idx + M - N) - high = min(num_remain_slice, M - 1) - - while (high - low) > 1: - mid = int((high + low) // 2) - keep = M - mid - current_group = window[:, :keep] - current_pieces = compute_piece_counts(current_group, M, stage_weights) - current_pieces = np.maximum(current_pieces, 1) - current_slice = slice_values(current_group.sum(0), current_pieces) - current_slice_sorted = np.sort(current_slice) - current_loads = loads + current_slice_sorted - current_max: np.float32 = np.max(current_loads) - current_min: np.float32 = np.min(current_loads) - current_slope = (current_max - current_min) / M - next_slope: np.float32 = np.max(simulated_slopes[current_idx + keep :]) - - if abs(current_slope) > abs(next_slope): - low = mid - else: - high = mid - - S = high - keep = M - S - current_group = window[:, :keep] - current_pieces = compute_piece_counts(current_group, M, stage_weights) - - for i in range(keep): - pieces[sort_idx[current_idx + i]] = current_pieces[i] - - current_slice = slice_values(current_group.sum(0), current_pieces) - current_slice_sorted = np.sort(current_slice) - loads += current_slice_sorted - loads = np.sort(loads)[::-1] - - current_idx += keep - num_remain_slice -= S - - return pieces - - -@njit -def compute_objective(deployment, X, pieces): - M, P = deployment.shape - loads = np.zeros(M) - - for i in range(M): - for j in range(P): - expert = deployment[i, j] - if pieces[expert] == 0: - continue - loads[i] += X[expert] / pieces[expert] - - mean_load = np.mean(loads) - max_load: np.float32 = np.max(loads) - obj = max_load / mean_load - return obj, loads - - -@njit -def auto_fix_new_placement(old_placement, new_placement): +@njit(fastmath=True, cache=True) +def min_max_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: """ - Adjust the new_placement matrix to ensure elements (including duplicates) that exist in both - old_placement and new_placement remain in their original positions from old_placement. - New elements (unique to new_placement) will fill the remaining empty positions. + Original min-max replica allocation algorithm + Allocates replicas iteratively to expert with maximum unit load value Args: - old_placement: Old deployment matrix with shape (num_ranks, num_experts) - new_placement: New deployment matrix to be fixed, must have the same shape as old_placement + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) Returns: - fixed_new: adjusted version of the new_placement matrix + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) """ - num_ranks, num_experts = old_placement.shape - fixed_new = np.empty_like(new_placement) + N = mu.shape[0] + unit_value = (mu + z_score * np.sqrt(var)) / current_replicas + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] - max_expert_old = old_placement.max() if num_experts > 0 else 0 - max_expert_new = new_placement.max() if num_experts > 0 else 0 - max_expert = max(max_expert_old, max_expert_new) + # Allocate replicas to expert with maximum unit value iteratively + for r in range(num_available_replicas): + max_idx = -1 + max_value = -1.0 + for idx in range(N): + value = unit_value[idx] + if value > max_value: + max_value = value + max_idx = idx - for rank_id in range(num_ranks): - old_row = old_placement[rank_id] - new_row = new_placement[rank_id] + current_replicas[max_idx] += 1 + unit_value[max_idx] = (mu[max_idx] + z_score * np.sqrt(var[max_idx])) / current_replicas[max_idx] + replicas_history[r + 1, :] = current_replicas[:] - index_array = np.full((max_expert + 1, num_experts), -1, dtype=np.int32) - count_array = np.zeros(max_expert + 1, dtype=np.int32) + return current_replicas, replicas_history - for idx in range(num_experts): - val = old_row[idx] - if val >= 0 and val <= max_expert: - pos = count_array[val] - index_array[val, pos] = idx - count_array[val] += 1 - old_counter = np.zeros(max_expert + 1, dtype=np.int32) - for idx in range(num_experts): - val = old_row[idx] - if val >= 0 and val <= max_expert: - old_counter[val] += 1 +@njit +def max_delta_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: + """ + Maximum delta replica allocation algorithm + Allocates replicas by maximum unit value delta after increment - retain_elements = np.empty(num_experts, dtype=new_placement.dtype) - new_elements = np.empty(num_experts, dtype=new_placement.dtype) - retain_ptr = 0 - new_ptr = 0 + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) - for val in new_row: - if val >= 0 and val <= max_expert and old_counter[val] > 0: - retain_elements[retain_ptr] = val - retain_ptr += 1 - old_counter[val] -= 1 - else: - new_elements[new_ptr] = val - new_ptr += 1 + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + N = mu.shape[0] + unit_value = (mu + z_score * np.sqrt(var)) / current_replicas + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] - current_fixed = np.full(num_experts, -1, dtype=new_placement.dtype) + # Allocate replicas by maximum unit value delta after increment + for r in range(num_available_replicas): + max_idx = -1 + max_value = -1.0 + for idx in range(N): + value = unit_value[idx] / (current_replicas[idx] + 1) + if value > max_value: + max_value = value + max_idx = idx - for i in range(retain_ptr): - val = retain_elements[i] - if val >= 0 and val <= max_expert: - pos = count_array[val] - 1 - if pos >= 0: - idx = index_array[val, pos] - current_fixed[idx] = val - count_array[val] -= 1 + current_replicas[max_idx] += 1 + unit_value[max_idx] = (mu[max_idx] + z_score * np.sqrt(var[max_idx])) / current_replicas[max_idx] + replicas_history[r + 1, :] = current_replicas[:] - empty_indices = np.empty(num_experts, dtype=np.int32) - empty_ptr = 0 - for idx in range(num_experts): - if current_fixed[idx] == -1: - empty_indices[empty_ptr] = idx - empty_ptr += 1 + return current_replicas, replicas_history - for i in range(new_ptr): - if i < empty_ptr: - current_fixed[empty_indices[i]] = new_elements[i] - fixed_new[rank_id] = current_fixed +@njit +def percentage_replica( + mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float +) -> tuple[np.ndarray, np.ndarray]: + """ + Proportional replica allocation algorithm + Allocates replicas proportionally to expert total load - return fixed_new + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) + + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + N = mu.shape[0] + total_load = mu + z_score * np.sqrt(var) + sum_total_load: float = np.sum(total_load) # Add type annotation for mypy + + replicas_history = np.ones((num_available_replicas + 1, N), dtype=np.int32) + replicas_history[0, :] = current_replicas[:] + + # Allocate replicas proportionally to expert total load + for r in range(1, num_available_replicas + 1): + add_slots = np.zeros(N, dtype=np.int32) + + if sum_total_load == 0.0: + # Average allocation if total load is zero + base_add = r // N + extra = r % N + add_slots[:] = base_add + add_slots[:extra] += 1 + else: + # Proportional allocation with remainder compensation + quotas = (total_load / sum_total_load) * r + base_add = np.floor(quotas).astype(np.int32) + add_slots[:] = base_add + remaining = r - np.sum(base_add) + + if remaining > 0: + fractions = quotas - base_add + indices = np.argsort(-fractions) + add_slots[indices[:remaining]] += 1 + + replicas_history[r] = current_replicas + add_slots + + return replicas_history[-1], replicas_history + + +def make_replica( + mu: np.ndarray, + var: np.ndarray, + num_available_replicas: int, + current_replicas: np.ndarray, + z_score: float, + method: str = "percentage", +) -> tuple[np.ndarray, np.ndarray]: + if method == "percentage": + return percentage_replica(mu, var, num_available_replicas, current_replicas, z_score) + elif method == "max_delta": + return max_delta_replica(mu, var, num_available_replicas, current_replicas, z_score) + else: + return min_max_replica(mu, var, num_available_replicas, current_replicas, z_score) + + +@njit(fastmath=True, cache=True) +def compute_updated_device_variance( + new_expert_id: int, + device_slots: np.ndarray, + current_device_var: float, + expert_var: np.ndarray, + expert_cov: np.ndarray, + expert_replicas: np.ndarray, +) -> float: + """ + Compute updated device variance after adding a new expert + Includes both individual variance and covariance with existing experts + + Args: + new_expert_id: ID of the new expert to add + device_slots: Current expert slots on the device (-1 for empty) + current_device_var: Current variance of the device + expert_var: Variance of each expert's load (N,) + expert_cov: Covariance matrix between experts (N,N) + expert_replicas: Replica count per expert (N,) + + Returns: + new_device_var: Updated device variance after adding the expert + """ + # Add variance of new expert + new_device_var = current_device_var + expert_var[new_expert_id] / expert_replicas[new_expert_id] ** 2 + + # Add covariance between new expert and existing experts on device + for slot in device_slots: + if slot == -1: + break + new_device_var += 2 * expert_cov[new_expert_id, slot] / expert_replicas[new_expert_id] / expert_replicas[slot] + + return new_device_var + + +@njit(fastmath=True, cache=True) +def lpt_deployment( + mu: np.ndarray, + var: np.ndarray, + cov: np.ndarray, + deployment: np.ndarray, + deployed_replicas: np.ndarray, + total_replicas: np.ndarray, + z_score: float, +) -> np.ndarray: + """ + Largest Processing Time (LPT) deployment algorithm + Greedily deploys experts to device with minimal risk (mean + z*std) + + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + cov: Covariance matrix between experts (N,N) + deployment: Initial deployment matrix (num_devices, num_slots_per_device) + deployed_replicas: Already deployed replica count per expert (N,) + total_replicas: Total target replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) + + Returns: + new_deployment: Updated expert deployment matrix + """ + num_devices, num_slots_per_device = deployment.shape + + # Initialize unit value and sort experts by load + unit_value = mu / total_replicas + sorted_indices = np.argsort(-unit_value) + + new_deployment = -np.ones_like(deployment) + device_mu = np.zeros(num_devices, dtype=np.float32) + device_var = np.zeros(num_devices, dtype=np.float32) + dev_ptr = np.zeros(num_devices, dtype=np.int32) + + # Copy existing deployment first + for dev in range(num_devices): + for slot in deployment[dev]: + if slot != -1: + device_mu[dev] += mu[slot] / total_replicas[slot] + device_var[dev] += compute_updated_device_variance( + slot, new_deployment[dev], device_var[dev], var, cov, total_replicas + ) + new_deployment[dev, dev_ptr[dev]] = slot + dev_ptr[dev] += 1 + + # Greedily deploy remaining replicas to device with minimal risk + for idx in sorted_indices: + for _ in range(total_replicas[idx] - deployed_replicas[idx]): + best_dev = -1 + best_risk = 1e30 + best_mu = -1.0 + best_var = -1.0 + for dev in range(num_devices): + if dev_ptr[dev] >= num_slots_per_device: + continue + if idx in new_deployment[dev]: + continue + # Calculate temporary device load and risk + temp_mu = device_mu[dev] + mu[idx] / total_replicas[idx] + temp_var = compute_updated_device_variance( + idx, new_deployment[dev], device_var[dev], var, cov, total_replicas + ) + + risk = temp_mu + z_score * np.sqrt(temp_var) + if risk < best_risk: + best_risk = risk + best_dev = dev + best_mu = temp_mu + best_var = temp_var + + # Update device state with best deployment choice + device_mu[best_dev] = best_mu + device_var[best_dev] = best_var + new_deployment[best_dev, dev_ptr[best_dev]] = idx + dev_ptr[best_dev] += 1 + + return new_deployment + + +@njit(fastmath=True, cache=True) +def compute_score(val_data: np.ndarray, simulated_replicas: np.ndarray, simulated_deployment: np.ndarray) -> np.float32: + """ + Calculate load balance score: (max_device_load * num_devices) / total_load + Lower score means better load balance + + Args: + val_data: Validation load data (T, N) - T time steps, N experts + simulated_replicas: Replica count per expert (N,) + simulated_deployment: Expert deployment matrix (D, K) - D devices, K slots + + Returns: + mean_score: Average load balance score over time steps + """ + T, N = val_data.shape + D, K = simulated_deployment.shape + scores = np.empty((T,), dtype=np.float32) + for t in range(T): + max_load = 0.0 # Explicit float type to avoid int/float mix + tot_load = 0.0 + for d in range(D): + s = 0.0 + for k in range(K): + idx = simulated_deployment[d, k] + s += val_data[t, idx] / simulated_replicas[idx] + tot_load += s + max_load = max(max_load, s) + # Add small epsilon to avoid division by zero + scores[t] = (max_load * D + 1e-2) / (tot_load + 1e-2) + + return np.mean(scores) + + +class FlashTree: + def __init__(self, X, num_replicas, num_devices, z_score=0.674, depth=4, width=8): + super().__init__() + self.num_replicas = num_replicas + self.num_devices = num_devices + self.z_score = z_score + self.depth = depth + self.width = width + + self.X = X + self.mu, self.var, self.cov = FlashTree.compute_statistics(X) + + @staticmethod + def compute_statistics(X): + T, N = X.shape + mean_ = np.mean(X, axis=0) + if T > 1: + X_centered = X - mean_ + variance_ = np.sum(X_centered**2, axis=0) / (T - 1) + cov_matrix = (X_centered.T @ X_centered) / (T - 1) + else: + variance_ = np.zeros((N,)) + cov_matrix = np.zeros((N, N)) + return mean_, variance_, cov_matrix + + def neighbor_search( + self, low: int, high: int, initial: int, max_range: int, get_score: Any, *args: Any + ) -> tuple[int, float, np.ndarray]: + """ + Local neighbor search for optimal replica number + Search [initial-max_range, initial+max_range] within [low, high] + + Args: + low: Lower bound of search range + high: Upper bound of search range + initial: Initial replica number to start search + max_range: Maximum search range from initial value + get_score: Function to compute score for a given replica number + *args: Additional arguments for get_score function + + Returns: + best_x: Optimal replica number + best_score: Best load balance score + best_sim: Corresponding deployment simulation result + """ + max_range = min(max(initial - low, high - initial), max_range) + best_x = initial + best_score, best_sim = get_score(initial, *args) + + # Search left and right neighbors + for r in range(1, max_range + 1): + left = initial - r + if left >= low: + score, sim = get_score(left, *args) + if score < best_score: + best_x, best_score, best_sim = left, score, sim + + right = initial + r + if right <= high: + score, sim = get_score(right, *args) + if score < best_score: + best_x, best_score, best_sim = right, score, sim + + return best_x, best_score, best_sim + + def optimize_balanceness(self): + X_row = self.X + mu, var, cov = self.mu, self.var, self.cov + num_total_replicas = self.num_replicas + num_devices = self.num_devices + z_score = self.z_score + depth, width = self.depth, self.width + + num_experts = mu.shape[0] + num_available_replicas = num_total_replicas - num_experts + + if depth <= 1: + default_replicas = np.ones(num_experts, dtype=np.int32) + default_replicas = make_replica(mu, var, num_available_replicas, default_replicas, z_score)[0] + default_deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + default_deployment = lpt_deployment( + mu, var, cov, default_deployment, np.zeros(num_experts, dtype=np.int32), default_replicas, z_score + ) + default_par = compute_score(X_row, default_replicas, default_deployment) + return default_deployment, default_replicas, default_par + + interval_size = math.ceil(num_experts / depth) + weight = mu + z_score * np.sqrt(var) + idx = np.argsort(-weight) + + deployed_replicas = np.zeros(num_experts, dtype=np.int32) + deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + + def _lpt_deployment(replicas): + nonlocal mu, var, cov, deployment, deployed_replicas, z_score + return lpt_deployment(mu, var, cov, deployment, np.zeros_like(replicas), replicas, z_score) + + def get_score( + f: Any, + val_data: np.ndarray, + deployed_replicas: np.ndarray, + current_idx: np.ndarray, + current_replicas: np.ndarray, + remaind_idx: np.ndarray, + remaind_replicas: np.ndarray, + ) -> tuple[float, np.ndarray]: + """ + Wrapper to compute load balance score for replica allocation simulation + + Args: + f: Deployment function (e.g., _lpt_deployment) + val_data: Validation load data (T, N) + deployed_replicas: Already deployed replica count per expert (N,) + current_idx: Indices of current expert group + current_replicas: Replica count for current expert group + remaind_idx: Indices of remaining expert group + remaind_replicas: Replica count for remaining expert group + + Returns: + score: Load balance score + simulated_deployment: Simulated expert deployment matrix + """ + # Simulate replica allocation and deployment + simulated_replicas = deployed_replicas.copy() + simulated_replicas[current_idx] = current_replicas + simulated_replicas[remaind_idx] = remaind_replicas + simulated_deployment = f(simulated_replicas) + + # Calculate load balance score + score = compute_score(val_data, simulated_replicas, simulated_deployment) + return score, simulated_deployment + + for node in range(depth - 1): + low, high = 0, num_available_replicas + simulation_idx = idx[node * interval_size :] + current_idx = idx[node * interval_size : (node + 1) * interval_size] + remaind_idx = idx[(node + 1) * interval_size :] + + simulation_replicas = make_replica( + mu[simulation_idx], var[simulation_idx], high, np.ones(simulation_idx.shape[0], dtype=np.int32), z_score + )[0] + current_replicas_f = make_replica( + mu[current_idx], var[current_idx], high, np.ones(current_idx.shape[0], dtype=np.int32), z_score + )[1] + remaind_replicas_f = make_replica( + mu[remaind_idx], var[remaind_idx], high, np.ones(remaind_idx.shape[0], dtype=np.int32), z_score + )[1] + + initial_replicas = (simulation_replicas[:interval_size] - 1).sum() + + best_replica, _, _ = self.neighbor_search( + low, + high, + initial_replicas, + width, + lambda mid, + ci=current_idx, + crf=current_replicas_f, + ri=remaind_idx, + rrf=remaind_replicas_f, + nar=num_available_replicas: get_score( + _lpt_deployment, X_row, deployed_replicas, ci, crf[mid], ri, rrf[nar - mid] + ), + ) + + deployed_replicas[current_idx] = current_replicas_f[best_replica] + num_available_replicas -= best_replica + + if not num_available_replicas or node == depth - 2: + deployed_replicas[remaind_idx] = remaind_replicas_f[num_available_replicas] + break + + final_deployment = -np.ones((num_devices, num_total_replicas // num_devices), dtype=np.int32) + final_deployment = lpt_deployment( + mu, var, cov, final_deployment, np.zeros_like(deployed_replicas), deployed_replicas, z_score + ) + final_par = compute_score(X_row, deployed_replicas, final_deployment) + + return final_deployment, deployed_replicas, final_par class FlashLB(EplbPolicy): + """ + Flash Load Balancing (FlashLB) policy for expert deployment optimization + Implements layered tree search with load balance score optimization + """ + def __init__(self, config: DynamicConfig): + """ + Initialize FlashLB policy with dynamic configuration + + Args: + config: Dynamic configuration object containing policy parameters + """ super().__init__(config) - self.par_history: dict[int, float] = {} - self.hotness_window: dict[int, deque[float]] = {} - self.max_stage_window = config.max_stage_window if hasattr(config, "max_stage_window") else 1 - self.buffer_expert_layer_num = ( - config.buffer_expert_layer_num if hasattr(config, "buffer_expert_layer_num") else 58 - ) - self.threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else 0 - - def compute_expert_hotness(self, num_of_expert: int, deployment: np.ndarray, rank_load: np.ndarray): - hotness = np.zeros(num_of_expert, dtype=rank_load.dtype) - deployment_flat = deployment.ravel() - rank_load_flat = rank_load.ravel() - np.add.at(hotness, deployment_flat, rank_load_flat) - return hotness - - def compute_rank_load(self, deployment: np.ndarray, hotness: np.ndarray): - n_stage, N = hotness.shape - if np.any(deployment < 0): - raise ValueError("Deployment table contains negative values.") - counts = np.bincount(deployment.reshape(-1), minlength=N) - unit_hotness = np.divide(hotness, counts, out=np.zeros_like(hotness, dtype=float), where=counts != 0) - stage_par = np.zeros(n_stage) - for i in range(n_stage): - stage_load = unit_hotness[i][deployment].sum(-1) - stage_par[i] = stage_load.max() / stage_load.mean() - return stage_par.mean() - - def group_based_adaptive_bloating(self, X, P, M, stage_weights=None, recorsive=False): - n_stage, N = X.shape - if stage_weights is None: - stage_weights = np.ones(n_stage, dtype=np.float32) - - if recorsive: - ( - simulated_deployment, - simulated_pieces, - ) = self.group_based_adaptive_bloating(X, P, M, stage_weights, recorsive=False) + if config.ep_worldsize >= 32: + threshold_ratio = 0.9 + threshold_value = 0.85 else: - simulated_pieces = compute_piece_counts(X, P, stage_weights) - simulated_deployment = jsq_placement(X, simulated_pieces, M, stage_weights) + threshold_ratio = 0.95 + threshold_value = 0.9 - pieces = group_based_adaptive_bloating_kernel( - X.astype(np.float32), - P, - M, - simulated_pieces.astype(np.int32), - simulated_deployment.astype(np.int32), - stage_weights.astype(np.float32), + # Max window size for expert hotness observation + self.max_observation_window = config.max_stage_window if hasattr(config, "max_stage_window") else 2000 + # Threshold ratio for load balance update trigger + self.update_threshold_ratio = config.threshold_ratio if hasattr(config, "threshold_ratio") else threshold_ratio + # Threshold value for load balance update trigger + self.update_threshold_value = config.threshold_value if hasattr(config, "threshold_value") else threshold_value + # Upper bound of layers to update per iteration + self.update_layers_upper_bound = config.layers_upper_bound if hasattr(config, "layers_upper_bound") else -1 + # Z-score for risk calculation (default: 75% confidence) + self.z_score = config.z_score if hasattr(config, "z_score") else stats.norm.ppf(0.75) + # Tree search depth for flash_tree algorithm + self.depth = config.depth if hasattr(config, "depth") else 4 + # Tree search width for neighbor search + self.width = config.width if hasattr(config, "width") else 8 + self.sample_size = ( + config.sample_size if hasattr(config, "sample_size") else min(self.max_observation_window, 64) ) - deployment = jsq_placement(X, pieces, M, stage_weights) + # Runtime state storage with type annotations + self.average_to_peak_history: dict[int, float] = {} # Layer-wise load balance history + self.hotness_window: dict[int, dict[str, Any]] = {} # Layer-wise hotness stats and buffer + self.current_deployment: dict[int, np.ndarray] = {} # Current expert deployment per layer + self.current_deployed_replicas: dict[int, np.ndarray] = {} # Current replica count per expert per layer - X_all = X.sum(0) - unit_load = np.divide(X_all, pieces, out=np.zeros_like(X_all, dtype=float), where=pieces != 0) - load = unit_load[deployment].sum(-1) + def min_max_replica( + self, mu: np.ndarray, var: np.ndarray, num_available_replicas: int, current_replicas: np.ndarray, z_score: float + ) -> tuple[np.ndarray, np.ndarray]: + """ + Wrapper for original min-max replica allocation - sim_unit_load = X_all / simulated_pieces - sim_load = sim_unit_load[simulated_deployment].sum(-1) + Args: + mu: Mean load of each expert (N,) + var: Variance of each expert's load (N,) + num_available_replicas: Total available replicas to allocate + current_replicas: Initial replica count per expert (N,) + z_score: Z-score for risk calculation (confidence level) - if load.max() > sim_load.max(): - return simulated_deployment, simulated_pieces - return deployment, pieces + Returns: + current_replicas: Updated replica count per expert (N,) + replicas_history: Replica allocation history (num_available_replicas+1, N) + """ + return min_max_replica(mu, var, num_available_replicas, current_replicas, z_score) - def need_update(self, current_par, layer_id=0): - threshold = self.par_history.get(layer_id, 0.0) - return current_par >= self.threshold_ratio * threshold + @staticmethod + def compute_statistics(X: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute mean, variance and covariance matrix from time series data - def compute_stage_weight(self, hotness): - n_stage = hotness.shape[0] - stage_weights = np.zeros(n_stage) - for i in range(n_stage): - stage_weights[i] = hotness[i].sum() + Args: + X: Time series data (T, N) - T steps, N experts - stage_weights = stage_weights / stage_weights.max() - return stage_weights + Returns: + mean: Mean load per expert (N,) + variance: Variance of load per expert (N,) + cov_matrix: Covariance matrix between experts (N,N) + """ + T, N = X.shape + mean_ = np.mean(X, axis=0) + if T > 1: + X_centered = X - mean_ + variance_ = np.sum(X_centered**2, axis=0) / (T - 1) + cov_matrix = (X_centered.T @ X_centered) / (T - 1) + else: + # Zero stats if only one sample + variance_ = np.zeros((N,)) + cov_matrix = np.zeros((N, N)) + return mean_, variance_, cov_matrix - def rebalance_layer(self, deployment, hotness, layer_id=0): - num_rank, expert_per_rank = deployment.shape - num_expert = np.unique(deployment.reshape(-1)).shape[0] - num_of_redundant_expert = num_rank * expert_per_rank - num_expert + @staticmethod + def sliding_update_stats( + mean: np.ndarray, cov: np.ndarray, x_old: np.ndarray, x_new: np.ndarray, T: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Update statistics with sliding window (replace old data with new data) - current_par = self.compute_rank_load(deployment, hotness) + Args: + mean: Current mean statistics (N,) + cov: Current covariance matrix (N,N) + x_old: Old data batch to remove (t, N) + x_new: New data batch to add (t, N) + T: Window size - if not self.need_update(current_par, layer_id): - return deployment, current_par, current_par + Returns: + new_mean: Updated mean (N,) + new_var: Updated variance (N,) + new_cov: Updated covariance matrix (N,N) + """ + assert x_new.shape == x_old.shape + mean = mean.astype(np.float64, copy=False) + cov = cov.astype(np.float64, copy=False) + x_old = x_old.astype(np.float64, copy=False) + x_new = x_new.astype(np.float64, copy=False) - stage_weights = self.compute_stage_weight(hotness) - new_deployment, _ = self.group_based_adaptive_bloating( - hotness, - num_expert + num_of_redundant_expert, - num_rank, - stage_weights, - recorsive=False, - ) - new_par = self.compute_rank_load(new_deployment, hotness) + # Update mean + sum_old = np.sum(x_old, axis=0) + sum_new = np.sum(x_new, axis=0) + deltaS = sum_new - sum_old + new_mean = mean + deltaS / T - return new_deployment, new_par, current_par + # Update covariance matrix + x_old_centered = x_old - mean + x_new_centered = x_new - mean + + SA_mu = np.dot(x_old_centered.T, x_old_centered) + SB_mu = np.dot(x_new_centered.T, x_new_centered) + + Sigma = cov * (T - 1) + Sigma_new = Sigma + SB_mu - SA_mu - np.outer(deltaS, deltaS) / T + new_cov = Sigma_new / (T - 1) + + new_var = np.diag(new_cov) + return new_mean, new_var, new_cov + + @staticmethod + def incremental_update_stats( + mean: np.ndarray, cov: np.ndarray, x_new: np.ndarray, T: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, int]: + """ + Incrementally update statistics with new data (expand window) + + Args: + mean: Current mean statistics (N,) + cov: Current covariance matrix (N,N) + x_new: New data batch to add (t, N) + T: Current window size + + Returns: + new_mean: Updated mean (N,) + new_var: Updated variance (N,) + new_cov: Updated covariance matrix (N,N) + new_T: Updated window size + """ + t, N = x_new.shape + sum_new = np.sum(x_new, axis=0) + new_T = T + t + + # Update mean + new_mean = (T * mean + sum_new) / new_T + + # Update covariance matrix + if T > 1: + x_new_centered = x_new - new_mean + cov_new = cov * (T - 1) + cov_new += np.dot(x_new_centered.T, x_new_centered) + cov_new += T * np.outer(mean - new_mean, mean - new_mean) + new_cov = cov_new / (new_T - 1) + else: + # Special case for initial single sample + x_old = mean.reshape(1, -1) + x_old_centered = x_old - new_mean + x_new_centered = x_new - new_mean + sum_squares = np.dot(x_old_centered.T, x_old_centered) + np.dot(x_new_centered.T, x_new_centered) + new_cov = sum_squares / (new_T - 1) + + new_var = np.diag(new_cov) + return new_mean, new_var, new_cov, new_T + + def register_hotness( + self, deployment: np.ndarray, rank_load: np.ndarray, num_layers: int, num_experts: int + ) -> None: + """ + Update expert hotness statistics with sliding window for all layers + + Args: + deployment: Expert deployment matrix (num_layers, num_devices, num_slots) + rank_load: Load data (num_stages, num_layers, num_devices, num_slots) + num_layers: Total number of layers + num_experts: Total number of experts + """ + num_stage = rank_load.shape[0] + hotness = np.zeros((num_stage, num_layers, num_experts), dtype=rank_load.dtype) + for stage in range(num_stage): + for layer in range(num_layers): + deployment_flat = deployment[layer].ravel() + rank_load_flat = rank_load[stage, layer].ravel() + np.add.at(hotness[stage, layer], deployment_flat, rank_load_flat) + + hotness += 1 + window_length = self.max_observation_window + + for layer in range(num_layers): + new_X = hotness[-window_length:, layer, :] + t = new_X.shape[0] - def register_hotness(self, deployment, rank_load, num_layer, num_expert): - for layer in range(num_layer): if layer not in self.hotness_window: - self.hotness_window[layer] = deque(maxlen=self.max_stage_window) - hotness = self.compute_expert_hotness(num_expert, deployment[layer], rank_load[layer]) - self.hotness_window[layer].append(hotness) + self.hotness_window[layer] = { + "buffer": np.zeros((window_length, num_experts), dtype=new_X.dtype), + "start": 0, + "length": 0, + } - def compress_by_avg_pooling_fast_nd(self, arr, m): - n, d = arr.shape - idx = np.arange(n) * m // n - result = np.zeros((m, d)) - counts = np.zeros((m, 1)) - np.add.at(result, idx, arr) - np.add.at(counts, idx, 1) - return result / counts + info = self.hotness_window[layer] + buf = info["buffer"] + start = info["start"] + length = info["length"] - def rebalance_experts(self, current_expert_table, expert_workload): + if start + t <= window_length: + buf[start : start + t] = new_X + else: + first_part = window_length - start + buf[start:] = new_X[:first_part] + buf[: t - first_part] = new_X[first_part:] + + start = (start + t) % window_length + length = min(window_length, length + t) + + self.hotness_window[layer]["buffer"] = buf + self.hotness_window[layer]["start"] = start + self.hotness_window[layer]["length"] = length + + def need_update(self, layer_id: int = 0) -> bool: + """ + Check if layer needs load balance update + Trigger update if load balance ratio drops below threshold + + Args: + layer_id: Layer index to check + + Returns: + bool: True if update is needed, False otherwise + """ + past_average_to_peak_ratio = self.average_to_peak_history.get(layer_id, 0.0) + if past_average_to_peak_ratio == 0.0: + # Force update for first iteration + return True + + # Calculate current load balance ratio (average/peak load) + hotness = self.hotness_window[layer_id]["buffer"] + average_to_peak_ratio = 1 / compute_score( + hotness, self.current_deployed_replicas[layer_id], self.current_deployment[layer_id] + ) + + # Check update conditions + return ( + average_to_peak_ratio < past_average_to_peak_ratio * self.update_threshold_ratio + or average_to_peak_ratio < self.update_threshold_value + ) + + @staticmethod + @njit + def compute_match(src_counts: np.ndarray, dst_counts: np.ndarray, N: int, M: int) -> np.ndarray: + """ + Compute match matrix between source and destination expert counts + match[i,j] = total min(src_counts[i,k], dst_counts[j,k]) for all k + + Args: + src_counts: Source expert count histogram (N, max_val) + dst_counts: Destination expert count histogram (N, max_val) + N: Number of rows in deployment matrix + M: Number of columns in deployment matrix + + Returns: + matches: Match matrix (N, N) + """ + matches = np.zeros((N, N), dtype=np.int32) + for i in range(N): + for j in range(N): + match = 0 + for k in range(N * M): + match += min(src_counts[i, k], dst_counts[j, k]) + matches[i, j] = match + return matches + + @staticmethod + def minimize_redeploy_with_inner_permutation(src: np.ndarray, dst: np.ndarray) -> np.ndarray: + """ + Minimize expert redeployment by permuting destination rows/columns + Aligns destination deployment with source to reduce expert movement + + Args: + src: Source deployment matrix (N, M) + dst: Destination deployment matrix (N, M) + + Returns: + dst_reordered: Reordered destination deployment matrix + """ + if src.shape != dst.shape: + raise ValueError("src and dst must have same shape (N, M)") + N, M = src.shape + valid_src = src + valid_dst = dst + + # Calculate expert count histogram for each row + max_val = N * M + src_counts = np.array([np.bincount(row[row != -1], minlength=max_val) for row in valid_src], dtype=np.int32) + dst_counts = np.array([np.bincount(row[row != -1], minlength=max_val) for row in valid_dst], dtype=np.int32) + + # Compute match matrix and optimal row mapping (Hungarian algorithm) + matches = FlashLB.compute_match(src_counts, dst_counts, N, M) + cost = M - matches + row_ind, col_ind = linear_sum_assignment(cost) + mapping = list(zip(row_ind.tolist(), col_ind.tolist())) + + # Reorder dst rows and columns to align with src + dst_reordered = np.empty_like(dst) + for src_idx, dst_idx in mapping: + s_row = src[src_idx] + d_row = dst[dst_idx] + # Map expert values to their positions in dst row + val_to_positions: dict[int, list[int]] = {} # Add type annotation for mypy + for pos, v in enumerate(d_row): + val_to_positions.setdefault(v, []).append(pos) + + reordered = np.empty(M, dtype=dst.dtype) + assigned = [False] * M + used_dst_positions = set() + + # Assign existing experts first to minimize movement + for pos_src, v in enumerate(s_row): + positions = val_to_positions.get(v) + if positions: + dst_pos = positions.pop() + reordered[pos_src] = v + assigned[pos_src] = True + used_dst_positions.add(dst_pos) + + # Fill remaining positions with unassigned experts + remaining = [d_row[p] for p in range(M) if p not in used_dst_positions] + ri = 0 + for pos in range(M): + if not assigned[pos]: + reordered[pos] = remaining[ri] + ri += 1 + dst_reordered[src_idx] = reordered + return dst_reordered + + def rebalance_experts( + self, current_expert_table: torch.Tensor, expert_workload: torch.Tensor + ) -> tuple[bool, np.ndarray, np.ndarray]: + """ + Main expert rebalance entry point + Optimizes expert deployment to improve load balance + + Args: + current_expert_table: Current expert deployment (layers, devices, slots) + expert_workload: Expert load data (stages, layers, devices, slots) + + Returns: + change: True if any layers were updated + priority_idx: Indices of updated layers (sorted by improvement) + new_deployment: Updated expert deployment matrix + """ current_deployment = np.array(current_expert_table) expert_workload = np.array(expert_workload) - expert_workload += 1 - num_layer = expert_workload.shape[0] - num_expert = np.unique(current_expert_table[0].reshape(-1)).shape[0] - self.register_hotness(current_deployment, expert_workload, num_layer, num_expert) - new_deployment = current_deployment.copy() + # Add batch dimension if missing + if expert_workload.ndim == 3: + expert_workload = expert_workload[np.newaxis, ...] + num_layers = expert_workload.shape[1] + num_expert = np.unique(current_deployment[0].reshape(-1)).shape[0] + num_devices = current_deployment.shape[1] + num_replicas = len(current_deployment[0].reshape(-1)) - layers_need_update = np.arange(num_layer) + # Update expert hotness statistics + self.register_hotness(current_deployment, expert_workload, num_layers, num_expert) - new_par = np.zeros(layers_need_update.shape[0]) - current_par = np.zeros(layers_need_update.shape[0]) - for i, layer in enumerate(layers_need_update): - hotness = np.array(self.hotness_window[layer]) - if hotness.shape[0] > self.max_stage_window: - hotness = self.compress_by_avg_pooling_fast_nd(hotness, self.max_stage_window) + # Initialize current deployment state for all layers + for layer in range(num_layers): + self.current_deployment[layer] = current_deployment[layer] + self.current_deployed_replicas[layer] = np.bincount( + current_deployment[layer].reshape(-1), minlength=num_expert + ) - ( - new_deployment[layer], - new_par[i], - current_par[i], - ) = self.rebalance_layer(current_deployment[layer], hotness, layer_id=layer) + # Initialize output variables + new_par = np.zeros((num_layers,), dtype=np.float32) + new_deployment = np.zeros((num_layers, num_devices, num_replicas // num_devices), dtype=np.int32) + new_deployed_replicas = np.zeros((num_layers, num_expert), dtype=np.int32) + new_average_to_peak_ratio = np.zeros((num_layers,), dtype=np.float32) + delta_average_to_peak_ratio = np.zeros((num_layers,), dtype=np.float32) + pars = np.zeros((num_layers,), dtype=np.float32) - priority = new_par / current_par - priority_idx = np.argsort(priority) - priority_idx = priority_idx[priority[priority_idx] < 1][: self.buffer_expert_layer_num] + # Optimize each layer + for layer in range(num_layers): + if not self.need_update(layer): + # Keep current deployment if no update needed + new_deployment[layer] = self.current_deployment[layer] + new_deployed_replicas[layer] = self.current_deployed_replicas[layer] + new_average_to_peak_ratio[layer] = self.average_to_peak_history.get(layer, 0.0) + new_par[layer] = 1 / new_average_to_peak_ratio[layer] if new_average_to_peak_ratio[layer] != 0 else 0.0 + delta_average_to_peak_ratio[layer] = 0 + continue - if np.all(expert_workload == 1): - for _, layer in enumerate(layers_need_update): - self.hotness_window[layer].pop() - return False, np.array([], dtype=int), current_deployment + # Get layer hotness stats + layer_info = self.hotness_window[layer] + buf = layer_info["buffer"] + start = layer_info["start"] + length = layer_info["length"] + + # Get valid hotness data from sliding window + idx = np.arange(start, start + length) % self.max_observation_window + data = buf[idx] + + shape = data.shape + window = max(length // self.sample_size, 1) + data = data[-window * self.sample_size :].reshape((-1, window, *shape[1:])).sum(1) + # Flash tree search for optimal deployment + flash_tree = FlashTree(data, num_replicas, num_devices, self.z_score, self.depth, self.width) + best_deployment, best_replicas, best_score = flash_tree.optimize_balanceness() + + # Update layer state + new_deployed_replicas[layer] = best_replicas + new_average_to_peak_ratio[layer] = 1 / best_score + + current_deployment = self.current_deployment.get(layer, None) + + new_deployment[layer] = best_deployment + # Minimize redeployment by permuting new deployment + new_deployment[layer] = FlashLB.minimize_redeploy_with_inner_permutation( + current_deployment, best_deployment + ) + current_average_to_peak_ratio = 1 / compute_score( + buf, self.current_deployed_replicas.get(layer), current_deployment + ) + delta_average_to_peak_ratio[layer] = new_average_to_peak_ratio[layer] - current_average_to_peak_ratio + pars[layer] = best_score + + # Select layers to update (sorted by improvement, positive delta only) + priority_idx = np.argsort(-delta_average_to_peak_ratio) + priority_idx = priority_idx[delta_average_to_peak_ratio[priority_idx] > 0] + # Apply upper bound of layers to update + if self.update_layers_upper_bound > 0: + priority_idx = priority_idx[: self.update_layers_upper_bound] + + # Update global state with optimal deployments + for layer in priority_idx: + self.current_deployment[layer] = new_deployment[layer] + self.current_deployed_replicas[layer] = new_deployed_replicas[layer] + self.average_to_peak_history[layer] = new_average_to_peak_ratio[layer] + + # Return update flag and results change = len(priority_idx) > 0 - if change: - for idx in priority_idx: - self.par_history[layers_need_update[idx]] = new_par[idx] - - layers_need_update = priority_idx - deployment = current_deployment - for layer in layers_need_update: - deployment[layer] = auto_fix_new_placement(current_deployment[layer], new_deployment[layer]) - - return change, layers_need_update, deployment + return change, priority_idx, new_deployment -def generate_layered_experts(num_layers=58, layer_shape=(32, 9), expert_min=0, expert_max=255): +def generate_layered_experts( + num_layers: int = 58, layer_shape: tuple[int, int] = (32, 9), expert_min: int = 0, expert_max: int = 255 +) -> torch.Tensor: """ - Generate expert deployment matrix meeting the following conditions: - - Total of num_layers layers - - Each layer has shape layer_shape (32,9) - - Each expert from expert_min to expert_max (0 to 255) appears at least once in each layer + Generate layered expert deployment matrix + Each layer contains all experts [expert_min, expert_max] at least once + Remaining slots filled with random experts Args: - num_layers: Number of layers, default 58 - layer_shape: Shape of a single layer, default (32,9) - expert_min: Minimum expert ID, default 0 - expert_max: Maximum expert ID, default 255 - Returns: - torch.Tensor: Tensor with shape (num_layers, layer_shape[0], layer_shape[1]) - """ - # 1. Basic parameter calculation - expert_num = expert_max - expert_min + 1 # Total number of experts: 256 (0~255) - layer_total = layer_shape[0] * layer_shape[1] # Total elements in a single layer: 32*9=288 - extra_slots = layer_total - expert_num # Number of random positions to fill per layer: 288-256=32 + num_layers: Number of layers to generate + layer_shape: Shape of each layer's deployment matrix (rows, cols) + expert_min: Minimum expert ID (inclusive) + expert_max: Maximum expert ID (inclusive) - # 2. Verify feasibility (total elements must be ≥ number of experts to cover all experts) + Returns: + layers: Layered expert deployment tensor (num_layers, *layer_shape) + """ + # Basic parameter calculation + expert_num = expert_max - expert_min + 1 + layer_total = layer_shape[0] * layer_shape[1] + extra_slots = layer_total - expert_num + + # Feasibility check: layer capacity ≥ expert count assert layer_total >= expert_num, ( - f"Number of elements in a single layer {layer_total} < number of experts {expert_num}, cannot cover all experts" + f"Layer element count {layer_total} < expert count {expert_num}, cannot cover all experts" ) - # 3. Generate layers one by one - layers = [] + # Generate each layer + layers: list[torch.Tensor] = [] for _ in range(num_layers): - # 3.1 Generate "complete expert sequence" (ensure each expert from 0 to 255 is included) - full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # shape (256,) + # Full expert sequence (cover all experts once) + full_experts = torch.arange(expert_min, expert_max + 1, dtype=torch.int64) # (expert_num,) - # 3.2 Generate "supplementary random experts" (fill remaining 32 positions, randomly selected from 0~255) - extra_experts = torch.randint(expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64) # shape (32,) + # Random extra experts for remaining slots + extra_experts = torch.randint( + expert_min, expert_max + 1, size=(extra_slots,), dtype=torch.int64 + ) # (extra_slots,) - # 3.3 Concatenate and shuffle (ensure random distribution of experts in each layer) - layer_flat = torch.cat([full_experts, extra_experts], dim=0) # shape (288,) - # Shuffle order (use randperm to generate random indices to avoid repeated shuffling issues) + # Concatenate and shuffle for random distribution + layer_flat = torch.cat([full_experts, extra_experts], dim=0) # (layer_total,) shuffle_idx = torch.randperm(layer_flat.shape[0]) layer_shuffled = layer_flat[shuffle_idx] - # 3.4 Reshape to layer_shape (32,9) + # Reshape to target layer shape layer = layer_shuffled.reshape(layer_shape) layers.append(layer) - # 4. Stack all layers to get the final tensor - return torch.stack(layers, dim=0) # shape (58,32,9) + # Stack all layers + return torch.stack(layers, dim=0) # (num_layers, *layer_shape) -def warm_up(): +def warm_up() -> None: + """ + Warm up FlashLB algorithm with dummy data + Pre-compiles numba functions and initializes state + """ exam_config = DynamicConfig() exam_config.ep_worldsize = 32 exam_config.num_die_per_host = 16 algo = FlashLB(exam_config) - # Generate target tensor + # Generate dummy expert deployment tensor expert_tensor = generate_layered_experts(num_layers=58, layer_shape=(32, 9)) - - algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (58, 32, 9))) + # Run rebalance with dummy workload data + algo.rebalance_experts(expert_tensor, torch.randint(1, 1000, (100, 58, 32, 9))) diff --git a/vllm_ascend/eplb/eplb_updator.py b/vllm_ascend/eplb/eplb_updator.py index 4d3da725..fa06737d 100644 --- a/vllm_ascend/eplb/eplb_updator.py +++ b/vllm_ascend/eplb/eplb_updator.py @@ -30,6 +30,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess class EplbUpdator: def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process): self.eplb_config = eplb_config + self.multi_stage = eplb_config.eplb_policy_type == 3 self.init_eplb(self.eplb_config.expert_map_path, process) self.eplb_loader = loader self.eplb_process = eplb_process @@ -131,9 +132,14 @@ class EplbUpdator: def compute_and_set_moe_load(self): local_load = self.adaptor.get_rank_expert_workload() - moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]) + moe_load = ( + self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu() + ) - self.shared_dict["moe_load"] = moe_load.cpu() + if self.multi_stage: + moe_load = moe_load.permute(2, 0, 1, 3) + + self.shared_dict["moe_load"] = moe_load logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}") return moe_load diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index d5df0ad8..4304931b 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -299,7 +299,13 @@ class AscendFusedMoE(FusedMoE): get_compressed_expert_map(self._expert_map), ) if self.dynamic_eplb: + self.multi_stage = False self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu() + if eplb_config.eplb_policy_type == 3: + self.multi_stage = True + self.load_counter = torch.tensor(0, dtype=torch.int32, device="npu") + self.num_iter = eplb_config.expert_heat_collection_interval + self.moe_load = torch.zeros((self.num_iter, self.local_num_experts), dtype=torch.int32, device="npu") self.moe_config.num_experts = self.global_num_experts self.moe_config.num_local_experts = self.local_num_experts @@ -361,6 +367,8 @@ class AscendFusedMoE(FusedMoE): def clear_moe_load(self): if self.moe_load is not None: self.moe_load.zero_() + if self.multi_stage: + self.load_counter.zero_() def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, @@ -493,7 +501,14 @@ class AscendFusedMoE(FusedMoE): if group_list_type == 1 else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) ) - self.moe_load.add_(local_load) + if self.multi_stage: + cur_iter = torch.remainder(self.load_counter, self.num_iter) + self.moe_load.index_add_( + dim=0, index=cur_iter, source=local_load.to(torch.int32, non_blocking=True).view(1, -1) + ) + self.load_counter.add_(1) + else: + self.moe_load.add_(local_load) routed_out = forward_context.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results,