[Feature] use pytest for sgl-kernel (#4896)

This commit is contained in:
Adarsh Shirawalmath
2025-03-30 23:06:52 +05:30
committed by GitHub
parent 4ede6770cd
commit 9fccda3111
10 changed files with 263 additions and 290 deletions

View File

@@ -1,49 +1,40 @@
import unittest
import pytest
import torch
from sgl_kernel import cublas_grouped_gemm
def torch_grouped_gemm(a_array, b_array, out_dtype):
c_array = []
for a, b in zip(a_array, b_array):
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
return c_array
return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]
class TestGroupedGemm(unittest.TestCase):
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
group_count = len(Ms)
a_array = []
b_array = []
c_array_cublas = []
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
skip_condition = not torch.cuda.is_available() or (
torch.version.cuda is None
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
)
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
for i in range(group_count):
M, N, K = Ms[i], Ns[i], Ks[i]
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
@pytest.mark.skipif(
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
expected = torch.matmul(a, b.t()).to(out_dtype)
def test_accuracy(self):
Ms = [1, 16, 32, 256, 1024]
Ns = [2, 16, 128, 256, 4096]
Ks = [3, 16, 32, 512, 8192]
out_dtypes = [torch.float16, torch.bfloat16]
for out_dtype in out_dtypes:
self._test_accuracy(Ms, Ns, Ks, out_dtype)
a_array = [a]
b_array = [b]
c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]
result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)
torch.testing.assert_close(result_torch, expected)
torch.testing.assert_close(c_array[0], expected)
if __name__ == "__main__":
if torch.cuda.is_available():
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
if cuda_version >= (12, 5):
unittest.main()
else:
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
pytest.main([__file__])