[Feature] use pytest for sgl-kernel (#4896)
This commit is contained in:
committed by
GitHub
parent
4ede6770cd
commit
9fccda3111
@@ -1,12 +1,13 @@
|
||||
import unittest
|
||||
import os
|
||||
import random
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fp8_blockwise_scaled_mm
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
@@ -23,7 +24,6 @@ def baseline_scaled_mm(
|
||||
out_dtype: Type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
||||
# in numpy simply stretches dimensions with an extent of 1 to match the
|
||||
# the target shape by repeating the data along that dimension (broadcasting)
|
||||
@@ -51,62 +51,44 @@ def baseline_scaled_mm(
|
||||
|
||||
scale_a = group_broadcast(scale_a, a.shape)
|
||||
scale_b = group_broadcast(scale_b, b.shape)
|
||||
|
||||
output = torch.mm(
|
||||
(scale_a * a.to(dtype=torch.float32)), (scale_b * b.to(dtype=torch.float32))
|
||||
).to(out_dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TestFp8Gemm(unittest.TestCase):
|
||||
def _test_accuracy_once(self, M, N, K, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
def _test_accuracy_once(M, N, K, out_dtype, device):
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
a_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
b_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
)
|
||||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
|
||||
|
||||
scale_a_group_shape = (1, 128)
|
||||
scale_b_group_shape = (128, 128)
|
||||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
|
||||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)
|
||||
|
||||
scale_a = torch.randn(scale_a_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_b = torch.randn(scale_b_shape, device=device, dtype=torch.float32) * 0.001
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16)
|
||||
o = baseline_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
o1 = fp8_blockwise_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype)
|
||||
|
||||
rtol = 0.02
|
||||
atol = 1
|
||||
torch.testing.assert_close(o, o1, rtol=rtol, atol=atol)
|
||||
print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK")
|
||||
|
||||
def test_accuracy(self):
|
||||
Ms = [1, 128, 512, 1024, 4096]
|
||||
Ns = [128, 512, 1024, 4096]
|
||||
Ks = [512, 1024, 4096, 8192, 16384]
|
||||
out_dtypes = [torch.bfloat16, torch.float16]
|
||||
for M in Ms:
|
||||
for N in Ns:
|
||||
for K in Ks:
|
||||
for out_dtype in out_dtypes:
|
||||
self._test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
@pytest.mark.parametrize("M", [1, 128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("N", [128, 512, 1024, 4096])
|
||||
@pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_accuracy(M, N, K, out_dtype):
|
||||
_test_accuracy_once(M, N, K, out_dtype, "cuda")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user