[bugfix] Repair the problem of moe model accuracy caused by version upgrade. (#4562)
Repair the problem of moe model accuracy caused by version upgrade. Reason: The new version adds the "reduce_output" operation after "forward_impl". Then we have fully taken over the implementation of the FusedMoe module. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: weijinqian_v1 <weijinqian@huawei.com> Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import os.path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
@@ -292,6 +293,32 @@ class AscendFusedMoE(FusedMoE):
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(
|
||||
hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
mode="constant",
|
||||
value=0.0,
|
||||
)
|
||||
if self.shared_experts is None:
|
||||
fused_output = torch.ops.vllm.moe_forward(hidden_states,
|
||||
router_logits,
|
||||
self.layer_name)
|
||||
return fused_output[..., :og_hidden_states]
|
||||
else:
|
||||
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
|
||||
hidden_states, router_logits, self.layer_name)
|
||||
return (
|
||||
shared_output[..., :og_hidden_states],
|
||||
fused_output[..., :og_hidden_states],
|
||||
)
|
||||
|
||||
def forward_impl(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
Reference in New Issue
Block a user