Support per-step heat collection and enhance FlashLB for multi-stage load balancing (#6477)
# Feature: FlashLB algorithm ## Purpose This Pull Request enhances the EPLB (Expert Parallelism Load Balancing) system by introducing a novel load balancing algorithm: FlashLB. 1. The default algorithm adopts two separate sub-procedures to optimize expert replication and placement independently: a. **Expert Replica Allotment Sub-procedure** : Determines the number of replicas for all experts. At each step, it greedily adds one more replica to the expert with the highest per-replica load, aiming to minimize load skew at the expert replica granularity (Min Max Replica, MMR). b. **Expert Replica Placement Sub-procedure** : Distributes all replicas across devices. First, it sorts the generated replicas in descending order of hotness, then iteratively places the currently hottest replica onto the device with the lowest cumulative load and available slots. However, this simplistic combination of two separate procedures lacks synergy and often leads to sub-optimal load balancing. For example, in the simple scenario illustrated below: Given 8 logical experts with hotness values [600, 560, 120, 120, 20, 10, 10, 10], and 2 replicas allocated per device across 8 devices, the default EPLB algorithm results in a maximum per-device hotness of 232 (peak-average load ratio 1.28), while our proposed FlashLB algorithm reduces this value to 205 (peak-average load ratio 1.13). <figure><img src="https://github.com/user-attachments/assets/b9b10fab-651e-4524-9942-adbca8d044a4" width="90%"</figure> 2. The default algorithm simply aggregates hotness measurements across the entire profiling window. While this provides a coarse approximation of the hotness distribution, it fails to capture the time-phased variations and temporal correlations in expert hotness (both within and between experts) across iterations—phenomena that have been observed in real-world scenarios. Such single-point hotness estimation degrades the solution quality of the load balancing algorithm. 3. The default algorithm regularly recalculates updated expert placement results for all layers without discrimination. Considering that excessive expert updates can impact Service Level Objectives (SLOs), such full-scale redeployment leads to excessively high adjustment overhead, which negatively affects end-to-end performance. ## FlashLB Algorithm Principle ### 1. Joint Optimization of Replica Allotment and Placement FlashLB achieves joint optimization of replica allotment and placement through a novel tree search approach, combined with carefully designed e Fl fficient pruning and lightweight look-ahead estimation. We partition all experts into several subsets, and for each subset, hierarchically determine the optimal replica count and placement. Leveraging efficient pruning and lightweight look-ahead estimation, the process consistently aims to optimize the globally expected inter-device load balance degree (considering both deployed and unexplored experts) while ensuring sufficient computational efficiency. Additionally, precompilation techniques are employed for acceleration, delivering load balancing that is both high-quality and practically efficient. ### 2. Multi-Episode Enhancement Instead of performing full-duration averaging like the default algorithm, FlashLB partitions each profiling interval (e.g., 1024 iterations) into multiple consecutive smaller episodes (e.g., 16 iterations). This preserves hotness fluctuation and correlation information. It then constructs a multi-objective optimization problem to co-optimize these episodes simultaneously, enabling adaptability to interleaved hotness patterns and improving statistical robustness. ### 3. Layer-wise Cherry-Picking Redeployment To reduce the overhead of frequent expert redeployment, FlashLB introduces a cherry-picking redeployment scheme. During each algorithmic decision cycle, it real-time tracks load balance degree of all layers and triggers expert placement updates only for those layers whose peak-average ratio exceeds a predefined threshold. This avoids unnecessary redeployment for stable layers, significantly reducing adjustment overhead and thereby improving end-to-end performance gains. ## Co-author: Co-authored-by: Skywalker-EP 173723846@qq.com This PR mainly introduces two key optimizations for load balancing scheduling: 1. **Add per-step heat collection function**: Support real-time collection of per-step heat information during model inference. This enables more fine-grained load balancing decisions by taking per-step heat as the optimization target, improving scheduling accuracy for dynamic and fluctuating workloads. 2. **Update FlashLB algorithm**: Upgrade the FlashLB scheduling logic to better adapt to multi-stage heat distribution scenarios. The improved algorithm can comprehensively perceive and utilize multi-stage heat characteristics, achieving more stable and efficient load balancing under complex expert deployment and dynamic traffic patterns. --------- Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> Signed-off-by: xuzewei28 <xuzewei2@h-partners.com> Co-authored-by: xuzewei28 <xuzewei2@h-partners.com>
This commit is contained in:
@@ -34,6 +34,7 @@ class EplbWorker:
|
||||
self.old_expert_maps = None
|
||||
self.enable_d2d = enable_d2d
|
||||
self.rank_id = dist.get_rank()
|
||||
self.multi_stage = policy_type == 3
|
||||
|
||||
def do_update(self):
|
||||
# put data in to queue
|
||||
@@ -62,7 +63,10 @@ class EplbWorker:
|
||||
_, _, new_placement = self.calculate_rebalance_experts(load_info, old_placement)
|
||||
|
||||
if self.rank_id == 0:
|
||||
hotness = self._calculate_hotness(old_placement, load_info)
|
||||
if self.multi_stage:
|
||||
hotness = self._calculate_hotness(old_placement, load_info.sum(0))
|
||||
else:
|
||||
hotness = self._calculate_hotness(old_placement, load_info)
|
||||
current_mean, current_max = self._compute_imbalance(old_placement, hotness)
|
||||
update_mean, update_max = self._compute_imbalance(new_placement, hotness)
|
||||
logger.info(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,7 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
||||
class EplbUpdator:
|
||||
def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process):
|
||||
self.eplb_config = eplb_config
|
||||
self.multi_stage = eplb_config.eplb_policy_type == 3
|
||||
self.init_eplb(self.eplb_config.expert_map_path, process)
|
||||
self.eplb_loader = loader
|
||||
self.eplb_process = eplb_process
|
||||
@@ -131,9 +132,14 @@ class EplbUpdator:
|
||||
|
||||
def compute_and_set_moe_load(self):
|
||||
local_load = self.adaptor.get_rank_expert_workload()
|
||||
moe_load = self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:])
|
||||
moe_load = (
|
||||
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
|
||||
)
|
||||
|
||||
self.shared_dict["moe_load"] = moe_load.cpu()
|
||||
if self.multi_stage:
|
||||
moe_load = moe_load.permute(2, 0, 1, 3)
|
||||
|
||||
self.shared_dict["moe_load"] = moe_load
|
||||
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
|
||||
|
||||
return moe_load
|
||||
|
||||
@@ -299,7 +299,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
get_compressed_expert_map(self._expert_map),
|
||||
)
|
||||
if self.dynamic_eplb:
|
||||
self.multi_stage = False
|
||||
self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
|
||||
if eplb_config.eplb_policy_type == 3:
|
||||
self.multi_stage = True
|
||||
self.load_counter = torch.tensor(0, dtype=torch.int32, device="npu")
|
||||
self.num_iter = eplb_config.expert_heat_collection_interval
|
||||
self.moe_load = torch.zeros((self.num_iter, self.local_num_experts), dtype=torch.int32, device="npu")
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
@@ -361,6 +367,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
def clear_moe_load(self):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
if self.multi_stage:
|
||||
self.load_counter.zero_()
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
||||
@@ -493,7 +501,14 @@ class AscendFusedMoE(FusedMoE):
|
||||
if group_list_type == 1
|
||||
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
)
|
||||
self.moe_load.add_(local_load)
|
||||
if self.multi_stage:
|
||||
cur_iter = torch.remainder(self.load_counter, self.num_iter)
|
||||
self.moe_load.index_add_(
|
||||
dim=0, index=cur_iter, source=local_load.to(torch.int32, non_blocking=True).view(1, -1)
|
||||
)
|
||||
self.load_counter.add_(1)
|
||||
else:
|
||||
self.moe_load.add_(local_load)
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
|
||||
Reference in New Issue
Block a user