diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 3142bc8..082e8a8 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -20,7 +20,8 @@ from typing import Callable, Optional import torch import torch_npu from vllm.config import CompilationLevel, get_current_vllm_config -from vllm.distributed import get_dp_group, get_ep_group, get_tp_group +from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, + tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import \ FusedMoEParallelConfig # isort: skip @@ -373,6 +374,21 @@ class AscendFusedMoE(FusedMoE): self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] + def maybe_all_reduce_tensor_model_parallel( + self, final_hidden_states: torch.Tensor): + """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, + and `alltoallcommimpl`, we do not need to all-reduce the final outputs since + the outputs are already aggregated across tensor parallel ranks in the + `finalize` function. In `allgathercommimpl`, we still need to all-reduce the + outputs since each rank only has partial outputs. + """ + forward_context = get_forward_context() + moe_comm_method_name = forward_context.moe_comm_method_name + if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: + return final_hidden_states + else: + return tensor_model_parallel_all_reduce(final_hidden_states) + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None @@ -415,6 +431,38 @@ class AscendFusedMoE(FusedMoE): return final_hidden_states +class AscendSharedFusedMoE(AscendFusedMoE): + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + shared_out = self._shared_experts(hidden_states) + + # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` + forward_context = get_forward_context() + moe_comm_method_name = forward_context.moe_comm_method_name + if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}: + shared_out = tensor_model_parallel_all_reduce(shared_out) + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out + + UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index f88f2a9..35ef149 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -16,3 +16,4 @@ # import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa +import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa diff --git a/vllm_ascend/patch/platform/patch_common/patch_shared_fused_moe.py b/vllm_ascend/patch/platform/patch_common/patch_shared_fused_moe.py new file mode 100644 index 0000000..6b6dfd5 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_common/patch_shared_fused_moe.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm.model_executor.models import deepseek_v2, llama4 + +from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE + +deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE +llama4.SharedFusedMoE = AscendSharedFusedMoE \ No newline at end of file