[Feat.]: 310p support MOE models (#6530)

### What this PR does / why we need it?
This pull request integrates comprehensive support for Mixture of
Experts (MoE) models on the Ascend 310P device within the vllm-ascend
framework. It achieves this by introducing specialized modules for
expert selection, fused MoE layers, and optimized all-gather
communication. The changes also refine existing NPU operations, making
them more consistent and efficient for 310P, ultimately enhancing the
performance and compatibility of MoE models on this hardware.

Highlights
310P MoE Support: Introduces dedicated implementations for Mixture of
Experts (MoE) models on Ascend 310P devices, including new modules for
expert selection, fused MoE layers, and communication.
All-Gather Communication: Enforces the use of ALLGATHER communication
for MoE operations on 310P, optimizing data transfer and leveraging
NPU-specific token dispatching.
Simplified NPU Operations: Removes conditional type casting for
npu_swiglu and enables custom rotary embedding kernels unconditionally,
suggesting improved native support for 310P.
New MoE Classes Registered: Registers AscendFusedMoE310 and
AscendSharedFusedMoE310 to integrate 310P-specific MoE layers into the
system's custom operation registry.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
offline test and server test, with qwen3-30b-a3b,tp/ep 4 on 310p

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-02-06 10:30:56 +08:00
committed by GitHub
parent c38166eefa
commit 85e33941e8
10 changed files with 550 additions and 11 deletions

View File

View File

@@ -0,0 +1,75 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
import torch
from vllm_ascend.ops.fused_moe.experts_selector import _native_select_experts
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: int = -1,
):
"""
Fused experts with select experts.
Args:
router_logits: router logits of shape (num_tokens, hidden_size).
hidden_states: Hidden states of shape (num_tokens, hidden_size).
top_k: number of top k experts.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.
global_num_experts: Global number of experts.
Returns:
topk_weights: router weights of shape (num_tokens, top_k).
topk_ids: selected expert IDs of shape (num_tokens, top_k).
"""
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
return topk_weights, topk_ids

View File

