diff --git a/.gitmodules b/.gitmodules index ed7603bfd..265d8d989 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "sgl-kernel/3rdparty/flashinfer"] path = sgl-kernel/3rdparty/flashinfer url = https://github.com/flashinfer-ai/flashinfer.git +[submodule "sgl-kernel/3rdparty/deepgemm"] + path = sgl-kernel/3rdparty/deepgemm + url = https://github.com/deepseek-ai/DeepGEMM diff --git a/sgl-kernel/3rdparty/deepgemm b/sgl-kernel/3rdparty/deepgemm new file mode 160000 index 000000000..5e4badc57 --- /dev/null +++ b/sgl-kernel/3rdparty/deepgemm @@ -0,0 +1 @@ +Subproject commit 5e4badc5777ff1371358ec73e0aac91229181fec diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index ffa798d14..a9eb9c5cc 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -11,11 +11,11 @@ else fi docker run --rm \ - -v "$(pwd)":/sgl-kernel \ + -v $(pwd):/sgl-kernel \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \ bash -c " ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ - ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ export SGL_KERNEL_ENABLE_BF16=1 && \ @@ -24,5 +24,6 @@ docker run --rm \ mkdir -p /usr/lib/x86_64-linux-gnu/ && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ - ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel + ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \ + PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel " diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 712ed36cf..da3caaf5f 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=61.0", + "setuptools>=75.0", "scikit-build-core>=0.10", "torch==2.5.1", "wheel", diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 47c1531a4..0cf88ff06 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -14,11 +14,13 @@ # ============================================================================== import os +import shutil import sys from pathlib import Path import torch from setuptools import find_packages, setup +from setuptools.command.build_py import build_py from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() @@ -52,6 +54,7 @@ operator_namespace = "sgl_kernel" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" +deepgemm = root / "3rdparty" / "deepgemm" include_dirs = [ root / "include", root / "csrc", @@ -63,6 +66,51 @@ include_dirs = [ "cublas", ] + +class CustomBuildPy(build_py): + def run(self): + self.copy_deepgemm_to_build_lib() + self.make_jit_include_symlinks() + build_py.run(self) + + def make_jit_include_symlinks(self): + # Make symbolic links of third-party include directories + build_include_dir = os.path.join(self.build_lib, "deep_gemm/include") + os.makedirs(build_include_dir, exist_ok=True) + + third_party_include_dirs = [ + cutlass.resolve() / "include" / "cute", + cutlass.resolve() / "include" / "cutlass", + ] + + for d in third_party_include_dirs: + dirname = str(d).split("/")[-1] + src_dir = d + dst_dir = f"{build_include_dir}/{dirname}" + assert os.path.exists(src_dir) + if os.path.exists(dst_dir): + assert os.path.islink(dst_dir) + os.unlink(dst_dir) + os.symlink(src_dir, dst_dir, target_is_directory=True) + + def copy_deepgemm_to_build_lib(self): + """ + This function copies DeepGemm to python's site-packages + """ + dst_dir = os.path.join(self.build_lib, "deep_gemm") + os.makedirs(dst_dir, exist_ok=True) + + # Copy deepgemm/deep_gemm to the build directory + src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm") + + # Remove existing directory if it exists + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + + # Copy the directory + shutil.copytree(src_dir, dst_dir) + + nvcc_flags = [ "-DNDEBUG", f"-DOPERATOR_NAMESPACE={operator_namespace}", @@ -175,6 +223,9 @@ setup( packages=find_packages(where="python"), package_dir={"": "python"}, ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, + cmdclass={ + "build_ext": BuildExtension.with_options(use_ninja=True), + "build_py": CustomBuildPy, + }, options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/sgl-kernel/tests/test_deep_gemm.py b/sgl-kernel/tests/test_deep_gemm.py new file mode 100644 index 000000000..bb5684935 --- /dev/null +++ b/sgl-kernel/tests/test_deep_gemm.py @@ -0,0 +1,263 @@ +import os +import random +import unittest +from typing import Any, Tuple + +import deep_gemm +import torch +from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor, jit + +""" +fork deepgemm/tests/test_core.py +""" + + +def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(1) % 128 == 0 + m, n = x.shape + x_view = x.view(m, -1, 128) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, n + ), (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def construct(m: int, k: int, n: int) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) + out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) + ref_out = x @ y.t() + + x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y) + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +def construct_grouped( + num_groups: int, m: int, k: int, n: int, is_masked: bool +) -> Tuple[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + x = torch.randn((num_groups, m, k), device="cuda", dtype=torch.bfloat16) + y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16) + out = torch.empty((num_groups, m, n), device="cuda", dtype=torch.bfloat16) + ref_out = torch.einsum("gmk,gnk->gmn", x, y) + + assert m % 4 == 0, f"TMA alignment error: {m}" + x_fp8 = ( + torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, k // 128), device="cuda", dtype=torch.float), + ) + y_fp8 = ( + torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty( + (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float + ), + ) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i]) + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i]) + + # For non-masked input, we must merge the group and M dims + if not is_masked: + x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1]) + out, ref_out = out.view(-1, n), ref_out.view(-1, n) + + # Transpose earlier so that the testing will not trigger transposing kernels + x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1])) + return x_fp8, y_fp8, out, ref_out + + +class TestDeepGemmCore(unittest.TestCase): + @classmethod + def setUpClass(cls): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.manual_seed(0) + random.seed(0) + + print("Library path:") + print(f" > {deep_gemm.__path__}\n") + + def test_gemm(self): + print("Testing GEMM:") + for m in (64, 128, 4096): + for k, n in [ + (7168, 2112), + (1536, 24576), + (512, 32768), + (16384, 7168), + (7168, 4096), + (2048, 7168), + ]: + x_fp8, y_fp8, out, ref_out = construct(m, k, n) + deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out) + diff = calc_diff(out, ref_out) + self.assertTrue(diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}") + + def test_m_grouped_gemm_contiguous(self): + print("Testing grouped contiguous GEMM:") + + for num_groups, m, k, n in ( + (4, 8192, 7168, 4096), + (4, 8192, 2048, 7168), + (8, 4096, 7168, 4096), + (8, 4096, 2048, 7168), + ): + # TODO: make a stronger test + x_fp8, y_fp8, out, ref_out = construct_grouped( + num_groups, m, k, n, is_masked=False + ) + m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) + m_indices = ( + m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1) + ) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + x_fp8, y_fp8, out, m_indices + ) + diff = calc_diff(out, ref_out) + self.assertTrue(diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}") + + def test_m_grouped_gemm_masked(self): + print("Testing grouped masked GEMM:") + + for num_groups, m in ((1, 1024), (2, 512), (4, 256)): + for k, n in ( + (7168, 4096), + (2048, 7168), + ): + # Test correctness + masked_m_candidates = list( + filter( + lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384) + ) + ) + for i in range(10): + x_fp8, y_fp8, out, ref_out = construct_grouped( + num_groups, m, k, n, is_masked=True + ) + masked_m = torch.empty( + (num_groups,), device="cuda", dtype=torch.int + ) + for j in range(num_groups): + masked_m[j] = random.choice(masked_m_candidates) + expected_m = min(int(masked_m.float().mean()) + 1, m) + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + x_fp8, y_fp8, out, masked_m, expected_m + ) + for j in range(num_groups): + diff = calc_diff( + out[j, : masked_m[j].item()], + ref_out[j, : masked_m[j].item()], + ) + self.assertTrue( + diff < 0.001, + f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}", + ) + + +""" +fork deepgemm/tests/test_jit.py +""" + + +class Capture: + def __init__(self) -> None: + self.read_fd = None + self.write_fd = None + self.saved_stdout = None + self.captured = None + + def __enter__(self) -> Any: + self.read_fd, self.write_fd = os.pipe() + self.saved_stdout = os.dup(1) + os.dup2(self.write_fd, 1) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + os.dup2(self.saved_stdout, 1) + os.close(self.write_fd) + with os.fdopen(self.read_fd, "r") as f: + self.captured = f.read() + + def capture(self) -> str: + return self.captured + + +class TestDeepGemmJIT(unittest.TestCase): + def test_jit(self): + # Runtime + print(f"NVCC compiler: {jit.get_nvcc_compiler()}\n") + + # Templates + print("Generated code:") + args = ( + ("lhs", torch.float8_e4m3fn), + ("rhs", torch.float8_e4m3fn), + ("scale", torch.float), + ("out", torch.bfloat16), + ("enable_double_streams", bool), + ("stream", torch.cuda.Stream), + ) + body = "\n" + body += "std::cout << reinterpret_cast(lhs) << std::endl;\n" + body += "std::cout << reinterpret_cast(rhs) << std::endl;\n" + body += "std::cout << reinterpret_cast(scale) << std::endl;\n" + body += "std::cout << reinterpret_cast(out) << std::endl;\n" + body += "std::cout << enable_double_streams << std::endl;\n" + body += "std::cout << reinterpret_cast(stream) << std::endl;\n" + code = jit.generate((), args, body) + print(code) + + # Build + print("Building ...") + func = jit.build("test_func", args, code) + + # Test correctness + print("Running ...") + fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device="cuda") + fp32_tensor = torch.empty((1,), dtype=torch.float, device="cuda") + bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device="cuda") + with Capture() as capture: + self.assertTrue( + func( + fp8_tensor, + fp8_tensor, + fp32_tensor, + bf16_tensor, + True, + torch.cuda.current_stream(), + ) + == 0 + ) + output = capture.capture() + ref_output = f"{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n" + self.assertTrue(output == ref_output, f"{output=}, {ref_output=}") + + +if __name__ == "__main__": + unittest.main()