[Feature] Apply Cublas Grouped Gemm kernel (#3629)
This commit is contained in:
49
sgl-kernel/tests/test_cublas_grouped_gemm.py
Normal file
49
sgl-kernel/tests/test_cublas_grouped_gemm.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from sgl_kernel import cublas_grouped_gemm
|
||||
|
||||
|
||||
def torch_grouped_gemm(a_array, b_array, out_dtype):
|
||||
c_array = []
|
||||
for a, b in zip(a_array, b_array):
|
||||
c_array.append(torch.matmul(a, b.t()).to(out_dtype))
|
||||
return c_array
|
||||
|
||||
|
||||
class TestGroupedGemm(unittest.TestCase):
|
||||
def _test_accuracy(self, Ms, Ns, Ks, out_dtype):
|
||||
group_count = len(Ms)
|
||||
a_array = []
|
||||
b_array = []
|
||||
c_array_cublas = []
|
||||
for i in range(group_count):
|
||||
M, N, K = Ms[i], Ns[i], Ks[i]
|
||||
a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5)
|
||||
b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5)
|
||||
c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype))
|
||||
|
||||
c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype)
|
||||
cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype)
|
||||
|
||||
for i in range(group_count):
|
||||
M, N, K = Ms[i], Ns[i], Ks[i]
|
||||
torch.testing.assert_close(c_array_torch[i], c_array_cublas[i])
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 16, 32, 256, 1024]
|
||||
Ns = [2, 16, 128, 256, 4096]
|
||||
Ks = [3, 16, 32, 512, 8192]
|
||||
out_dtypes = [torch.float16, torch.bfloat16]
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy(Ms, Ns, Ks, out_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if torch.cuda.is_available():
|
||||
cuda_version = tuple(map(int, torch.version.cuda.split(".")))
|
||||
if cuda_version >= (12, 5):
|
||||
unittest.main()
|
||||
else:
|
||||
print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")
|
||||
Reference in New Issue
Block a user