33 lines
899 B
Python
33 lines
899 B
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from sgl_kernel import dsv3_router_gemm
|
|
|
|
|
|
@pytest.mark.parametrize("num_tokens", [i + 1 for i in range(16)])
|
|
def test_dsv3_router_gemm(num_tokens):
|
|
num_experts = 256
|
|
hidden_dim = 7168
|
|
|
|
mat_a = torch.randn(
|
|
(num_tokens, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
|
).contiguous()
|
|
mat_b = torch.randn(
|
|
(num_experts, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
|
).contiguous()
|
|
output = torch.empty(
|
|
(num_tokens, num_experts), dtype=torch.float32, device="cuda"
|
|
).contiguous()
|
|
|
|
ref = F.linear(mat_a, mat_b).to(torch.float32)
|
|
|
|
output = dsv3_router_gemm(mat_a, mat_b)
|
|
|
|
assert torch.allclose(
|
|
output, ref, rtol=1e-2, atol=1e-3
|
|
), "Router GEMM output mismatch with torch.nn.functional.linear reference"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|