Files
xc-llm-ascend/vllm_ascend/eplb/eplb_updator.py
Mercykid-bash 132f3c5d0a Support per-step heat collection and enhance FlashLB for multi-stage load balancing (#6477)
# Feature: FlashLB algorithm

## Purpose

This Pull Request enhances the EPLB (Expert Parallelism Load Balancing)
system by introducing a novel load balancing algorithm: FlashLB.
1. The default algorithm adopts two separate sub-procedures to optimize
expert replication and placement independently:

a. **Expert Replica Allotment Sub-procedure** : Determines the number of
replicas for all experts. At each step, it greedily adds one more
replica to the expert with the highest per-replica load, aiming to
minimize load skew at the expert replica granularity (Min Max Replica,
MMR).

b. **Expert Replica Placement Sub-procedure** : Distributes all replicas
across devices. First, it sorts the generated replicas in descending
order of hotness, then iteratively places the currently hottest replica
onto the device with the lowest cumulative load and available slots.
However, this simplistic combination of two separate procedures lacks
synergy and often leads to sub-optimal load balancing. For example, in
the simple scenario illustrated below: Given 8 logical experts with
hotness values [600, 560, 120, 120, 20, 10, 10, 10], and 2 replicas
allocated per device across 8 devices, the default EPLB algorithm
results in a maximum per-device hotness of 232 (peak-average load ratio
1.28), while our proposed FlashLB algorithm reduces this value to 205
(peak-average load ratio 1.13).

<figure><img
src="https://github.com/user-attachments/assets/b9b10fab-651e-4524-9942-adbca8d044a4"
width="90%"</figure>

2. The default algorithm simply aggregates hotness measurements across
the entire profiling window. While this provides a coarse approximation
of the hotness distribution, it fails to capture the time-phased
variations and temporal correlations in expert hotness (both within and
between experts) across iterations—phenomena that have been observed in
real-world scenarios. Such single-point hotness estimation degrades the
solution quality of the load balancing algorithm.

3. The default algorithm regularly recalculates updated expert placement
results for all layers without discrimination. Considering that
excessive expert updates can impact Service Level Objectives (SLOs),
such full-scale redeployment leads to excessively high adjustment
overhead, which negatively affects end-to-end performance.

## FlashLB Algorithm Principle

### 1. Joint Optimization of Replica Allotment and Placement

FlashLB achieves joint optimization of replica allotment and placement
through a novel tree search approach, combined with carefully designed e
Fl fficient pruning and lightweight look-ahead estimation. We partition
all experts into several subsets, and for each subset, hierarchically
determine the optimal replica count and placement. Leveraging efficient
pruning and lightweight look-ahead estimation, the process consistently
aims to optimize the globally expected inter-device load balance degree
(considering both deployed and unexplored experts) while ensuring
sufficient computational efficiency. Additionally, precompilation
techniques are employed for acceleration, delivering load balancing that
is both high-quality and practically efficient.
### 2. Multi-Episode Enhancement

Instead of performing full-duration averaging like the default
algorithm, FlashLB partitions each profiling interval (e.g., 1024
iterations) into multiple consecutive smaller episodes (e.g., 16
iterations). This preserves hotness fluctuation and correlation
information. It then constructs a multi-objective optimization problem
to co-optimize these episodes simultaneously, enabling adaptability to
interleaved hotness patterns and improving statistical robustness.

### 3. Layer-wise Cherry-Picking Redeployment

To reduce the overhead of frequent expert redeployment, FlashLB
introduces a cherry-picking redeployment scheme. During each algorithmic
decision cycle, it real-time tracks load balance degree of all layers
and triggers expert placement updates only for those layers whose
peak-average ratio exceeds a predefined threshold. This avoids
unnecessary redeployment for stable layers, significantly reducing
adjustment overhead and thereby improving end-to-end performance gains.

## Co-author:

Co-authored-by: Skywalker-EP 173723846@qq.com

This PR mainly introduces two key optimizations for load balancing
scheduling:
1. **Add per-step heat collection function**:
Support real-time collection of per-step heat information during model
inference. This enables more fine-grained load balancing decisions by
taking per-step heat as the optimization target, improving scheduling
accuracy for dynamic and fluctuating workloads.

2. **Update FlashLB algorithm**:
Upgrade the FlashLB scheduling logic to better adapt to multi-stage heat
distribution scenarios. The improved algorithm can comprehensively
perceive and utilize multi-stage heat characteristics, achieving more
stable and efficient load balancing under complex expert deployment and
dynamic traffic patterns.

---------

Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>
Signed-off-by: xuzewei28 <xuzewei2@h-partners.com>
Co-authored-by: xuzewei28 <xuzewei2@h-partners.com>
2026-03-12 15:49:09 +08:00

178 lines
7.0 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this updator.
import numpy
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import logger
from vllm_ascend.distributed.parallel_state import get_dynamic_eplb_group
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
class EplbUpdator:
def __init__(self, eplb_config, loader: D2DExpertWeightLoader, eplb_process: EplbProcess, process):
self.eplb_config = eplb_config
self.multi_stage = eplb_config.eplb_policy_type == 3
self.init_eplb(self.eplb_config.expert_map_path, process)
self.eplb_loader = loader
self.eplb_process = eplb_process
self.shared_dict = self.eplb_process.shared_dict
self.comm_group = get_dynamic_eplb_group()
def set_adaptor(self, adaptor: VllmEplbAdaptor):
self.adaptor = adaptor
self.num_moe_layers = self.adaptor.num_moe_layers
local_load = self.adaptor.get_rank_expert_workload()
self.world_size = dist.get_world_size()
self.device = local_load.device
self.eplb_loader.num_layers = self.adaptor.num_dense_layers + self.adaptor.num_moe_layers
def init_eplb(self, expert_map_path, process):
self.rank_id = dist.get_rank()
self.num_expert_load_gather = 10
self.periodic_load_gather = True
self.expert_heat_collection_interval: torch.int64 = self.eplb_config.expert_heat_collection_interval
self.expert_map_path = expert_map_path
self.expert_map_record_path = self.eplb_config.expert_map_record_path
try:
if not envs.VLLM_ALLOW_EXPERT_LOAD_COLLECTING:
self.num_expert_load_gather = self.expert_heat_collection_interval
self.periodic_load_gather = False
except Exception:
self.num_expert_load_gather = self.expert_heat_collection_interval
self.periodic_load_gather = False
self.reqs = []
self.update_info_all = []
self.cur_iterations: torch.int64 = 0
self.algorithm_execution_interval: torch.int64 = self.eplb_config.algorithm_execution_interval
self.process = process
logger.info(f"[ModelRunner] Launched EPLB process (pid={self.process.pid})")
def update_iteration(self):
self.cur_iterations += 1
if self.cur_iterations == (
self.expert_heat_collection_interval + self.algorithm_execution_interval + self.num_moe_layers
):
if self.expert_map_record_path is not None:
self.adaptor._export_tensor_to_file(self.shared_dict["expert_maps"], self.expert_map_record_path)
self.adaptor.model.clear_all_moe_loads()
self.cur_iterations = 0
def get_update_info_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval + self.algorithm_execution_interval - 1)
def wakeup_eplb_worker_flag(self):
return self.cur_iterations == (self.expert_heat_collection_interval - 1)
def update_expert_weight_flag(self):
weight_update_counter = self.cur_iterations - (
self.expert_heat_collection_interval + self.algorithm_execution_interval
)
return weight_update_counter >= 0 and weight_update_counter < self.num_moe_layers
def wakeup_eplb_worker(self):
self.eplb_process.planner_q.put(1)
def forward_before(self):
# Batch after eplb process being triggered, get update info provided by eplb process
if self.get_update_info_flag():
self.update_info_all = self.eplb_process.block_update_q.get()
if self.update_expert_weight_flag():
(expert_send_info, expert_recv_info, updated_expert_map, log2phy_map, layer_id) = self.update_info_all.pop(
0
)
log2phy_map_this_rank = torch.from_numpy(numpy.array(log2phy_map))
self.eplb_loader.set_log2phy_map(log2phy_map_this_rank)
updated_expert_map_this_rank = torch.from_numpy(numpy.array(updated_expert_map))
self.eplb_loader.generate_expert_d2d_transfer_task(
expert_send_info,
expert_recv_info,
updated_expert_map_this_rank,
layer_id + self.adaptor.num_dense_layers,
)
# set asynchronous stream for d2d expert weight update
self.reqs = []
self.eplb_loader.asyn_expert_weight_transfer(self.reqs)
def forward_end(self):
if self.wakeup_eplb_worker_flag():
self.compute_and_set_moe_load()
self.wakeup_eplb_worker()
if self.update_expert_weight_flag() and self.expert_map_record_path is None:
self.eplb_loader.update_expert_map_and_weight(self.reqs)
self.update_iteration()
def compute_and_set_moe_load(self):
local_load = self.adaptor.get_rank_expert_workload()
moe_load = (
self.comm_group.all_gather(local_load, dim=0).reshape(-1, self.world_size, *local_load.shape[1:]).cpu()
)
if self.multi_stage:
moe_load = moe_load.permute(2, 0, 1, 3)
self.shared_dict["moe_load"] = moe_load
logger.debug(f"[ModelRunner] Updated shared_dict['moe_load'] shape={moe_load.shape}")
return moe_load
def warm_up_eplb(self):
self.shared_dict["expert_maps"] = self.adaptor.get_global_expert_map()
self.compute_and_set_moe_load()
src_tensor = torch.empty((1,), device=self.device)
comm_op_list = []
for dst_rank in range(self.world_size):
if dst_rank == self.rank_id:
continue
comm_op_list.append(dist.P2POp(dist.isend, src_tensor, dst_rank))
for src_rank in range(self.world_size):
if src_rank == self.rank_id:
continue
comm_op_list.append(dist.P2POp(dist.irecv, src_tensor, src_rank))
if comm_op_list:
reqs = dist.batch_isend_irecv(comm_op_list)
for req in reqs:
req.wait()
def shutdown(self):
"""
Clean up the EPLB process.
"""
if self.process.is_alive():
self.process.terminate()
self.process.join()
logger.info("[ModelRunner] EPLB process terminated")