Files
xc-llm-ascend/vllm_ascend/ops/fused_moe/moe_comm_method.py

309 lines
13 KiB
Python
Raw Normal View History

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalizeWithAll2All, PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2, QuantType)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
_MoECommMethods[MoECommType.FUSED_ALLTOALL] = FusedAlltoAllCommImpl(
moe_config)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher()
self.prepare_finalize = self._get_prepare_finalize()
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, quant_type)
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.prepare_finalize.finalize(hidden_states,
reduce_results,
context_metadata)
return hidden_states
def fused_experts(
self,
hidden_states: torch.Tensor,
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
[EPLB][Ops] Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list operator into dynamic EPLB (#4216) ### What this PR does / why we need it? Integerate grouped_matmul_swiglu_quant_weight_nz_tensor_list into dynamic EPLB to support list-type parameters This PR also modify the logic of loading model in dynamic-eplb scenario. The operator is based on this pr: https://github.com/vllm-project/vllm-ascend/pull/3804 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ``` vllm serve /home/weight/DeepSeek-V3.1_w8a8mix_mtp \ --max_num_seqs 8 \ --max-model-len 8192 \ --max-num-batched-tokens 16384 \ --tensor-parallel-size 8 \ --data-parallel-size 2 \ --enable-expert-parallel \ --served-model-name ds_r1 \ --enable-auto-tool-choice \ --tool-call-parser hermes \ --no-enable-prefix-caching \ --port 8999 \ --quantization "ascend" \ --gpu-memory-utilization 0.85 \ --trust-remote-code \ --compilation_config '{"cudagraph_capture_sizes":[1,2,4,8,16,32]}' \ --additional-config='{"dynamic_eplb":true, "num_iterations_eplb_update":100, "num_wait_worker_iterations":100}' ``` input&output: 2k 2k This PR: <img width="1318" height="695" alt="fusion" src="https://github.com/user-attachments/assets/f8657813-0c02-42f4-8396-d99e730f48cd" /> Baseline: <img width="1323" height="690" alt="baseline" src="https://github.com/user-attachments/assets/e1323a78-af26-4523-820c-e20e5642a38e" /> - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: 白永斌 <baiyongbin3@h-partners.com> Signed-off-by: 欧派果奶我还要 <845473182@qq.com> Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
2025-11-30 22:52:05 +08:00
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
# Check constraints
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.int8
]
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8 or use_int4_w4a16,
fusion=use_int8_w8a8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=context_metadata)
Dynamic Expert Load Balance with Zero-like-overhead (#2956) ### Motivation Currently dynamically experts balancing would stop-the-world. Asynchronously expert load balancing would be better without flowing problems: Host-bound latency: There are many cpu operations during EPLB such as eplb-algorithm、creating p2p ops、and log2phy expert converting would spend long cpu time, as ~1s. Communication latency: The transfer time would cost much in the situation without nvlink. As the weight of an expert maybe transfer to multiple new positions, thus N times send/recv for one expert, with result long latency. We had tested that batch_isend_irecv cost more 100ms for 16 experts weight transmission in A2 server of ascend. SwiftBalancer would not stop-the-world anymore, in out test on NPU 1~2ms cost for each layer while benefit 5ms-8ms decode latency with ep_size = 64. The following updates have been made: 1、expert distribution recording with lower cost. 2、async cpu computing for eplb algo and other python operator. 3、new eplb algo with less expert rebalancing while almost the same effect. ### Proposed Change We will gradually migrate the EPLB logic to the VLLM community and implement a generalized design. Relevant RFC: https://github.com/vllm-project/vllm/issues/22246 The overall workflow involves: <img width="801" height="302" alt="474430541-23b06f58-23bc-44a3-a1be-00f268aeb15c" src="https://github.com/user-attachments/assets/1d73a459-1b23-4b0a-812a-bf0a75debfed" /> 1. Record experts distribution during forward. We using expert_token_num after disptach instead of topk_ids, thus we got much smaller tensor shape to reduce cost of hbm recording and add-operator. 2. Do all-gather for experts distribution. Using all-gather instead of all-reduce as less traffic volume. 3. Wake up eplb worker process with experts distribution when num_iterations comes. Run eplb algorithm in eplb worker. 4. Generate p2p send/recv ops and other operator such as log2phy would cost long cpu time. 5. Lanch ibatch_send_recv in async_stream before forward. 6. After forward, wait for the ibatch_send_recv finish, then do uapte expert map and expert weights. ### Co-author Co-authored-by: raindaywhu raindaywhu@raindaywhu@ 163.con Co-authored-by: njuyuan yuanjl19@smail.nju.edu.cn Co-authored-by: qmkakaxi wjh1594260677@qq.com Co-authored-by: Skywalker-EP 173723846@qq.com - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/567939953b7a9cb0ded6bf0bb21a76917b8fed97 --------- Signed-off-by: offline0806 <z00858301@china.huawei.com> Co-authored-by: offline0806 <z00858301@china.huawei.com>
2025-09-17 10:36:43 +08:00
if dynamic_eplb:
return (final_hidden_states, group_list_type, expert_tokens)
return final_hidden_states
@abstractmethod
def _get_token_dispatcher(self):
raise NotImplementedError(
"_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_prepare_finalize(self):
raise NotImplementedError(
"_get_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config)
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithMC2()
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
class FusedAlltoAllCommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_grouped_matmul` is available.
This implementation uses all-to-all communication to exchange tokens
between data parallel ranks before and after the MLP computation. It should
have better performance than AllGatherCommImpl when DP size > 1.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
add `dispatch_gmm_combine` kernel (#3532) ### What this PR does / why we need it? This PR introduces the Ascend implementation of the `dispatch_ffn_combine` kernel and wires it into the vLLM-Ascend runtime, together with follow‑up fixes to ensure the kernel builds and runs correctly in CI. - Add full host and device implementation of the `dispatch_ffn_combine` kernel under `csrc/dispatch_ffn_combine`, including tiling logic, MOE routing helpers, and kernel utilities for quantized FFN dispatch. - Integrate the new kernel with the PyTorch binding (csrc/torch_binding.cpp, csrc/torch_binding_meta.cpp) and the Ascend runtime (vllm_ascend/ascend_forward_context.py, vllm_ascend/worker/model_runner_v1.py). - Extend fused MoE communication and token dispatch support in `vllm_ascend/ops/fused_moe`, adding methods/utilities needed by the new dispatch path. - Update quantization logic in vllm_ascend/quantization/w8a8_dynamic.py to support the new FFN dispatch flow. - Fix kernel build issues by adjusting `csrc/build_aclnn.sh`, CMake configuration, and include/namespace usage in the new kernel files. - Add an end‑to‑end nightly test `tests/e2e/nightly/ops/test_dispatch_ffn_combine.py` and helper utilities in `vllm_ascend/utils.py` to validate the new kernel. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.12.0 --------- Signed-off-by: mojave2 <chenchen145@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-12-04 23:00:59 +08:00
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
out = torch.empty_like(hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine(
x=hidden_states,
weight1=w1,
weight2=w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
)
return out