feat: integrate activation kernels into sgl-kernel (#3053)
This commit is contained in:
@@ -2,6 +2,8 @@ from sgl_kernel.ops import (
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
gelu_tanh_and_mul,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
@@ -12,12 +14,15 @@ from sgl_kernel.ops import (
|
||||
rmsnorm,
|
||||
rotary_embedding,
|
||||
sampling_scaling_penalties,
|
||||
silu_and_mul,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"fused_add_rmsnorm",
|
||||
"gelu_and_mul",
|
||||
"gelu_tanh_and_mul",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
@@ -28,4 +33,5 @@ __all__ = [
|
||||
"rmsnorm",
|
||||
"rotary_embedding",
|
||||
"sampling_scaling_penalties",
|
||||
"silu_and_mul",
|
||||
]
|
||||
|
||||
@@ -43,6 +43,15 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||
int64_t cuda_stream);
|
||||
|
||||
// silu and mul
|
||||
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// gelu tanh and mul
|
||||
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
// gelu and mul
|
||||
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -66,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
|
||||
// fused gemma rms norm
|
||||
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
|
||||
// silu and mul
|
||||
m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
|
||||
// gelu tanh and mul
|
||||
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
|
||||
// gelu and mul
|
||||
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ import torch
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
|
||||
from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
|
||||
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
@@ -18,6 +20,7 @@ from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
|
||||
from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
|
||||
|
||||
|
||||
def get_cuda_stream(device: torch.device) -> int:
|
||||
@@ -127,3 +130,61 @@ def gemma_fused_add_rmsnorm(
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||
|
||||
|
||||
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
|
||||
assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
|
||||
assert (
|
||||
input.shape[:-1] == output.shape[:-1]
|
||||
), f"{input.shape[:-1]} != {output.shape[:-1]}"
|
||||
assert (
|
||||
input.shape[-1] == 2 * output.shape[-1]
|
||||
), f"{input.shape[-1]} != {2 * output.shape[-1]}"
|
||||
|
||||
|
||||
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_silu_and_mul(out, input, get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_tanh_and_mul(out, input, get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
|
||||
if input.shape[-1] * input.dtype.itemsize % 16 != 0:
|
||||
raise ValueError("The pointers must be multiple of 16 bytes.")
|
||||
if out is not None:
|
||||
_check_shape(input, out)
|
||||
else:
|
||||
out = torch.empty(
|
||||
input.shape[:-1] + (input.shape[-1] // 2,),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
with input.device as device:
|
||||
_gelu_and_mul(out, input, get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
38
sgl-kernel/tests/test_activation.py
Normal file
38
sgl-kernel/tests/test_activation.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_silu_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
|
||||
y = sgl_kernel.silu_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
|
||||
y = sgl_kernel.gelu_tanh_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
|
||||
def test_fused_gelu_mul(dim, batch_size, seq_len):
|
||||
x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
|
||||
y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
|
||||
y = sgl_kernel.gelu_and_mul(x)
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
test_fused_silu_mul(128, 1, 1)
|
||||
Reference in New Issue
Block a user