diff --git a/sgl-kernel/tests/test_deep_gemm.py b/sgl-kernel/tests/test_deep_gemm.py deleted file mode 100644 index bb5684935..000000000 --- a/sgl-kernel/tests/test_deep_gemm.py +++ /dev/null @@ -1,263 +0,0 @@ -import os -import random -import unittest -from typing import Any, Tuple - -import deep_gemm -import torch -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor, jit - -""" -fork deepgemm/tests/test_core.py -""" - - -def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( - m, n - ), (x_amax / 448.0).view(m, -1) - - -def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device - ) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( - x_view.size(0), x_view.size(2) - ) - - -def construct(m: int, k: int, n: int) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], - torch.Tensor, - torch.Tensor, -]: - x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) - y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) - out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) - ref_out = x @ y.t() - - x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -def construct_grouped( - num_groups: int, m: int, k: int, n: int, is_masked: bool -) -> Tuple[ - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, torch.Tensor], - torch.Tensor, - torch.Tensor, -]: - x = torch.randn((num_groups, m, k), device="cuda", dtype=torch.bfloat16) - y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16) - out = torch.empty((num_groups, m, n), device="cuda", dtype=torch.bfloat16) - ref_out = torch.einsum("gmk,gnk->gmn", x, y) - - assert m % 4 == 0, f"TMA alignment error: {m}" - x_fp8 = ( - torch.empty_like(x, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, k // 128), device="cuda", dtype=torch.float), - ) - y_fp8 = ( - torch.empty_like(y, dtype=torch.float8_e4m3fn), - torch.empty( - (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float - ), - ) - for i in range(num_groups): - x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) - y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) - - # For non-masked input, we must merge the group and M dims - if not is_masked: - x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) - out, ref_out = out.view(-1, n), ref_out.view(-1, n) - - # Transpose earlier so that the testing will not trigger transposing kernels - x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) - return x_fp8, y_fp8, out, ref_out - - -class TestDeepGemmCore(unittest.TestCase): - @classmethod - def setUpClass(cls): - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.manual_seed(0) - random.seed(0) - - print("Library path:") - print(f" > {deep_gemm.__path__}\n") - - def test_gemm(self): - print("Testing GEMM:") - for m in (64, 128, 4096): - for k, n in [ - (7168, 2112), - (1536, 24576), - (512, 32768), - (16384, 7168), - (7168, 4096), - (2048, 7168), - ]: - x_fp8, y_fp8, out, ref_out = construct(m, k, n) - deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) - diff = calc_diff(out, ref_out) - self.assertTrue(diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}") - - def test_m_grouped_gemm_contiguous(self): - print("Testing grouped contiguous GEMM:") - - for num_groups, m, k, n in ( - (4, 8192, 7168, 4096), - (4, 8192, 2048, 7168), - (8, 4096, 7168, 4096), - (8, 4096, 2048, 7168), - ): - # TODO: make a stronger test - x_fp8, y_fp8, out, ref_out = construct_grouped( - num_groups, m, k, n, is_masked=False - ) - m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) - m_indices = ( - m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) - ) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - x_fp8, y_fp8, out, m_indices - ) - diff = calc_diff(out, ref_out) - self.assertTrue(diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}") - - def test_m_grouped_gemm_masked(self): - print("Testing grouped masked GEMM:") - - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for k, n in ( - (7168, 4096), - (2048, 7168), - ): - # Test correctness - masked_m_candidates = list( - filter( - lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384) - ) - ) - for i in range(10): - x_fp8, y_fp8, out, ref_out = construct_grouped( - num_groups, m, k, n, is_masked=True - ) - masked_m = torch.empty( - (num_groups,), device="cuda", dtype=torch.int - ) - for j in range(num_groups): - masked_m[j] = random.choice(masked_m_candidates) - expected_m = min(int(masked_m.float().mean()) + 1, m) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - x_fp8, y_fp8, out, masked_m, expected_m - ) - for j in range(num_groups): - diff = calc_diff( - out[j, : masked_m[j].item()], - ref_out[j, : masked_m[j].item()], - ) - self.assertTrue( - diff < 0.001, - f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}", - ) - - -""" -fork deepgemm/tests/test_jit.py -""" - - -class Capture: - def __init__(self) -> None: - self.read_fd = None - self.write_fd = None - self.saved_stdout = None - self.captured = None - - def __enter__(self) -> Any: - self.read_fd, self.write_fd = os.pipe() - self.saved_stdout = os.dup(1) - os.dup2(self.write_fd, 1) - return self - - def __exit__(self, exc_type, exc_val, exc_tb) -> None: - os.dup2(self.saved_stdout, 1) - os.close(self.write_fd) - with os.fdopen(self.read_fd, "r") as f: - self.captured = f.read() - - def capture(self) -> str: - return self.captured - - -class TestDeepGemmJIT(unittest.TestCase): - def test_jit(self): - # Runtime - print(f"NVCC compiler: {jit.get_nvcc_compiler()}\n") - - # Templates - print("Generated code:") - args = ( - ("lhs", torch.float8_e4m3fn), - ("rhs", torch.float8_e4m3fn), - ("scale", torch.float), - ("out", torch.bfloat16), - ("enable_double_streams", bool), - ("stream", torch.cuda.Stream), - ) - body = "\n" - body += "std::cout << reinterpret_cast(lhs) << std::endl;\n" - body += "std::cout << reinterpret_cast(rhs) << std::endl;\n" - body += "std::cout << reinterpret_cast(scale) << std::endl;\n" - body += "std::cout << reinterpret_cast(out) << std::endl;\n" - body += "std::cout << enable_double_streams << std::endl;\n" - body += "std::cout << reinterpret_cast(stream) << std::endl;\n" - code = jit.generate((), args, body) - print(code) - - # Build - print("Building ...") - func = jit.build("test_func", args, code) - - # Test correctness - print("Running ...") - fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device="cuda") - fp32_tensor = torch.empty((1,), dtype=torch.float, device="cuda") - bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device="cuda") - with Capture() as capture: - self.assertTrue( - func( - fp8_tensor, - fp8_tensor, - fp32_tensor, - bf16_tensor, - True, - torch.cuda.current_stream(), - ) - == 0 - ) - output = capture.capture() - ref_output = f"{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n" - self.assertTrue(output == ref_output, f"{output=}, {ref_output=}") - - -if __name__ == "__main__": - unittest.main()