Files
xc-llm-ascend/vllm_ascend/ops/common_fused_moe.py

480 lines
20 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
import os.path
from typing import Callable, Optional
import torch
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
[CustomOp] Register AscendSharedFusedMoE custom op (#2980) ### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/486c5599e3ab7d721c94dd01e89c87742c01e1ac --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-19 19:05:01 +08:00
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
determine_default_log2phy_map)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl)
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
self.transpose = True
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
assert global_num_experts is not None
return fused_experts_moge(
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
if self.transpose:
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
self.transpose = False
else:
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendFusedMoE(FusedMoE):
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
moe_counter = -1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
ascend_config = get_ascend_config()
self.dynamic_eplb = ascend_config.dynamic_eplb
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
# static eplb initializing with expert_map_path
if self.expert_map_path and os.path.exists(
self.expert_map_path) and os.access(self.expert_map_path,
os.R_OK):
self.expert_load_balancer = ExpertLoadBalancer(
self.expert_map_path, self.global_num_experts)
self.local_num_experts, self.expert_map = (
self.expert_load_balancer.get_rank_placement_map(
self.moe_instance_id, self.ep_rank))
self.log2phy = self.expert_load_balancer.get_rank_log2phy_map(
self.moe_instance_id, self.ep_rank).npu()
self.global_redundant_expert_num = (
self.expert_load_balancer.get_global_redundant_expert_num())
else:
# init moe.
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size, self.ep_rank, self.global_num_experts)
# dynamic eplb initializing with not expert_map_path
if self.dynamic_eplb:
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.local_num_experts, self.expert_map = determine_default_expert_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
self.log2phy = determine_default_log2phy_map(
self.global_num_experts, self.ep_size, self.ep_rank,
self.global_redundant_expert_num)
local_num_experts = (torch.sum(
self.expert_map != -1) if self.expert_map is not None else
self.global_num_experts)
if self.dynamic_eplb:
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
for method in {
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
NaiveMulticastCommImpl
}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
def update_expert_map(self, new_expert_map):
self.expert_map = new_expert_map
def get_map(self):
return self.expert_map
def get_log2phy_map(self):
return self.logical_to_physical_map
def clear_moe_load(self):
if self.moe_load is not None:
self.moe_load.zero_()
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
if isinstance(final_hidden_states, tuple):
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
if self.dynamic_eplb:
self.moe_load += expert_tokens if group_list_type else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
return final_hidden_states
def transpose_weight(self, loaded_weight, expert_data, shard_dim):
# Ensure training and inference weight shapes match during RL weight updates
if (
loaded_weight.shape[1] != expert_data.shape[1] and \
loaded_weight.shape[0] != expert_data.shape[0]
):
shard_dim = int(not shard_dim)
loaded_weight = loaded_weight.transpose(0, 1).contiguous()
return loaded_weight, shard_dim
def _load_w13(self,
expert_data: torch.Tensor,
shard_dim: int,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim] // 2
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
loaded_weight, shard_dim = self.transpose_weight(
loaded_weight, expert_data, shard_dim)
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
[CustomOp] Register AscendSharedFusedMoE custom op (#2980) ### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/486c5599e3ab7d721c94dd01e89c87742c01e1ac --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-19 19:05:01 +08:00
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
[CustomOp] Register AscendSharedFusedMoE custom op (#2980) ### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/486c5599e3ab7d721c94dd01e89c87742c01e1ac --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-19 19:05:01 +08:00
AscendFusedMoE.__init__(self, **kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
self.shared_expert_stream.wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(self.shared_expert_stream,
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
[CustomOp] Register AscendSharedFusedMoE custom op (#2980) ### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/486c5599e3ab7d721c94dd01e89c87742c01e1ac --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-19 19:05:01 +08:00
_, fused_out = AscendFusedMoE.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
return shared_out, fused_out
[CustomOp] Register AscendSharedFusedMoE custom op (#2980) ### What this PR does / why we need it? Register `AscendSharedFusedMoE` custom op. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `DeepSeek-V2-Lite` is a MoE model with shared experts. Test: ```bash vllm serve /root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite \ --trust-remote-code \ --enforce-eager \ --no-enable-prefix-caching \ --gpu-memory-utilization 0.95 curl -X POST http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "/root/.cache/modelscope/hub/models/deepseek-ai/DeepSeek-V2-Lite", "messages": [ {"role": "user", "content": "介绍一下联通公司?"} ], "stream": false, "max_tokens": 100 }' ``` Output: ```bash 中国联合网络通信集团有限公司(简称“中国联通”)于2009年1月6日在原中国网通和原中国联通的基础上合并组建而成,在国内31个省(自治区、直辖市)和境外多个国家和地区设有分支机构,是中国唯一一家在纽约、香港、上海三地同时上市的电信运营企业,连续多年入选“世界500强企业”。\n\n中国联通主要经营固定通信业务,移动通信业务,国内 ``` - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/486c5599e3ab7d721c94dd01e89c87742c01e1ac --------- Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Signed-off-by: shen-shanshan <467638484@qq.com>
2025-09-19 19:05:01 +08:00
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
shared_output = torch.empty(1)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_output, fused_output
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
UnquantizedFusedMoEMethod.forward_oot = forward_oot