[CI] fix race condition problem (#353)
fix race condition problem Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ Run `pytest tests/ops/test_fused_moe.py`.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe import fused_experts
|
||||
@@ -67,30 +68,35 @@ def test_fused_experts(
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device=device, dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
score = torch.softmax(score, dim=-1, dtype=dtype)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
output = fused_experts(a, w1, w2, topk_weights, topk_ids, topk, e_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
e_map)
|
||||
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
|
||||
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
|
||||
|
||||
Reference in New Issue
Block a user