[MoE][Dist] Fix Qwen MoE accuracy bug in DP scenario (#1856)

### What this PR does / why we need it?
Fix Qwen MoE accuracy bug in DP scenario.

Now the implentment of `FusedMoE` in vLLM use `All2AllManager` to
manager different all2all algorithm branch. And the default branch use
`Multicast` in `dispatch` phase and `all_reduce` in `combine` phase,
which are not implented in vLLM-Ascend. This leading to invoking into a
default implentment in `base_communicator`, with empty `dispatch` and
`combine` operations, thus causing the accuracy issue on it.

This pr is a temporary workaround, refacting all2all in vLLM-Ascend
could be a better way.


- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-08-04 10:24:18 +08:00
committed by GitHub
parent f939381c6f
commit af04ee9e7a
3 changed files with 46 additions and 58 deletions

View File

@@ -20,6 +20,7 @@ import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.utils import logger
class NPUCommunicator(DeviceCommunicatorBase):
@@ -34,6 +35,12 @@ class NPUCommunicator(DeviceCommunicatorBase):
# init device according to rank
self.device = torch.npu.current_device()
if self.use_all2all:
from vllm.distributed.device_communicators.all2all import \
NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
@@ -73,3 +80,17 @@ class NPUCommunicator(DeviceCommunicatorBase):
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor
# TODO: Add ut for dispatch and combine
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states