From 2b3bfe432e886b4773ef5cfa33a0e69b2c7d5b6d Mon Sep 17 00:00:00 2001 From: weijinqian0 <1184188277@qq.com> Date: Sun, 30 Nov 2025 06:12:39 +0800 Subject: [PATCH] [bugfix] Repair the problem of moe model accuracy caused by version upgrade. (#4562) Repair the problem of moe model accuracy caused by version upgrade. Reason: The new version adds the "reduce_output" operation after "forward_impl". Then we have fully taken over the implementation of the FusedMoe module. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 Co-authored-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe/fused_moe.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index f3e3d156..a181f2cb 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ import os.path from typing import Any, Callable, Optional import torch +import torch.nn.functional as F import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, @@ -292,6 +293,32 @@ class AscendFusedMoE(FusedMoE): return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( final_hidden_states) + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad( + hidden_states, + (0, self.hidden_size - og_hidden_states), + mode="constant", + value=0.0, + ) + if self.shared_experts is None: + fused_output = torch.ops.vllm.moe_forward(hidden_states, + router_logits, + self.layer_name) + return fused_output[..., :og_hidden_states] + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return ( + shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states], + ) + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None