diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 6b673c4..ff301bd 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -27,6 +27,7 @@ from vllm.model_executor.layers.fused_moe.config import \ FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group @@ -415,7 +416,7 @@ class AscendFusedMoE(FusedMoE): expert_data.copy_(loaded_weight) -class AscendSharedFusedMoE(AscendFusedMoE): +class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): def __init__( self, @@ -423,7 +424,7 @@ class AscendSharedFusedMoE(AscendFusedMoE): use_overlapped: bool = True, **kwargs, ): - super().__init__(**kwargs) + AscendFusedMoE.__init__(self, **kwargs) self._shared_experts = shared_experts self.use_overlapped = use_overlapped self.shared_expert_stream = None @@ -452,7 +453,8 @@ class AscendSharedFusedMoE(AscendFusedMoE): if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: shared_out = tensor_model_parallel_all_reduce(shared_out) - fused_out = super().forward( + _, fused_out = AscendFusedMoE.forward( + self, hidden_states=hidden_states, router_logits=router_logits, ) @@ -461,6 +463,16 @@ class AscendSharedFusedMoE(AscendFusedMoE): torch.npu.current_stream().wait_stream(self.shared_expert_stream) return shared_out, fused_out + def forward_impl(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + shared_output = torch.empty(1) + fused_output = AscendFusedMoE.forward_impl( + self, + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_output, fused_output + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index c8a72e2..a723072 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -18,4 +18,3 @@ import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa -import vllm_ascend.patch.worker.patch_common.patch_shared_fused_moe # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_shared_fused_moe.py b/vllm_ascend/patch/worker/patch_common/patch_shared_fused_moe.py deleted file mode 100644 index 6b6dfd5..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_shared_fused_moe.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from vllm.model_executor.models import deepseek_v2, llama4 - -from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE - -deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE -llama4.SharedFusedMoE = AscendSharedFusedMoE \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 33d1699..06e1a2b 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -498,7 +498,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul - from vllm_ascend.ops.common_fused_moe import AscendFusedMoE + from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, + AscendSharedFusedMoE) from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, @@ -525,6 +526,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "LogitsProcessor": AscendLogitsProcessor, "RMSNorm": AscendRMSNorm, "FusedMoE": AscendFusedMoE, + "SharedFusedMoE": AscendSharedFusedMoE, "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, }