[Feature] use pytest for sgl-kernel (#4896)

This commit is contained in:
Adarsh Shirawalmath
2025-03-30 23:06:52 +05:30
committed by GitHub
parent 4ede6770cd
commit 9fccda3111
10 changed files with 263 additions and 290 deletions

View File

@@ -1,5 +1,4 @@
import unittest
import pytest
import torch
from sgl_kernel import int8_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
@@ -18,39 +17,31 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias):
return o.to(out_dtype)
class TestInt8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
a = to_int8(torch.randn((M, K), device=device) * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
a = to_int8(torch.randn((M, K), device=device) * 5)
b = to_int8(torch.randn((N, K), device=device).t() * 5)
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
torch.testing.assert_close(o, o1)
torch.testing.assert_close(o, o2)
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
Ns = [16, 128, 512, 1024, 4096, 8192, 16384]
Ks = [512, 1024, 4096, 8192, 16384]
bias_opts = [True, False]
out_dtypes = [torch.float16, torch.bfloat16]
for M in Ms:
for N in Ns:
for K in Ks:
for with_bias in bias_opts:
for out_dtype in out_dtypes:
self._test_accuracy_once(
M, N, K, with_bias, out_dtype, "cuda"
)
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192])
@pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("with_bias", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_accuracy(M, N, K, with_bias, out_dtype):
_test_accuracy_once(M, N, K, with_bias, out_dtype, "cuda")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])