### What this PR does / why we need it? This pull request integrates comprehensive support for Mixture of Experts (MoE) models on the Ascend 310P device within the vllm-ascend framework. It achieves this by introducing specialized modules for expert selection, fused MoE layers, and optimized all-gather communication. The changes also refine existing NPU operations, making them more consistent and efficient for 310P, ultimately enhancing the performance and compatibility of MoE models on this hardware. Highlights 310P MoE Support: Introduces dedicated implementations for Mixture of Experts (MoE) models on Ascend 310P devices, including new modules for expert selection, fused MoE layers, and communication. All-Gather Communication: Enforces the use of ALLGATHER communication for MoE operations on 310P, optimizing data transfer and leveraging NPU-specific token dispatching. Simplified NPU Operations: Removes conditional type casting for npu_swiglu and enables custom rotary embedding kernels unconditionally, suggesting improved native support for 310P. New MoE Classes Registered: Registers AscendFusedMoE310 and AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the system's custom operation registry. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0 --------- Signed-off-by: pu-zhe <zpuaa@outlook.com>
76 lines
2.9 KiB
Python
76 lines
2.9 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from collections.abc import Callable
|
|
|
|
import torch
|
|
|
|
from vllm_ascend.ops.fused_moe.experts_selector import _native_select_experts
|
|
from vllm_ascend.utils import get_weight_prefetch_method
|
|
|
|
|
|
def select_experts(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
topk_group: int | None = None,
|
|
num_expert_group: int | None = None,
|
|
custom_routing_function: Callable | None = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: torch.Tensor | None = None,
|
|
global_num_experts: int = -1,
|
|
):
|
|
"""
|
|
Fused experts with select experts.
|
|
|
|
Args:
|
|
router_logits: router logits of shape (num_tokens, hidden_size).
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
top_k: number of top k experts.
|
|
use_grouped_topk: Whether to group experts before selecting top-k.
|
|
renormalize: Whether to renormalize the routing weights.
|
|
topk_group: Number of expert groups to select from.
|
|
num_expert_group: Number of experts in each group.
|
|
custom_routing_function: Custom routing function.
|
|
scoring_func: Scoring function to use.
|
|
e_score_correction_bias: Correction bias to apply to expert scores.
|
|
global_num_experts: Global number of experts.
|
|
|
|
Returns:
|
|
topk_weights: router weights of shape (num_tokens, top_k).
|
|
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
|
"""
|
|
# prefetch w1_w3_proj.weight preprocess
|
|
weight_prefetch_method = get_weight_prefetch_method()
|
|
if weight_prefetch_method:
|
|
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
|
|
topk_weights, topk_ids = _native_select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts,
|
|
)
|
|
return topk_weights, topk_ids
|