[Refactor][MoE] Reuse vLLM's all_reduce logic (#5189)

### What this PR does / why we need it?
Move all_reduce logic to AscendFusedMoE.forward, reuse vLLM's logic.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
weichen
2025-12-23 18:53:48 +08:00
committed by GitHub
parent 8ae7fca947
commit ffe51eedd6
3 changed files with 1 additions and 38 deletions

View File

@@ -18,7 +18,6 @@ import os.path
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
@@ -292,32 +291,6 @@ class AscendFusedMoE(FusedMoE):
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
def forward(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
hidden_states = F.pad(
hidden_states,
(0, self.hidden_size - og_hidden_states),
mode="constant",
value=0.0,
)
if self.shared_experts is None:
fused_output = torch.ops.vllm.moe_forward(hidden_states,
router_logits,
self.layer_name)
return fused_output[..., :og_hidden_states]
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name)
return (
shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states],
)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None