@@ -0,0 +1,300 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
import torch
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
from vllm_ascend.quantization.methods.base import QuantType
from .experts_selector import select_experts
from .moe_comm_method import AllGatherCommImpl310
class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
def process_weights_after_loading(self, layer):
super().process_weights_after_loading(layer)
# Fused gate_up_proj (column parallel)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
# down_proj (row parallel)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
**kwargs,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
assert routed_scaling_factor == 1.0
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = get_forward_context().moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
class AscendFusedMoE310(FusedMoE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.global_num_experts = kwargs["num_experts"]
if self.quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod310(self.moe_config)
else:
self.quant_method = self.quant_config.get_quant_method(self, self.layer_name)
assert self.quant_method is not None
self.moe_config.tp_group = get_tp_group()
self.moe_config.dp_group = get_dp_group()
self.moe_config.ep_group = get_ep_group()
self.moe_config.supports_eplb = False
# init moe
self.global_expert_map = None
self.local_expert_map = None
if self.moe_config.ep_size > 1:
self.global_expert_map, self.local_expert_map = self.init_experts_map(self.moe_config)
self.local_num_experts = (
torch.sum(self.local_expert_map != -1).item()
if self.local_expert_map is not None
else self.global_num_experts
)
self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts
self.moe_config.global_redundant_expert_num = 0
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": self.hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": self.params_dtype,
"weight_loader": self.weight_loader,
}
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.quant_type = self.get_quant_type()
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl310(self.moe_config)
def init_experts_map(self, moe_config):
"""
Initialize expert mapping for MoE (Mixture of Experts) model.
This function creates mappings between global expert indices and local expert indices
for each rank in the expert parallel group. It divides the total experts among
different ranks and creates both global and local expert maps that are used
during MoE computation to determine which experts are handled by which rank.
Args:
moe_config: Configuration object containing MoE parameters including
number of experts, expert parallel size, and expert parallel rank.
Returns:
tuple: A tuple containing:
- global_expert_map: Stack of expert maps for all ranks
- local_expert_map: Expert map for the current rank (transferred to NPU)
"""
n_experts = moe_config.num_experts
ep_size = moe_config.ep_size
all_experts = torch.arange(n_experts, dtype=torch.int32)
experts_groups = all_experts.chunk(ep_size)
global_expert_map = []
local_expert_map = None
for rankid in range(ep_size):
expert_map = torch.full((n_experts,), -1, dtype=torch.int32)
local_experts = experts_groups[rankid]
expert_map[local_experts] = torch.arange(local_experts.shape[0], dtype=torch.int32)
global_expert_map.append(expert_map)
if rankid == moe_config.ep_rank:
local_expert_map = expert_map.npu()
return torch.stack(global_expert_map), local_expert_map
def get_quant_type(self) -> QuantType:
quant_method = self.quant_method
if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
return QuantType.NONE
method = quant_method.quant_method
quant_type = getattr(method, "quant_type", QuantType.NONE)
if quant_type != QuantType.NONE:
# TODO: w8a8 quantization will be supported soon, and only reject w4a8 here.
raise RuntimeError("W8A8 is not supported currently.")
return QuantType.NONE
def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
) -> torch.Tensor:
assert self.quant_method is not None
forward_context = get_forward_context()
hidden_states, router_logits, _, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
pertoken_scale = None
# Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
pertoken_scale=pertoken_scale,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.local_expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata,
)
return routed_out
class AscendSharedFusedMoE310(SharedFusedMoE, AscendFusedMoE310):
def __init__(
self,
shared_experts: torch.nn.Module,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
AscendFusedMoE310.__init__(self, **kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
self._gate = gate
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if self._shared_experts is None:
fused_out = AscendFusedMoE310.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
shared_out = None
return shared_out, fused_out
shared_out, fused_out = AscendFusedMoE310.forward(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
def _forward_shared_experts(self, hidden_states: torch.Tensor):
if self._shared_experts is None:
return None
part1_out = self._shared_experts_part1(hidden_states)
shared_out = self._shared_experts_part2(hidden_states, part1_out)
return shared_out
def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
routed_out = AscendFusedMoE310.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
if self._shared_experts is None:
return routed_out
shared_out = self._forward_shared_experts(hidden_states)
return shared_out, routed_out

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from __future__ import annotations
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
from .token_dispatcher import TokenDispatcherWithAllGather310
class AllGatherCommImpl310(AllGatherCommImpl):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
"""
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather310(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts,
)

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
if with_quant:
raise RuntimeError("Quant is not supported for 310P currently.")
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
if expert_map is not None:
mask = expert_map[topk_ids] != -1
topk_weights = topk_weights * mask
first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
first_expert_idx = 0
last_expert_idx = self.num_experts_local
sorted_hidden_states, expanded_row_idx, expert_tokens = self.moe_init_routing(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k,
active_expert_range=[first_expert_idx, last_expert_idx],
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult(
hidden_states=sorted_hidden_states,
dynamic_scale=None,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
)
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):
"""
Initialize routing for Mixture of Experts (MoE) model by organizing tokens
according to their assigned experts and preparing data structures for
efficient expert computation.
Args:
x (torch.Tensor): Input tensor containing token representations
expert_idx (torch.Tensor): Tensor containing expert indices for each token
active_num (int): Number of active experts or None
active_expert_range (tuple): Range (start, end) of active experts
Returns:
tuple: A tuple containing:
- expanded_x: Subset of input tensor for active experts
- expanded_row_idx: Mapping indices for token positions
- expert_tokens_count: Count of tokens assigned to each expert
"""
MAX_INT32 = torch.iinfo(torch.int32).max
expert_start, expert_end = active_expert_range
num_rows = x.shape[0]
k = expert_idx.shape[-1]
expert_idx_flat = expert_idx.flatten()
mask = (expert_idx_flat >= expert_start) & (expert_idx_flat < expert_end)
actual_expert_total_num = mask.sum().item()
expert_idx_flat = torch.where(
~mask, torch.full_like(expert_idx_flat, MAX_INT32, dtype=torch.int32), expert_idx_flat
)
sorted_idx = torch.argsort(expert_idx_flat, stable=True)
sorted_expert_idx = expert_idx_flat[sorted_idx]
expanded_row_idx = torch.full((num_rows * k,), -1, dtype=torch.int32, device=expert_idx.device)
expanded_row_idx[sorted_idx[:actual_expert_total_num]] = torch.arange(
actual_expert_total_num, dtype=torch.int32, device=expert_idx.device
)
expert_tokens_count = torch.bincount(
sorted_expert_idx[:actual_expert_total_num] - expert_start, minlength=expert_end - expert_start
)
active_num = min(active_num or actual_expert_total_num, actual_expert_total_num)
expanded_x = x[sorted_idx[:active_num] // k]
return expanded_x, expanded_row_idx, expert_tokens_count

View File

@@ -208,6 +208,7 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
4. On A3 with expert parallel, prefer fused MC2 when using w8a8_dynamic
quantization with small EP size, no dynamic_eplb, and not in MTP
mode; otherwise use MC2 within capacity or all-to-all.
5. On 310P, always use all-gather.
Args:
num_tokens (int): The number of tokens in the current batch.
@@ -262,7 +263,8 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
fused_prefill_enable = False
moe_comm_type = MoECommType.FUSED_MC2 if fused_prefill_enable else MoECommType.ALLTOALL
elif soc_version in {AscendDeviceType._310P}:
moe_comm_type = MoECommType.ALLGATHER
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
return moe_comm_type

View File

@@ -82,9 +82,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if get_ascend_device_type() != AscendDeviceType._310P:
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
def apply(self,
layer: torch.nn.Module,

View File

@@ -291,11 +291,7 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
)[0]
if get_ascend_device_type() == AscendDeviceType._310P:
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
torch.float16)
else:
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
if topk_scales is not None:
gate_up_out *= topk_scales

View File

@@ -190,8 +190,7 @@ def _rope_forward_oot(
cos, sin = get_cos_and_sin_slice()
# adopt custom kernel path for rotary_embedding
if _custom_rotary_embedding_enabled(
query, is_neox_style, self.head_size) and get_ascend_device_type(
) != AscendDeviceType._310P:
query, is_neox_style, self.head_size):
query, key = torch.ops._C_ascend.rotary_embedding(
positions,
query,

View File

@@ -625,6 +625,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
# 310P: override selected ops with 310P implementations (keep minimal changes outside _310p)
if is_310p():
from vllm_ascend._310p.fused_moe.fused_moe import AscendFusedMoE310, AscendSharedFusedMoE310
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
from vllm_ascend._310p.ops.mm_encoder_attention import AscendMMEncoderAttention310
@@ -637,6 +638,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
"RotaryEmbedding": AscendRotaryEmbedding310,
"RMSNorm": AscendRMSNorm310,
"GemmaRMSNorm": AscendGemmaRMSNorm310,
"FusedMoE": AscendFusedMoE310,
"SharedFusedMoE": AscendSharedFusedMoE310,
}
)