Support per-step heat collection and enhance FlashLB for multi-stage load balancing (#6477)

# Feature: FlashLB algorithm

## Purpose

This Pull Request enhances the EPLB (Expert Parallelism Load Balancing)
system by introducing a novel load balancing algorithm: FlashLB.
1. The default algorithm adopts two separate sub-procedures to optimize
expert replication and placement independently:

a. **Expert Replica Allotment Sub-procedure** : Determines the number of
replicas for all experts. At each step, it greedily adds one more
replica to the expert with the highest per-replica load, aiming to
minimize load skew at the expert replica granularity (Min Max Replica,
MMR).

b. **Expert Replica Placement Sub-procedure** : Distributes all replicas
across devices. First, it sorts the generated replicas in descending
order of hotness, then iteratively places the currently hottest replica
onto the device with the lowest cumulative load and available slots.
However, this simplistic combination of two separate procedures lacks
synergy and often leads to sub-optimal load balancing. For example, in
the simple scenario illustrated below: Given 8 logical experts with
hotness values [600, 560, 120, 120, 20, 10, 10, 10], and 2 replicas
allocated per device across 8 devices, the default EPLB algorithm
results in a maximum per-device hotness of 232 (peak-average load ratio
1.28), while our proposed FlashLB algorithm reduces this value to 205
(peak-average load ratio 1.13).

<figure><img
src="https://github.com/user-attachments/assets/b9b10fab-651e-4524-9942-adbca8d044a4"
width="90%"</figure>

2. The default algorithm simply aggregates hotness measurements across
the entire profiling window. While this provides a coarse approximation
of the hotness distribution, it fails to capture the time-phased
variations and temporal correlations in expert hotness (both within and
between experts) across iterations—phenomena that have been observed in
real-world scenarios. Such single-point hotness estimation degrades the
solution quality of the load balancing algorithm.

3. The default algorithm regularly recalculates updated expert placement
results for all layers without discrimination. Considering that
excessive expert updates can impact Service Level Objectives (SLOs),
such full-scale redeployment leads to excessively high adjustment
overhead, which negatively affects end-to-end performance.

## FlashLB Algorithm Principle

### 1. Joint Optimization of Replica Allotment and Placement

FlashLB achieves joint optimization of replica allotment and placement
through a novel tree search approach, combined with carefully designed e
Fl fficient pruning and lightweight look-ahead estimation. We partition
all experts into several subsets, and for each subset, hierarchically
determine the optimal replica count and placement. Leveraging efficient
pruning and lightweight look-ahead estimation, the process consistently
aims to optimize the globally expected inter-device load balance degree
(considering both deployed and unexplored experts) while ensuring
sufficient computational efficiency. Additionally, precompilation
techniques are employed for acceleration, delivering load balancing that
is both high-quality and practically efficient.
### 2. Multi-Episode Enhancement

Instead of performing full-duration averaging like the default
algorithm, FlashLB partitions each profiling interval (e.g., 1024
iterations) into multiple consecutive smaller episodes (e.g., 16
iterations). This preserves hotness fluctuation and correlation
information. It then constructs a multi-objective optimization problem
to co-optimize these episodes simultaneously, enabling adaptability to
interleaved hotness patterns and improving statistical robustness.

### 3. Layer-wise Cherry-Picking Redeployment

To reduce the overhead of frequent expert redeployment, FlashLB
introduces a cherry-picking redeployment scheme. During each algorithmic
decision cycle, it real-time tracks load balance degree of all layers
and triggers expert placement updates only for those layers whose
peak-average ratio exceeds a predefined threshold. This avoids
unnecessary redeployment for stable layers, significantly reducing
adjustment overhead and thereby improving end-to-end performance gains.

## Co-author:

Co-authored-by: Skywalker-EP 173723846@qq.com

This PR mainly introduces two key optimizations for load balancing
scheduling:
1. **Add per-step heat collection function**:
Support real-time collection of per-step heat information during model
inference. This enables more fine-grained load balancing decisions by
taking per-step heat as the optimization target, improving scheduling
accuracy for dynamic and fluctuating workloads.

2. **Update FlashLB algorithm**:
Upgrade the FlashLB scheduling logic to better adapt to multi-stage heat
distribution scenarios. The improved algorithm can comprehensively
perceive and utilize multi-stage heat characteristics, achieving more
stable and efficient load balancing under complex expert deployment and
dynamic traffic patterns.

---------

Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>
Signed-off-by: xuzewei28 <xuzewei2@h-partners.com>
Co-authored-by: xuzewei28 <xuzewei2@h-partners.com>
This commit is contained in:
Mercykid-bash
2026-03-12 15:49:09 +08:00
committed by GitHub
parent abe72d7cb9
commit 132f3c5d0a
4 changed files with 960 additions and 510 deletions

View File

@@ -34,6 +34,7 @@ class EplbWorker:
self.old_expert_maps = None
self.enable_d2d = enable_d2d
self.rank_id = dist.get_rank()
self.multi_stage = policy_type == 3
def do_update(self):
# put data in to queue
@@ -62,7 +63,10 @@ class EplbWorker:
_, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
if self.rank_id == 0:
hotness = self._calculate_hotness(old_placement, load_info)
if self.multi_stage:
hotness = self._calculate_hotness(old_placement, load_info.sum(0))
else:
hotness = self._calculate_hotness(old_placement, load_info)
current_mean, current_max = self._compute_imbalance(old_placement, hotness)
update_mean, update_max = self._compute_imbalance(new_placement, hotness)
logger.info(

File diff suppressed because it is too large Load Diff

View File

@@ -30,6 +30,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
class EplbUpdator:
def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process):
self.eplb_config = eplb_config
self.multi_stage = eplb_config.eplb_policy_type == 3
self.init_eplb(self.eplb_config.expert_map_path, process)
self.eplb_loader = loader
self.eplb_process = eplb_process
@@ -131,9 +132,14 @@ class EplbUpdator:
def compute_and_set_moe_load(self):
local_load = self.adaptor.get_rank_expert_workload()
moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:])
moe_load = (
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
)
self.shared_dict["moe_load"] = moe_load.cpu()
if self.multi_stage:
moe_load = moe_load.permute(2, 0, 1, 3)
self.shared_dict["moe_load"] = moe_load
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
return moe_load

View File

@@ -299,7 +299,13 @@ class AscendFusedMoE(FusedMoE):
get_compressed_expert_map(self._expert_map),
)
if self.dynamic_eplb:
self.multi_stage = False
self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
if eplb_config.eplb_policy_type == 3:
self.multi_stage = True
self.load_counter = torch.tensor(0, dtype=torch.int32, device="npu")
self.num_iter = eplb_config.expert_heat_collection_interval
self.moe_load = torch.zeros((self.num_iter, self.local_num_experts), dtype=torch.int32, device="npu")
self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts
@@ -361,6 +367,8 @@ class AscendFusedMoE(FusedMoE):
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
if self.multi_stage:
self.load_counter.zero_()
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
@@ -493,7 +501,14 @@ class AscendFusedMoE(FusedMoE):
if group_list_type == 1
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
)
self.moe_load.add_(local_load)
if self.multi_stage:
cur_iter = torch.remainder(self.load_counter, self.num_iter)
self.moe_load.index_add_(
dim=0, index=cur_iter, source=local_load.to(torch.int32, non_blocking=True).view(1, -1)
)
self.load_counter.add_(1)
else:
self.moe_load.add_(local_load)
routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,