Dynamic Expert Load Balance with Zero-like-overhead (#2956)

### Motivation
Currently dynamically experts balancing would stop-the-world.
Asynchronously expert load balancing would be better without flowing
problems:

Host-bound latency:
There are many cpu operations during EPLB such as
eplb-algorithm、creating p2p ops、and log2phy expert converting would
spend long cpu time, as ~1s.
Communication latency: The transfer time would cost much in the
situation without nvlink. As the weight of an expert maybe transfer to
multiple new positions, thus N times send/recv for one expert, with
result long latency. We had tested that batch_isend_irecv cost more
100ms for 16 experts weight transmission in A2 server of ascend.

SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms
cost for each layer while benefit 5ms-8ms decode latency with ep_size =
64.
The following updates have been made:
1、expert distribution recording with lower cost.
2、async cpu computing for eplb algo and other python operator.
3、new eplb algo with less expert rebalancing while almost the same
effect.
### Proposed Change
We will gradually migrate the EPLB logic to the VLLM community and
implement a generalized design. Relevant RFC:
https://github.com/vllm-project/vllm/issues/22246
The overall workflow involves:
<img width="801" height="302"
alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c"
src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed"
/>
1. Record experts distribution during forward. We using expert_token_num
after disptach instead of topk_ids, thus we got much smaller tensor
shape to reduce cost of hbm recording and add-operator.
2. Do all-gather for experts distribution. Using all-gather instead of
all-reduce as less traffic volume.
3. Wake up eplb worker process with experts distribution when
num_iterations comes. Run eplb algorithm in eplb worker.
4. Generate p2p send/recv ops and other operator such as log2phy would
cost long cpu time.
5. Lanch ibatch_send_recv in async_stream before forward.
6. After forward, wait for the ibatch_send_recv finish, then do uapte
expert map and expert weights.
### Co-author
Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con
Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn
Co-authored-by: qmkakaxi wjh1594260677@qq.com
Co-authored-by: Skywalker-EP 173723846@qq.com


- vLLM version: v0.10.2
- vLLM main:
567939953b

---------

Signed-off-by: offline0806 <z00858301@china.huawei.com>
Co-authored-by: offline0806 <z00858301@china.huawei.com>
This commit is contained in:
offline893
2025-09-17 10:36:43 +08:00
committed by GitHub
parent ae758dda05
commit 76844eec78
30 changed files with 2891 additions and 47 deletions

View File

@@ -26,6 +26,7 @@ from collections.abc import Iterator
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import numpy as np
@@ -93,6 +94,12 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
from vllm_ascend.eplb.core.eplb_worker import EplbProcess
from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
@@ -422,6 +429,23 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.bool,
device=self.device,
)
self.dynamic_eplb = ascend_config.dynamic_eplb
if self.dynamic_eplb:
self.is_eplb_warmuped = False
self.eplb_loader = D2DExpertWeightLoader()
self.manager = Manager()
self.shared_dict = self.manager.dict({
"expert_map": None,
"moe_load": None,
"expert_maps": None
})
self.eplb_process = EplbProcess(shared_dict=self.shared_dict,
policy_type=1,
enable_d2d=True)
self.process = self.eplb_process._launch_process()
ascend_config = get_ascend_config()
self.eplb_updator = EplbUpdator(ascend_config, self.eplb_loader,
self.eplb_process, self.process)
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.async_output_copy_stream = torch.npu.Stream() if \
@@ -1736,12 +1760,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
if self.dynamic_eplb:
self.eplb_updator.forward_before()
(attn_metadata, positions, num_scheduled_tokens_np,
num_input_tokens, num_tokens_across_dp, maybe_padded_num_tokens,
logits_indices, spec_decode_metadata, input_ids, inputs_embeds,
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
if self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
@@ -2004,7 +2035,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
captured_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
logger.info("Profile execute duration [%s]:%s", captured_name,
" ".join(dr_str))
if self.dynamic_eplb:
self.eplb_updator.forward_end()
if not self.use_async_scheduling:
return model_runner_output
@@ -2169,6 +2201,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs,
skip_attn=True)
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.forward_before()
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
if self.is_multimodal_model:
@@ -2251,6 +2286,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb:
self.eplb_updator.take_update_info_from_eplb_process()
self.eplb_updator.forward_end()
return hidden_states
@contextmanager
@@ -2357,12 +2397,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
max_task = max(output_size.items(), key=lambda x: x[1])[0]
return self._dummy_pooler_run_task(hidden_states, max_task)
def eplb_warmup(self):
if self.dynamic_eplb and not self.is_eplb_warmuped:
self.is_eplb_warmuped = True
self.eplb_adaptor = VllmEplbAdaptor(model=self.model)
self.eplb_loader.set_adator(self.eplb_adaptor)
self.eplb_updator.set_adaptor(self.eplb_adaptor)
self.eplb_updator.warm_up_eplb()
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
if self.dynamic_eplb:
model_register(self.model, self.model_config)
if is_310p():
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear,

View File

@@ -250,6 +250,7 @@ class NPUWorker(WorkerBase):
def compile_or_warm_up_model(self) -> None:
# Note: need to adapt for graph mode.
self.model_runner.eplb_warmup()
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
or []).copy()
if not self.model_config.enforce_eager: