From 9d9b482a392598fc342ee449835af5535ccc772f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 22 Jan 2025 23:25:45 +0800 Subject: [PATCH] feat: integrate activation kernels into sgl-kernel (#3053) --- sgl-kernel/src/sgl-kernel/__init__.py | 6 ++ .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 15 +++++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 61 +++++++++++++++++++ sgl-kernel/tests/test_activation.py | 38 ++++++++++++ 4 files changed, 120 insertions(+) create mode 100644 sgl-kernel/tests/test_activation.py diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index bdbc0ce84..0bcd77aad 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -2,6 +2,8 @@ from sgl_kernel.ops import ( custom_dispose, custom_reduce, fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, gemma_fused_add_rmsnorm, gemma_rmsnorm, get_graph_buffer_ipc_meta, @@ -12,12 +14,15 @@ from sgl_kernel.ops import ( rmsnorm, rotary_embedding, sampling_scaling_penalties, + silu_and_mul, ) __all__ = [ "custom_dispose", "custom_reduce", "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", "gemma_fused_add_rmsnorm", "gemma_rmsnorm", "get_graph_buffer_ipc_meta", @@ -28,4 +33,5 @@ __all__ = [ "rmsnorm", "rotary_embedding", "sampling_scaling_penalties", + "silu_and_mul", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 8f9d1ae53..d9aaa41b8 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -43,6 +43,15 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +// silu and mul +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu tanh and mul +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu and mul +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -66,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)"); // fused gemma rms norm m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)"); + // silu and mul + m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)"); + // gelu tanh and mul + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)"); + // gelu and mul + m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index bbfd76878..5bfde5df2 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -4,6 +4,8 @@ import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm +from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul +from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm from sgl_kernel.ops._kernels import ( @@ -18,6 +20,7 @@ from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding from sgl_kernel.ops._kernels import ( sampling_scaling_penalties as _sampling_scaling_penalties, ) +from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul def get_cuda_stream(device: torch.device) -> int: @@ -127,3 +130,61 @@ def gemma_fused_add_rmsnorm( ) -> None: with input.device as device: _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + + +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" + assert ( + input.shape[:-1] == output.shape[:-1] + ), f"{input.shape[:-1]} != {output.shape[:-1]}" + assert ( + input.shape[-1] == 2 * output.shape[-1] + ), f"{input.shape[-1]} != {2 * output.shape[-1]}" + + +def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _silu_and_mul(out, input, get_cuda_stream(device)) + return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _gelu_tanh_and_mul(out, input, get_cuda_stream(device)) + return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + _gelu_and_mul(out, input, get_cuda_stream(device)) + return out diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py new file mode 100644 index 000000000..f71f36b51 --- /dev/null +++ b/sgl-kernel/tests/test_activation.py @@ -0,0 +1,38 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_silu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y = sgl_kernel.silu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = sgl_kernel.gelu_tanh_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = sgl_kernel.gelu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +test_fused_silu_mul(128, 1, 1)