[Feature] use pytest for sgl-kernel (#4896)

This commit is contained in:
Adarsh Shirawalmath
2025-03-30 23:06:52 +05:30
committed by GitHub
parent 4ede6770cd
commit 9fccda3111
10 changed files with 263 additions and 290 deletions

View File

@@ -1,12 +1,13 @@
import unittest
import os
import random
from typing import Optional, Type
import pytest
import torch
from sgl_kernel import fp8_blockwise_scaled_mm
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
@@ -23,7 +24,6 @@ def baseline_scaled_mm(
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
# in numpy simply stretches dimensions with an extent of 1 to match the
# the target shape by repeating the data along that dimension (broadcasting)
@@ -51,62 +51,44 @@ def baseline_scaled_mm(
scale_a = group_broadcast(scale_a, a.shape)
scale_b = group_broadcast(scale_b, b.shape)
output = torch.mm(
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
).to(out_dtype)
if bias is not None:
output = output + bias
return output
class TestFp8Gemm(unittest.TestCase):
def _test_accuracy_once(self, M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
def _test_accuracy_once(M, N, K, out_dtype, device):
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
a_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
b_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
)
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
rtol = 0.02
atol = 1
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
def test_accuracy(self):
Ms = [1, 128, 512, 1024, 4096]
Ns = [128, 512, 1024, 4096]
Ks = [512, 1024, 4096, 8192, 16384]
out_dtypes = [torch.bfloat16, torch.float16]
for M in Ms:
for N in Ns:
for K in Ks:
for out_dtype in out_dtypes:
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
def test_accuracy(M, N, K, out_dtype):
_test_accuracy_once(M, N, K, out_dtype, "cuda")
if __name__ == "__main__":
unittest.main()
pytest.main([__file__])