[Fix] Fix SharedFusedMoE (#2817)
### What this PR does / why we need it?
Really strange that `register_oot` doesn't work with `SharedFusedMoE`,
so we have to add this patch, for now.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
This PR won't have any effect in DeepSeek since we currently still stick
with the old `CustomDeepseekV2`.
- vLLM version: v0.10.1.1
- vLLM main:
0cdd213641
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -20,7 +20,8 @@ from typing import Callable, Optional
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
@@ -373,6 +374,21 @@ class AscendFusedMoE(FusedMoE):
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
||||
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
||||
the outputs are already aggregated across tensor parallel ranks in the
|
||||
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
||||
outputs since each rank only has partial outputs.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
@@ -415,6 +431,38 @@ class AscendFusedMoE(FusedMoE):
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(AscendFusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
use_overlapped: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._shared_experts = shared_experts
|
||||
self.use_overlapped = use_overlapped
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
shared_out = self._shared_experts(hidden_states)
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
|
||||
fused_out = super().forward(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
|
||||
|
||||
@@ -16,3 +16,4 @@
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_shared_fused_moe # noqa
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm.model_executor.models import deepseek_v2, llama4
|
||||
|
||||
from vllm_ascend.ops.common_fused_moe import AscendSharedFusedMoE
|
||||
|
||||
deepseek_v2.SharedFusedMoE = AscendSharedFusedMoE
|
||||
llama4.SharedFusedMoE = AscendSharedFusedMoE
|
||||
Reference in New Issue
Block a user