From 1fce70a2fb2602170781773104a69c69beb16161 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Mon, 28 Apr 2025 21:57:01 +0800 Subject: [PATCH] [Model] Support common fused moe ops for moe model, such as Qwen3Moe (#709) vllm-ascend now only support moe for deepseek. We should add common moe support back Signed-off-by: wangxiyuan --- vllm_ascend/ops/__init__.py | 1 + vllm_ascend/ops/common_fused_moe.py | 67 +++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 vllm_ascend/ops/common_fused_moe.py diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index 71d86a2..317024f 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -19,6 +19,7 @@ import torch import torch_npu # noqa: F401 import vllm_ascend.ops.activation # noqa +import vllm_ascend.ops.common_fused_moe # noqa import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa import vllm_ascend.ops.rotary_embedding # noqa diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py new file mode 100644 index 0000000..1f2bf43 --- /dev/null +++ b/vllm_ascend/ops/common_fused_moe.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Callable, Optional + +import torch +from vllm.model_executor.layers.fused_moe.layer import \ + UnquantizedFusedMoEMethod + +from vllm_ascend.ops.fused_moe import fused_experts, select_experts + + +def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: Optional[int] = None, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", +) -> torch.Tensor: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + + +UnquantizedFusedMoEMethod.forward_oot = forward_oot