Files
xc-llm-ascend/vllm_ascend/eplb/utils.py
offline893 76844eec78 Dynamic Expert Load Balance with Zero-like-overhead (#2956)
### Motivation
Currently dynamically experts balancing would stop-the-world.
Asynchronously expert load balancing would be better without flowing
problems:

Host-bound latency:
There are many cpu operations during EPLB such as
eplb-algorithm、creating p2p ops、and log2phy expert converting would
spend long cpu time, as ~1s.
Communication latency: The transfer time would cost much in the
situation without nvlink. As the weight of an expert maybe transfer to
multiple new positions, thus N times send/recv for one expert, with
result long latency. We had tested that batch_isend_irecv cost more
100ms for 16 experts weight transmission in A2 server of ascend.

SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms
cost for each layer while benefit 5ms-8ms decode latency with ep_size =
64.
The following updates have been made:
1、expert distribution recording with lower cost.
2、async cpu computing for eplb algo and other python operator.
3、new eplb algo with less expert rebalancing while almost the same
effect.
### Proposed Change
We will gradually migrate the EPLB logic to the VLLM community and
implement a generalized design. Relevant RFC:
https://github.com/vllm-project/vllm/issues/22246
The overall workflow involves:
<img width="801" height="302"
alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c"
src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed"
/>
1. Record experts distribution during forward. We using expert_token_num
after disptach instead of topk_ids, thus we got much smaller tensor
shape to reduce cost of hbm recording and add-operator.
2. Do all-gather for experts distribution. Using all-gather instead of
all-reduce as less traffic volume.
3. Wake up eplb worker process with experts distribution when
num_iterations comes. Run eplb algorithm in eplb worker.
4. Generate p2p send/recv ops and other operator such as log2phy would
cost long cpu time.
5. Lanch ibatch_send_recv in async_stream before forward.
6. After forward, wait for the ibatch_send_recv finish, then do uapte
expert map and expert weights.
### Co-author
Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con
Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn
Co-authored-by: qmkakaxi wjh1594260677@qq.com
Co-authored-by: Skywalker-EP 173723846@qq.com


- vLLM version: v0.10.2
- vLLM main:
567939953b

---------

Signed-off-by: offline0806 <z00858301@china.huawei.com>
Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00

78 lines
2.8 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Todo: Once https://github.com/vllm-project/vllm/pull/23553 is merged in vllm. Remove this model register.
import types
import torch
def get_expert_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_map()
def get_log2phy_map(self, layer_id):
return self.model.layers[layer_id].mlp.experts.get_log2phy_map()
def get_all_expert_map(self, num_moe_layers):
all_loads = []
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(num_moe_layers):
load_tensor = self.get_expert_map(
layer_id + num_dense_layers) # (num_experts_per_layer,)
all_loads.append(load_tensor)
return torch.stack(all_loads, dim=0)
def get_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
all_moe_loads = torch.stack(
[self.model.layers[layer_id + num_dense_layers].mlp.experts.moe_load \
for layer_id in range(self.num_moe_layers)],
dim=0
)
return all_moe_loads
def clear_all_moe_loads(self):
num_dense_layers = self.num_dense_layers if hasattr(
self, "num_dense_layers") else 0
for layer_id in range(self.num_moe_layers):
self.model.layers[layer_id +
num_dense_layers].mlp.experts.clear_moe_load()
def model_register(model, model_config):
model.get_expert_map = types.MethodType(get_expert_map, model)
model.get_log2phy_map = types.MethodType(get_log2phy_map, model)
model.get_all_expert_map = types.MethodType(get_all_expert_map, model)
model.get_all_moe_loads = types.MethodType(get_all_moe_loads, model)
model.clear_all_moe_loads = types.MethodType(clear_all_moe_loads, model)
config = model_config.hf_config
if config.model_type == "qwen3_moe":
model.num_moe_layers = config.num_hidden_layers
elif config.model_type == "deepseek_v2" or config.model_type == "deepseek_v3":
num_dense_layers = config.first_k_dense_replace
model.num_moe_layers = config.num_hidden_layers - num_dense_layers
else:
raise NotImplementedError("EPLB is not supported.")