diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index f3e3d156..a181f2cb 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ import os.path from typing import Any, Callable, Optional import torch +import torch.nn.functional as F import torch_npu from vllm.config import get_current_vllm_config from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, @@ -292,6 +293,32 @@ class AscendFusedMoE(FusedMoE): return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( final_hidden_states) + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + og_hidden_states = hidden_states.shape[-1] + if self.hidden_size != og_hidden_states: + hidden_states = F.pad( + hidden_states, + (0, self.hidden_size - og_hidden_states), + mode="constant", + value=0.0, + ) + if self.shared_experts is None: + fused_output = torch.ops.vllm.moe_forward(hidden_states, + router_logits, + self.layer_name) + return fused_output[..., :og_hidden_states] + else: + shared_output, fused_output = torch.ops.vllm.moe_forward_shared( + hidden_states, router_logits, self.layer_name) + return ( + shared_output[..., :og_hidden_states], + fused_output[..., :og_hidden_states], + ) + def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None