Files
sglang/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py
xutizhou 506c4928f5 feat: integrate deepgemm into EPMoE (#6821)
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: TianQiLin666666 <1834987979@qq.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
2025-06-23 01:38:58 -07:00

165 lines
3.7 KiB
Python

import itertools
import pytest
import torch
from sgl_kernel import ep_moe_post_reorder
from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
def create_test_tensors(
batch_size: int,
hidden_size: int,
topk: int,
start_expert_id: int,
end_expert_id: int,
dtype: torch.dtype,
device: torch.device,
):
down_output = torch.randn(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
# Ensure src2dst has no duplicate destinations to avoid race conditions
total_tokens = batch_size * topk
dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32)
src2dst = dst_indices.view(batch_size, topk)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
return down_output, src2dst, topk_ids, topk_weights
def run_cuda_kernel(
down_output: torch.Tensor,
output: torch.Tensor,
src2dst: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
topk: int,
):
ep_moe_post_reorder(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
)
return output
def run_triton_kernel(
down_output: torch.Tensor,
output: torch.Tensor,
src2dst: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
start_expert_id: int,
end_expert_id: int,
topk: int,
hidden_size: int,
):
batch_size = down_output.size(0)
block_size = 512
post_reorder_triton_kernel[(batch_size,)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
hidden_size,
0,
block_size,
)
return output
def assert_close(a, b):
a32, b32 = a.float(), b.float()
if a.dtype is torch.float16:
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-2)
elif a.dtype is torch.bfloat16:
torch.testing.assert_close(a32, b32, rtol=1e-4, atol=1e-1)
else:
torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-5)
@pytest.mark.parametrize(
"batch_size,hidden_size,topk",
list(itertools.product([32, 64], [128, 256, 512], [2, 4, 8])),
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
def test_ep_moe_post_reorder_vs_triton(
batch_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
):
device = torch.device("cuda")
start_expert_id = 0
end_expert_id = 15
(
down_output,
src2dst,
topk_ids,
topk_weights,
) = create_test_tensors(
batch_size,
hidden_size,
topk,
start_expert_id,
end_expert_id,
dtype,
device,
)
output_cuda = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
output_triton = torch.empty(batch_size, hidden_size, dtype=dtype, device=device)
cuda_output = run_cuda_kernel(
down_output,
output_cuda,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
)
triton_output = run_triton_kernel(
down_output,
output_triton,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
hidden_size,
)
assert_close(cuda_output, triton_output)
if __name__ == "__main__":
pytest.main([__file__])