Files
xc-llm-ascend/vllm_ascend/ops/common_fused_moe.py

425 lines
16 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
import torch
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl)
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
moe_parallel_config: FusedMoEParallelConfig,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
top_k: Number of experts to select.
expert_map: Expert mapping of shape (num_experts,).
Returns:
hidden_states: Hidden states after routing.
"""
ep_size = moe_parallel_config.ep_size
local_num_experts = global_num_experts // ep_size
local_num_group = top_k // ep_size
bsz, _ = hidden_states.shape
flatten_topk_ids = topk_ids.view(-1)
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
sorted_hidden_states = hidden_states.index_select(
0, sorted_topk_ids // local_num_group)
experts_id = torch.arange(0,
local_num_experts,
dtype=topk_ids.dtype,
device=topk_ids.device)
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
topk_scales = topk_weights.view(-1).index_select(
0, sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
gate_up_out = torch_npu.npu_grouped_matmul(
x=[sorted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
if is_310p():
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out *= topk_scales
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)[0]
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
final_hidden_states = unsorted_hidden_states.reshape(
bsz, top_k // ep_size, -1).sum(1)
return final_hidden_states
def unquantized_fused_moe_init_func(self, *args, **kwargs):
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
# NOTE: Currently, this self.use_aclgraph is only used in
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
# Once torch.randint_like is supported or removed, this flag can be removed.
vllm_config = get_current_vllm_config()
ascend_config = get_ascend_config()
if ascend_config.torchair_graph_config.enabled:
self.use_aclgraph = False
else:
self.use_aclgraph = (vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
def forward_oot_v01011(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=1.0,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
return fused_experts_moge(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
if topk_ids.shape[1] < top_k or is_310p():
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
assert global_num_experts is not None
return fused_experts_moge(
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
moe_parallel_config=self.moe.moe_parallel_config,
[Platform] Add initial experimental support for Altlas 300I series (#1333) ### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
2025-06-21 09:00:16 +08:00
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendFusedMoE(FusedMoE):
def __init__(
self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
routed_scaling_fator: float = 1.0,
e_score_correction_bias=None,
apply_router_weight_on_input=False,
activation="silu",
enable_eplb=False,
num_redundant_experts=0,
has_bias=False,
):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
else:
super().__init__(
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype,
reduce_results,
renormalize,
use_grouped_topk,
num_expert_group,
topk_group,
quant_config,
tp_size,
ep_size,
dp_size,
prefix,
custom_routing_function,
scoring_func,
routed_scaling_fator,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
num_redundant_experts,
has_bias,
)
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
setup_token_dispatchers(self.moe_config.ep_size,
top_k=self.top_k,
num_experts=self.global_num_experts,
num_local_experts=self.local_num_experts)
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.mc2_group = get_mc2_group()
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
setattr(
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
final_hidden_states = forward_context.moe_comm_method.finalize(
hidden_states=final_hidden_states,
reduce_results=self.reduce_results)
return final_hidden_states
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
[3/N][Feat][Graph] Support `all-to-all` and quantized models with ACL Graph (#2614) ### What this PR does / why we need it? * **Unify execution paths:** Consolidates the quantized and non-quantized execution paths into a single `fused_experts` function, removing duplicated logic and making the control flow clearer and easier to maintain. * **W8A8 dynamic quantization:** Adds support for W8A8 dynamic quantization inside the unified MoE kernel. Communication routines are updated to correctly handle dynamic quantization scales for activations. * **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight matrices (as implemented in PR #2025) so that quantized and non-quantized models follow the same code path for the MoE gating, up-projection, and down-projection operations. * **All-to-all communication:** Adds an `all-to-all` collective communication pattern. For large token counts on modern hardware, `all-to-all` is more efficient than the previous `all-gather` strategy. However, `all-to-all` is not really captured and replayed due to multiple D2H operations which will trigger synchronization, and thus raise error when capture graphs. We only use `all-to-all` when fallback to `compiled_graph_for_general_shape`. * **Dynamic communication selection:** The model runner now selects the optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at runtime based on token count and the Ascend SoC version. * **Limitation:** `all-gather` is not yet supported for quantized models, which means there is still something left to do on A2. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? No further test cases needed. - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d660c98c1b59580af97d6c7dd162c7c8894d40ed --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
2025-08-30 11:00:35 +08:00
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
UnquantizedFusedMoEMethod.forward_oot = forward_oot_v01011
else:
UnquantizedFusedMoEMethod.forward_oot = forward_oot