[MoE][Dist] Fix Qwen MoE accuracy bug in DP scenario (#1856)

### What this PR does / why we need it?
Fix Qwen MoE accuracy bug in DP scenario.

Now the implentment of `FusedMoE` in vLLM use `All2AllManager` to
manager different all2all algorithm branch. And the default branch use
`Multicast` in `dispatch` phase and `all_reduce` in `combine` phase,
which are not implented in vLLM-Ascend. This leading to invoking into a
default implentment in `base_communicator`, with empty `dispatch` and
`combine` operations, thus causing the accuracy issue on it.

This pr is a temporary workaround, refacting all2all in vLLM-Ascend
could be a better way.


- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-08-04 10:24:18 +08:00
committed by GitHub
parent f939381c6f
commit af04ee9e7a
3 changed files with 46 additions and 58 deletions

View File

@@ -18,15 +18,11 @@
#
import gc
import multiprocessing
import signal
import subprocess
import sys
import time
from multiprocessing import Queue
import lm_eval
import pytest
import requests
import torch
SERVER_HOST = "127.0.0.1"
@@ -36,7 +32,7 @@ COMPLETIONS_URL = f"http://{SERVER_HOST}:{SERVER_PORT}/v1/completions"
# pre-trained model path on Hugging Face.
# Qwen/Qwen2.5-0.5B-Instruct: accuracy test for DP.
# Qwen/Qwen3-30B-A3B: accuracy test for EP.
# Qwen/Qwen3-30B-A3B: accuracy test for EP and DP.
# deepseek-ai/DeepSeek-V2-Lite: accuracy test for TP.
MODEL_NAME = ["Qwen/Qwen3-30B-A3B", "deepseek-ai/DeepSeek-V2-Lite"]
@@ -145,58 +141,27 @@ def test_lm_eval_accuracy(monkeypatch: pytest.MonkeyPatch, model):
f"Expected: {EXPECTED_VALUE[model]}±{RTOL} | Measured: {result}"
@pytest.mark.parametrize("max_tokens", [10])
@pytest.mark.parametrize("model", ["Qwen/Qwen2.5-0.5B-Instruct"])
def test_lm_eval_accuracy_dp(model, max_tokens):
log_file = open("accuracy_pd.log", "a+")
cmd = [
"vllm", "serve", model, "--max_model_len", "4096",
"--tensor_parallel_size", "2", "--data_parallel_size", "2"
]
server_proc = subprocess.Popen(cmd,
stdout=log_file,
stderr=subprocess.DEVNULL)
DP_DENSCE_MODEL = ["Qwen/Qwen2.5-0.5B-Instruct"]
DP_MOE_MOEDL = ["Qwen/Qwen3-30B-A3B"]
try:
for _ in range(300):
try:
r = requests.get(HEALTH_URL, timeout=1)
if r.status_code == 200:
break
except requests.exceptions.RequestException:
pass
time.sleep(1)
else:
log_file.flush()
log_file.seek(0)
log_content = log_file.read()
pytest.fail(
f"vLLM serve did not become healthy after 300s: {HEALTH_URL}\n"
f"==== vLLM Serve Log Start ===\n{log_content}\n==== vLLM Serve Log End ==="
)
DP_MORE_ARGS = {
"Qwen/Qwen2.5-0.5B-Instruct":
"tensor_parallel_size=2,data_parallel_size=2",
"Qwen/Qwen3-30B-A3B":
"tensor_parallel_size=2,data_parallel_size=2,enable_expert_parallel=True,max_model_len=1024,enforce_eager=True",
}
prompt = "bejing is a"
payload = {
"prompt": prompt,
"max_tokens": max_tokens,
"sampling_params": {
"temperature": 0.0,
"top_p": 1.0,
"seed": 123
}
}
resp = requests.post(COMPLETIONS_URL, json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
generated = data["choices"][0]["text"].strip()
expected = "city in north china, it has many famous attractions"
assert generated == expected, f"Expected `{expected}`, got `{generated}`"
finally:
server_proc.send_signal(signal.SIGINT)
try:
server_proc.wait(timeout=10)
except subprocess.TimeoutExpired:
server_proc.kill()
server_proc.wait()
@pytest.mark.parametrize("model", DP_DENSCE_MODEL)
def test_lm_eval_accuracy_dp(model):
result_queue: Queue[float] = multiprocessing.Queue()
p = multiprocessing.Process(target=run_test,
args=(result_queue, model,
MAX_MODEL_LEN[model], MODEL_TYPE[model],
DP_MORE_ARGS[model]))
p.start()
p.join()
result = result_queue.get()
print(result)
assert (EXPECTED_VALUE[model] - RTOL < result < EXPECTED_VALUE[model] + RTOL), \
f"Expected: {EXPECTED_VALUE[model]}±{RTOL} | Measured: {result}"

View File

@@ -27,7 +27,7 @@ from unittest.mock import patch
import pytest
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen3-30B-A3B"]
@pytest.mark.parametrize("model", MODELS)
@@ -54,6 +54,8 @@ def test_data_parallel_inference(model, max_tokens):
"--trust-remote-code",
"--enforce-eager",
]
if model == "Qwen/Qwen3-30B-A3B":
cmd.append("--enable-expert-parallel")
print(f"Running subprocess: {' '.join(cmd)}")
proc = subprocess.run(cmd,

View File

@@ -20,6 +20,7 @@ import torch
import torch.distributed as dist
from vllm.distributed.device_communicators.base_device_communicator import \
DeviceCommunicatorBase
from vllm.utils import logger
class NPUCommunicator(DeviceCommunicatorBase):
@@ -34,6 +35,12 @@ class NPUCommunicator(DeviceCommunicatorBase):
# init device according to rank
self.device = torch.npu.current_device()
if self.use_all2all:
from vllm.distributed.device_communicators.all2all import \
NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def all_to_all(self,
input_: torch.Tensor,
scatter_dim: int = 0,
@@ -73,3 +80,17 @@ class NPUCommunicator(DeviceCommunicatorBase):
dist.all_to_all(output_list, input_list, group=self.device_group)
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
return output_tensor
# TODO: Add ut for dispatch and combine
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
return hidden_states