[CI] fix race condition problem (#353)

fix race condition problem

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-03-19 17:04:36 +08:00
committed by GitHub
parent 441a62e937
commit 663dca7578
7 changed files with 49 additions and 88 deletions

View File

@@ -22,6 +22,7 @@ Run `pytest tests/ops/test_fused_moe.py`.
import pytest
import torch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe import fused_experts
@@ -67,30 +68,35 @@ def test_fused_experts(
dtype: torch.dtype,
device: str,
):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
vllm_config = VllmConfig()
with set_current_vllm_config(vllm_config):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
score = torch.randn((m, e), device=device, dtype=dtype)
score = torch.randn((m, e), device=device, dtype=dtype)
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device=device,
dtype=torch.int32)
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
if ep_size > 1:
local_e = e // ep_size
e_ids = torch.randint(0,
e, (local_e, ),
device=device,
dtype=torch.int32)
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e,
device=device,
dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
score = torch.softmax(score, dim=-1, dtype=dtype)
topk_weights, topk_ids = torch.topk(score, topk)
topk_ids = topk_ids.to(torch.int32)
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
e_map)
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)