feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)
This commit is contained in:
@@ -62,12 +62,22 @@ nvcc_flags = [
|
||||
"-std=c++17",
|
||||
"-use_fast_math",
|
||||
"-DFLASHINFER_ENABLE_F16",
|
||||
"-DFLASHINFER_ENABLE_BF16",
|
||||
]
|
||||
|
||||
if cuda_version >= (12, 0) and sm_version >= 90:
|
||||
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
|
||||
|
||||
if sm_version >= 90:
|
||||
nvcc_flags.extend(
|
||||
[
|
||||
"-DFLASHINFER_ENABLE_FP8",
|
||||
"-DFLASHINFER_ENABLE_FP8_E4M3",
|
||||
"-DFLASHINFER_ENABLE_FP8_E5M2",
|
||||
]
|
||||
)
|
||||
if sm_version >= 80:
|
||||
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
|
||||
|
||||
for flag in [
|
||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from sgl_kernel.ops import (
|
||||
bmm_fp8,
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fused_add_rmsnorm,
|
||||
@@ -18,6 +19,7 @@ from sgl_kernel.ops import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"bmm_fp8",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"fused_add_rmsnorm",
|
||||
|
||||
@@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
// gelu and mul
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// bmm fp8
|
||||
void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
|
||||
at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
|
||||
// gelu and mul
|
||||
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
|
||||
// bmm fp8
|
||||
m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
|
||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
|
||||
@@ -21,10 +22,7 @@ from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
|
||||
|
||||
|
||||
def get_cuda_stream(device: torch.device) -> int:
|
||||
return torch.cuda.current_stream(device).cuda_stream
|
||||
from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream
|
||||
|
||||
|
||||
def init_custom_reduce(
|
||||
@@ -101,7 +99,7 @@ def rmsnorm(
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||
_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -109,7 +107,7 @@ def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||
_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
@@ -121,7 +119,7 @@ def gemma_rmsnorm(
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||
_gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
@@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_silu_and_mul(out, input, get_cuda_stream(device))
|
||||
_silu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_tanh_and_mul(out, input, get_cuda_stream(device))
|
||||
_gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
@@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_and_mul(out, input, get_cuda_stream(device))
|
||||
_gelu_and_mul(out, input, _get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def _bmm_fp8_internal(
|
||||
workspace_buffer: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
D: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
) -> None:
|
||||
with A.device as device:
|
||||
cublas_handle = torch.cuda.current_blas_handle()
|
||||
_bmm_fp8(
|
||||
A,
|
||||
B,
|
||||
D,
|
||||
A_scale,
|
||||
B_scale,
|
||||
workspace_buffer,
|
||||
cublas_handle,
|
||||
_get_cuda_stream(device),
|
||||
)
|
||||
|
||||
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty(
|
||||
(A.shape[0], A.shape[1], B.shape[2]),
|
||||
device=A.device,
|
||||
dtype=dtype,
|
||||
)
|
||||
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
||||
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
||||
return out
|
||||
|
||||
19
sgl-kernel/src/sgl-kernel/ops/utils.py
Normal file
19
sgl-kernel/src/sgl-kernel/ops/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _get_cuda_stream(device: torch.device) -> int:
|
||||
return torch.cuda.current_stream(device).cuda_stream
|
||||
|
||||
|
||||
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
||||
key = (name, device)
|
||||
buf = _cache_buf.get(key)
|
||||
if buf is None:
|
||||
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
||||
_cache_buf[key] = buf
|
||||
return buf
|
||||
43
sgl-kernel/tests/test_bmm_fp8.py
Normal file
43
sgl-kernel/tests/test_bmm_fp8.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sgl_kernel import bmm_fp8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
|
||||
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
|
||||
def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
|
||||
if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
|
||||
pytest.skip("Invalid combination: both input and mat2 are e5m2")
|
||||
|
||||
input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
|
||||
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
|
||||
|
||||
# mat2 row major -> column major
|
||||
mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(
|
||||
-2, -1
|
||||
)
|
||||
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
|
||||
|
||||
res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype)
|
||||
bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
|
||||
|
||||
reference = torch.bmm(input, mat2)
|
||||
cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
|
||||
assert cos_sim > 0.99
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user