From 85e33941e848e5a273b243fad0eb685fbfad2a02 Mon Sep 17 00:00:00 2001 From: pu-zhe Date: Fri, 6 Feb 2026 10:30:56 +0800 Subject: [PATCH] [Feat.]: 310p support MOE models (#6530) ### What this PR does / why we need it? This pull request integrates comprehensive support for Mixture of Experts (MoE) models on the Ascend 310P device within the vllm-ascend framework. It achieves this by introducing specialized modules for expert selection, fused MoE layers, and optimized all-gather communication. The changes also refine existing NPU operations, making them more consistent and efficient for 310P, ultimately enhancing the performance and compatibility of MoE models on this hardware. Highlights 310P MoE Support: Introduces dedicated implementations for Mixture of Experts (MoE) models on Ascend 310P devices, including new modules for expert selection, fused MoE layers, and communication. All-Gather Communication: Enforces the use of ALLGATHER communication for MoE operations on 310P, optimizing data transfer and leveraging NPU-specific token dispatching. Simplified NPU Operations: Removes conditional type casting for npu_swiglu and enables custom rotary embedding kernels unconditionally, suggesting improved native support for 310P. New MoE Classes Registered: Registers AscendFusedMoE310 and AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the system's custom operation registry. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: pu-zhe --- vllm_ascend/_310p/fused_moe/__init__.py | 0 .../_310p/fused_moe/experts_selector.py | 75 +++++ vllm_ascend/_310p/fused_moe/fused_moe.py | 300 ++++++++++++++++++ .../_310p/fused_moe/moe_comm_method.py | 39 +++ .../_310p/fused_moe/token_dispatcher.py | 126 ++++++++ vllm_ascend/ascend_forward_context.py | 4 +- vllm_ascend/ops/fused_moe/fused_moe.py | 5 +- vllm_ascend/ops/fused_moe/moe_mlp.py | 6 +- vllm_ascend/ops/rotary_embedding.py | 3 +- vllm_ascend/utils.py | 3 + 10 files changed, 550 insertions(+), 11 deletions(-) create mode 100644 vllm_ascend/_310p/fused_moe/__init__.py create mode 100644 vllm_ascend/_310p/fused_moe/experts_selector.py create mode 100644 vllm_ascend/_310p/fused_moe/fused_moe.py create mode 100644 vllm_ascend/_310p/fused_moe/moe_comm_method.py create mode 100644 vllm_ascend/_310p/fused_moe/token_dispatcher.py diff --git a/vllm_ascend/_310p/fused_moe/__init__.py b/vllm_ascend/_310p/fused_moe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/_310p/fused_moe/experts_selector.py b/vllm_ascend/_310p/fused_moe/experts_selector.py new file mode 100644 index 00000000..71200c99 --- /dev/null +++ b/vllm_ascend/_310p/fused_moe/experts_selector.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections.abc import Callable + +import torch + +from vllm_ascend.ops.fused_moe.experts_selector import _native_select_experts +from vllm_ascend.utils import get_weight_prefetch_method + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + e_score_correction_bias: torch.Tensor | None = None, + global_num_experts: int = -1, +): + """ + Fused experts with select experts. + + Args: + router_logits: router logits of shape (num_tokens, hidden_size). + hidden_states: Hidden states of shape (num_tokens, hidden_size). + top_k: number of top k experts. + use_grouped_topk: Whether to group experts before selecting top-k. + renormalize: Whether to renormalize the routing weights. + topk_group: Number of expert groups to select from. + num_expert_group: Number of experts in each group. + custom_routing_function: Custom routing function. + scoring_func: Scoring function to use. + e_score_correction_bias: Correction bias to apply to expert scores. + global_num_experts: Global number of experts. + + Returns: + topk_weights: router weights of shape (num_tokens, top_k). + topk_ids: selected expert IDs of shape (num_tokens, top_k). + """ + # prefetch w1_w3_proj.weight preprocess + weight_prefetch_method = get_weight_prefetch_method() + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up") + topk_weights, topk_ids = _native_select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + return topk_weights, topk_ids diff --git a/vllm_ascend/_310p/fused_moe/fused_moe.py b/vllm_ascend/_310p/fused_moe/fused_moe.py new file mode 100644 index 00000000..5cca5036 --- /dev/null +++ b/vllm_ascend/_310p/fused_moe/fused_moe.py @@ -0,0 +1,300 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections.abc import Callable + +import torch +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE + +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute +from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods +from vllm_ascend.quantization.methods.base import QuantType + +from .experts_selector import select_experts +from .moe_comm_method import AllGatherCommImpl310 + + +class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod): + def __init__(self, moe: FusedMoEConfig = None): + super().__init__(moe=moe) + + def process_weights_after_loading(self, layer): + super().process_weights_after_loading(layer) + + # Fused gate_up_proj (column parallel) + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) + # down_proj (row parallel) + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: int | None = None, + num_expert_group: int | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + **kwargs, + ) -> torch.Tensor: + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + assert routed_scaling_factor == 1.0 + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts, + ) + + if zero_expert_num > 0 and zero_expert_type is not None: + topk_ids, topk_weights, zero_expert_result = zero_experts_compute( + expert_indices=topk_ids, + expert_scales=topk_weights, + num_experts=global_num_experts, + zero_expert_type=zero_expert_type, + hidden_states=x, + ) + + topk_weights = topk_weights.to(x.dtype) + + moe_comm_method = get_forward_context().moe_comm_method + final_hidden_states = moe_comm_method.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + if zero_expert_num > 0 and zero_expert_type is not None: + final_hidden_states += zero_expert_result + return final_hidden_states + + +class AscendFusedMoE310(FusedMoE): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.global_num_experts = kwargs["num_experts"] + + if self.quant_config is None: + self.quant_method = AscendUnquantizedFusedMoEMethod310(self.moe_config) + else: + self.quant_method = self.quant_config.get_quant_method(self, self.layer_name) + + assert self.quant_method is not None + + self.moe_config.tp_group = get_tp_group() + self.moe_config.dp_group = get_dp_group() + self.moe_config.ep_group = get_ep_group() + self.moe_config.supports_eplb = False + + # init moe + self.global_expert_map = None + self.local_expert_map = None + if self.moe_config.ep_size > 1: + self.global_expert_map, self.local_expert_map = self.init_experts_map(self.moe_config) + self.local_num_experts = ( + torch.sum(self.local_expert_map != -1).item() + if self.local_expert_map is not None + else self.global_num_experts + ) + + self.moe_config.num_experts = self.global_num_experts + self.moe_config.num_local_experts = self.local_num_experts + self.moe_config.global_redundant_expert_num = 0 + + moe_quant_params = { + "num_experts": self.local_num_experts, + "hidden_size": self.hidden_size, + "intermediate_size_per_partition": self.intermediate_size_per_partition, + "params_dtype": self.params_dtype, + "weight_loader": self.weight_loader, + } + + self.quant_method.create_weights(layer=self, **moe_quant_params) + self.quant_type = self.get_quant_type() + + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl310(self.moe_config) + + def init_experts_map(self, moe_config): + """ + Initialize expert mapping for MoE (Mixture of Experts) model. + + This function creates mappings between global expert indices and local expert indices + for each rank in the expert parallel group. It divides the total experts among + different ranks and creates both global and local expert maps that are used + during MoE computation to determine which experts are handled by which rank. + + Args: + moe_config: Configuration object containing MoE parameters including + number of experts, expert parallel size, and expert parallel rank. + + Returns: + tuple: A tuple containing: + - global_expert_map: Stack of expert maps for all ranks + - local_expert_map: Expert map for the current rank (transferred to NPU) + """ + n_experts = moe_config.num_experts + ep_size = moe_config.ep_size + all_experts = torch.arange(n_experts, dtype=torch.int32) + experts_groups = all_experts.chunk(ep_size) + global_expert_map = [] + local_expert_map = None + for rankid in range(ep_size): + expert_map = torch.full((n_experts,), -1, dtype=torch.int32) + local_experts = experts_groups[rankid] + expert_map[local_experts] = torch.arange(local_experts.shape[0], dtype=torch.int32) + global_expert_map.append(expert_map) + if rankid == moe_config.ep_rank: + local_expert_map = expert_map.npu() + return torch.stack(global_expert_map), local_expert_map + + def get_quant_type(self) -> QuantType: + quant_method = self.quant_method + if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None: + return QuantType.NONE + + method = quant_method.quant_method + quant_type = getattr(method, "quant_type", QuantType.NONE) + if quant_type != QuantType.NONE: + # TODO: w8a8 quantization will be supported soon, and only reject w4a8 here. + raise RuntimeError("W8A8 is not supported currently.") + return QuantType.NONE + + def forward_impl( # type: ignore[override] + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ) -> torch.Tensor: + assert self.quant_method is not None + forward_context = get_forward_context() + + hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare( + hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type + ) + + if isinstance(hidden_states, tuple): + hidden_states, pertoken_scale = hidden_states + else: + pertoken_scale = None + + # Matrix multiply. + fused_experts_results: FusedExpertsResult = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + pertoken_scale=pertoken_scale, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.local_expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + e_score_correction_bias=self.e_score_correction_bias, + activation=self.activation, + apply_router_weight_on_input=self.apply_router_weight_on_input, + ) + + routed_out = forward_context.moe_comm_method.finalize( + hidden_states=fused_experts_results.routed_out, + reduce_results=self.reduce_results, + context_metadata=context_metadata, + ) + + return routed_out + + +class AscendSharedFusedMoE310(SharedFusedMoE, AscendFusedMoE310): + def __init__( + self, + shared_experts: torch.nn.Module, + gate: torch.nn.Module | None = None, + use_overlapped: bool = True, + **kwargs, + ): + AscendFusedMoE310.__init__(self, **kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + self.shared_expert_stream = None + self._gate = gate + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self._shared_experts is None: + fused_out = AscendFusedMoE310.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + shared_out = None + return shared_out, fused_out + shared_out, fused_out = AscendFusedMoE310.forward( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + def _forward_shared_experts(self, hidden_states: torch.Tensor): + if self._shared_experts is None: + return None + part1_out = self._shared_experts_part1(hidden_states) + shared_out = self._shared_experts_part2(hidden_states, part1_out) + return shared_out + + def forward_impl( # type: ignore[override] + self, hidden_states: torch.Tensor, router_logits: torch.Tensor + ): + routed_out = AscendFusedMoE310.forward_impl( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + if self._shared_experts is None: + return routed_out + shared_out = self._forward_shared_experts(hidden_states) + return shared_out, routed_out diff --git a/vllm_ascend/_310p/fused_moe/moe_comm_method.py b/vllm_ascend/_310p/fused_moe/moe_comm_method.py new file mode 100644 index 00000000..36fadf27 --- /dev/null +++ b/vllm_ascend/_310p/fused_moe/moe_comm_method.py @@ -0,0 +1,39 @@ +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +from __future__ import annotations + +from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl + +from .token_dispatcher import TokenDispatcherWithAllGather310 + + +class AllGatherCommImpl310(AllGatherCommImpl): + """This implementation is the same as NativeAllGatherCommImpl, + but uses NPU-specific ops for better performance. + + This implementation should be compatible with all scenarios, and + thus it is the default implementation for MoE communication methods. + It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing + and `torch_npu.npu_moe_token_unpermute` for post-processing + to handle the token-to-expert mapping and communication efficiently. + """ + + def _get_token_dispatcher(self): + return TokenDispatcherWithAllGather310( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts, + ) diff --git a/vllm_ascend/_310p/fused_moe/token_dispatcher.py b/vllm_ascend/_310p/fused_moe/token_dispatcher.py new file mode 100644 index 00000000..00c611cf --- /dev/null +++ b/vllm_ascend/_310p/fused_moe/token_dispatcher.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from vllm.distributed.parallel_state import get_ep_group + +from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult + + +class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def token_dispatch( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + mc2_mask: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + with_quant: bool = False, + dynamic_eplb: bool = False, + pertoken_scale: torch.Tensor | None = None, + ): + if with_quant: + raise RuntimeError("Quant is not supported for 310P currently.") + self.original_shape = hidden_states.shape + + num_tokens = hidden_states.shape[:-1].numel() + self.apply_router_weight_on_input = apply_router_weight_on_input + if self.apply_router_weight_on_input: + assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" + hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) + if expert_map is not None: + mask = expert_map[topk_ids] != -1 + topk_weights = topk_weights * mask + first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local + last_expert_idx = first_expert_idx + self.num_experts_local + else: + first_expert_idx = 0 + last_expert_idx = self.num_experts_local + + sorted_hidden_states, expanded_row_idx, expert_tokens = self.moe_init_routing( + hidden_states, + topk_ids, + active_num=num_tokens * self.top_k, + active_expert_range=[first_expert_idx, last_expert_idx], + ) + expert_tokens = expert_tokens.to(torch.int64) + group_list_type = 1 # `count` mode + context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx} + + return TokenDispatchResult( + hidden_states=sorted_hidden_states, + dynamic_scale=None, + group_list=expert_tokens, + group_list_type=group_list_type, + context_metadata=context_metadata, + ) + + def moe_init_routing(self, x, expert_idx, active_num, active_expert_range): + """ + Initialize routing for Mixture of Experts (MoE) model by organizing tokens + according to their assigned experts and preparing data structures for + efficient expert computation. + + Args: + x (torch.Tensor): Input tensor containing token representations + expert_idx (torch.Tensor): Tensor containing expert indices for each token + active_num (int): Number of active experts or None + active_expert_range (tuple): Range (start, end) of active experts + + Returns: + tuple: A tuple containing: + - expanded_x: Subset of input tensor for active experts + - expanded_row_idx: Mapping indices for token positions + - expert_tokens_count: Count of tokens assigned to each expert + """ + MAX_INT32 = torch.iinfo(torch.int32).max + expert_start, expert_end = active_expert_range + num_rows = x.shape[0] + k = expert_idx.shape[-1] + expert_idx_flat = expert_idx.flatten() + mask = (expert_idx_flat >= expert_start) & (expert_idx_flat < expert_end) + actual_expert_total_num = mask.sum().item() + expert_idx_flat = torch.where( + ~mask, torch.full_like(expert_idx_flat, MAX_INT32, dtype=torch.int32), expert_idx_flat + ) + sorted_idx = torch.argsort(expert_idx_flat, stable=True) + sorted_expert_idx = expert_idx_flat[sorted_idx] + expanded_row_idx = torch.full((num_rows * k,), -1, dtype=torch.int32, device=expert_idx.device) + expanded_row_idx[sorted_idx[:actual_expert_total_num]] = torch.arange( + actual_expert_total_num, dtype=torch.int32, device=expert_idx.device + ) + expert_tokens_count = torch.bincount( + sorted_expert_idx[:actual_expert_total_num] - expert_start, minlength=expert_end - expert_start + ) + active_num = min(active_num or actual_expert_total_num, actual_expert_total_num) + expanded_x = x[sorted_idx[:active_num] // k] + + return expanded_x, expanded_row_idx, expert_tokens_count diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 7f56165e..8e14b4d2 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -208,6 +208,7 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo 4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic quantization with small EP size, no dynamic_eplb, and not in MTP mode; otherwise use MC2 within capacity or all-to-all. + 5. On 310P, always use all-gather. Args: num_tokens (int): The number of tokens in the current batch. @@ -262,7 +263,8 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: fused_prefill_enable = False moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL - + elif soc_version in {AscendDeviceType._310P}: + moe_comm_type = MoECommType.ALLGATHER else: raise ValueError(f"Unsupported soc_version: {soc_version}") return moe_comm_type diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 86be6d99..f67cc1a3 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -82,9 +82,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): 1, 2).contiguous() layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - if get_ascend_device_type() != AscendDeviceType._310P: - layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) - layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) + layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) + layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) def apply(self, layer: torch.nn.Module, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index d102a1d5..e29945ea 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -291,11 +291,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor, group_type=0, group_list=group_list, )[0] - if get_ascend_device_type() == AscendDeviceType._310P: - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) + gate_up_out = torch_npu.npu_swiglu(gate_up_out) if topk_scales is not None: gate_up_out *= topk_scales diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index b4da71f3..31f1a8da 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -190,8 +190,7 @@ def _rope_forward_oot( cos, sin = get_cos_and_sin_slice() # adopt custom kernel path for rotary_embedding if _custom_rotary_embedding_enabled( - query, is_neox_style, self.head_size) and get_ascend_device_type( - ) != AscendDeviceType._310P: + query, is_neox_style, self.head_size): query, key = torch.ops._C_ascend.rotary_embedding( positions, query, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9aadfb66..2f150160 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -625,6 +625,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): # 310P: override selected ops with 310P implementations (keep minimal changes outside _310p) if is_310p(): + from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310 from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310 @@ -637,6 +638,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "RotaryEmbedding": AscendRotaryEmbedding310, "RMSNorm": AscendRMSNorm310, "GemmaRMSNorm": AscendGemmaRMSNorm310, + "FusedMoE": AscendFusedMoE310, + "SharedFusedMoE": AscendSharedFusedMoE310, } )