[Fix] Fix SharedFusedMoE (#2817)

### What this PR does / why we need it?
Really strange that `register_oot` doesn't work with `SharedFusedMoE`,
so we have to add this patch, for now.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
This PR won't have any effect in DeepSeek since we currently still stick
with the old `CustomDeepseekV2`.

- vLLM version: v0.10.1.1
- vLLM main:
0cdd213641

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-09-09 18:19:56 +08:00
committed by GitHub
parent 7a205dbaa8
commit e13c4ddb42
3 changed files with 71 additions and 1 deletions

View File

@@ -20,7 +20,8 @@ from typing import Callable, Optional
import torch
import torch_npu
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
@@ -373,6 +374,21 @@ class AscendFusedMoE(FusedMoE):
self, method.__name__.lower(),
method(moe_config=self.moe_config)) # type: ignore[abstract]
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None
@@ -415,6 +431,38 @@ class AscendFusedMoE(FusedMoE):
return final_hidden_states
class AscendSharedFusedMoE(AscendFusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
self.use_overlapped = use_overlapped
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = self._shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_method_name = forward_context.moe_comm_method_name
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_out = super().forward(
hidden_states=hidden_states,
router_logits=router_logits,
)
return shared_out, fused_out
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading

View File

@@ -16,3 +16,4 @@
#
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa

View File

@@ -0,0 +1,21 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.model_executor.models import deepseek_v2, llama4
from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE
deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE
llama4.SharedFusedMoE = AscendSharedFusedMoE