feat: integrate norm kernels into sgl-kernel (#3052)
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
from sgl_kernel.ops import (
|
||||
custom_dispose,
|
||||
custom_reduce,
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
get_graph_buffer_ipc_meta,
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
@@ -12,14 +15,17 @@ from sgl_kernel.ops import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"moe_align_block_size",
|
||||
"init_custom_reduce",
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"sampling_scaling_penalties",
|
||||
"fused_add_rmsnorm",
|
||||
"gemma_fused_add_rmsnorm",
|
||||
"gemma_rmsnorm",
|
||||
"get_graph_buffer_ipc_meta",
|
||||
"init_custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"moe_align_block_size",
|
||||
"register_graph_buffers",
|
||||
"rotary_embedding",
|
||||
"rmsnorm",
|
||||
"rotary_embedding",
|
||||
"sampling_scaling_penalties",
|
||||
]
|
||||
|
||||
@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
|
||||
// rms norm
|
||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// fused rms norm
|
||||
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// gemma rms norm
|
||||
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||
|
||||
// fused gemma rms norm
|
||||
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||
int64_t cuda_stream);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// trt_reduce
|
||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||
@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||
// rms norm
|
||||
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
|
||||
// fused rms norm
|
||||
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
|
||||
// gemma rms norm
|
||||
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
|
||||
// fused gemma rms norm
|
||||
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
|
||||
}
|
||||
|
||||
@@ -3,6 +3,9 @@ from typing import Optional
|
||||
import torch
|
||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
|
||||
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
|
||||
from sgl_kernel.ops._kernels import (
|
||||
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
|
||||
)
|
||||
@@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import (
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_stream(device: torch.device) -> int:
|
||||
return torch.cuda.current_stream(device).cuda_stream
|
||||
|
||||
|
||||
def init_custom_reduce(
|
||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||
):
|
||||
@@ -88,9 +95,35 @@ def rmsnorm(
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
stream_int = int(stream)
|
||||
_rmsnorm(out, input, weight, eps, stream_int)
|
||||
return out
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||
|
||||
|
||||
def gemma_rmsnorm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
with input.device as device:
|
||||
if out is None:
|
||||
out = torch.empty_like(input)
|
||||
_gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||
return out
|
||||
|
||||
|
||||
def gemma_fused_add_rmsnorm(
|
||||
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||
) -> None:
|
||||
with input.device as device:
|
||||
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||
|
||||
129
sgl-kernel/tests/test_norm.py
Normal file
129
sgl-kernel/tests/test_norm.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
|
||||
|
||||
import pytest
|
||||
import sgl_kernel
|
||||
import torch
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * w.float()
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
def gemma_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * (1.0 + w.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x + residual
|
||||
residual = x
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * (1.0 + w.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x, residual
|
||||
|
||||
|
||||
def fused_add_rms_norm(x, residual, weight, eps):
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
x = x + residual.to(torch.float32)
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = (x * weight.float()).to(orig_dtype)
|
||||
return x, residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = sgl_kernel.rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
eps = 1e-6
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
x_native, residual_native = fused_add_rms_norm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = gemma_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = sgl_kernel.gemma_rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||
eps = 1e-6
|
||||
|
||||
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||
residual = torch.randn_like(x)
|
||||
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
x_native, residual_native = gemma_fused_add_rms_norm(
|
||||
x.clone(), residual.clone(), weight, eps
|
||||
)
|
||||
|
||||
x_fused = x.clone()
|
||||
residual_fused = residual.clone()
|
||||
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||
|
||||
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||
@@ -1,31 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import rmsnorm
|
||||
|
||||
|
||||
def llama_rms_norm(x, w, eps=1e-6):
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + eps)
|
||||
x = x * w.float()
|
||||
x = x.to(orig_dtype)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("specify_out", [True, False])
|
||||
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||
|
||||
y_ref = llama_rms_norm(x, w)
|
||||
if specify_out:
|
||||
y = torch.empty_like(x)
|
||||
rmsnorm(x, w, out=y)
|
||||
else:
|
||||
y = rmsnorm(x, w)
|
||||
|
||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||
Reference in New Issue
Block a user