From bf669606eb84e12dc1ecf15b23c1eedab204d660 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 23 Jan 2025 00:39:38 +0800 Subject: [PATCH] feat: integrate bmm_fp8 kernel into sgl-kernel (#3056) --- sgl-kernel/setup.py | 12 +++- sgl-kernel/src/sgl-kernel/__init__.py | 2 + .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 6 ++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 61 +++++++++++++++---- sgl-kernel/src/sgl-kernel/ops/utils.py | 19 ++++++ sgl-kernel/tests/test_bmm_fp8.py | 43 +++++++++++++ 6 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/ops/utils.py create mode 100644 sgl-kernel/tests/test_bmm_fp8.py diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index b9324c355..81cd96e99 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -62,12 +62,22 @@ nvcc_flags = [ "-std=c++17", "-use_fast_math", "-DFLASHINFER_ENABLE_F16", - "-DFLASHINFER_ENABLE_BF16", ] if cuda_version >= (12, 0) and sm_version >= 90: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") +if sm_version >= 90: + nvcc_flags.extend( + [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + ] + ) +if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + for flag in [ "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 0bcd77aad..86c4f34d3 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,4 +1,5 @@ from sgl_kernel.ops import ( + bmm_fp8, custom_dispose, custom_reduce, fused_add_rmsnorm, @@ -18,6 +19,7 @@ from sgl_kernel.ops import ( ) __all__ = [ + "bmm_fp8", "custom_dispose", "custom_reduce", "fused_add_rmsnorm", diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 985cfa173..12df07471 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); // gelu and mul void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); +// bmm fp8 +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)"); // gelu and mul m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); + // bmm fp8 + m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 5bfde5df2..cea3436b6 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -2,6 +2,7 @@ from typing import Optional import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce +from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8 from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul @@ -21,10 +22,7 @@ from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul - - -def get_cuda_stream(device: torch.device) -> int: - return torch.cuda.current_stream(device).cuda_stream +from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream def init_custom_reduce( @@ -101,7 +99,7 @@ def rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + _rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -109,7 +107,7 @@ def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + _fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) def gemma_rmsnorm( @@ -121,7 +119,7 @@ def gemma_rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + _gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + _gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: @@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _silu_and_mul(out, input, get_cuda_stream(device)) + _silu_and_mul(out, input, _get_cuda_stream(device)) return out @@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te dtype=input.dtype, ) with input.device as device: - _gelu_tanh_and_mul(out, input, get_cuda_stream(device)) + _gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) return out @@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _gelu_and_mul(out, input, get_cuda_stream(device)) + _gelu_and_mul(out, input, _get_cuda_stream(device)) return out + + +def _bmm_fp8_internal( + workspace_buffer: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + _bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + _get_cuda_stream(device), + ) + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) + return out diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py new file mode 100644 index 000000000..af5fccbb7 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -0,0 +1,19 @@ +from typing import Dict, Tuple + +import torch + + +def _get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py new file mode 100644 index 000000000..e0be92896 --- /dev/null +++ b/sgl-kernel/tests/test_bmm_fp8.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import bmm_fp8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) + + reference = torch.bmm(input, mat2) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__])