feat: integrate norm kernels into sgl-kernel (#3052)
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
from sgl_kernel.ops import (
|
from sgl_kernel.ops import (
|
||||||
custom_dispose,
|
custom_dispose,
|
||||||
custom_reduce,
|
custom_reduce,
|
||||||
|
fused_add_rmsnorm,
|
||||||
|
gemma_fused_add_rmsnorm,
|
||||||
|
gemma_rmsnorm,
|
||||||
get_graph_buffer_ipc_meta,
|
get_graph_buffer_ipc_meta,
|
||||||
init_custom_reduce,
|
init_custom_reduce,
|
||||||
int8_scaled_mm,
|
int8_scaled_mm,
|
||||||
@@ -12,14 +15,17 @@ from sgl_kernel.ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"moe_align_block_size",
|
|
||||||
"init_custom_reduce",
|
|
||||||
"custom_dispose",
|
"custom_dispose",
|
||||||
"custom_reduce",
|
"custom_reduce",
|
||||||
"int8_scaled_mm",
|
"fused_add_rmsnorm",
|
||||||
"sampling_scaling_penalties",
|
"gemma_fused_add_rmsnorm",
|
||||||
|
"gemma_rmsnorm",
|
||||||
"get_graph_buffer_ipc_meta",
|
"get_graph_buffer_ipc_meta",
|
||||||
|
"init_custom_reduce",
|
||||||
|
"int8_scaled_mm",
|
||||||
|
"moe_align_block_size",
|
||||||
"register_graph_buffers",
|
"register_graph_buffers",
|
||||||
"rotary_embedding",
|
|
||||||
"rmsnorm",
|
"rmsnorm",
|
||||||
|
"rotary_embedding",
|
||||||
|
"sampling_scaling_penalties",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
|
|||||||
// rms norm
|
// rms norm
|
||||||
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||||
|
|
||||||
|
// fused rms norm
|
||||||
|
void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||||
|
|
||||||
|
// gemma rms norm
|
||||||
|
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
|
||||||
|
|
||||||
|
// fused gemma rms norm
|
||||||
|
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
|
||||||
|
int64_t cuda_stream);
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
// trt_reduce
|
// trt_reduce
|
||||||
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
|
||||||
@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
|
||||||
// rms norm
|
// rms norm
|
||||||
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
|
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
|
||||||
|
// fused rms norm
|
||||||
|
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
|
||||||
|
// gemma rms norm
|
||||||
|
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
|
||||||
|
// fused gemma rms norm
|
||||||
|
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
|
||||||
from sgl_kernel.ops._kernels import dispose as _dispose
|
from sgl_kernel.ops._kernels import dispose as _dispose
|
||||||
|
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
|
||||||
|
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
|
||||||
|
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
|
||||||
from sgl_kernel.ops._kernels import (
|
from sgl_kernel.ops._kernels import (
|
||||||
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
|
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
|
||||||
)
|
)
|
||||||
@@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_stream(device: torch.device) -> int:
|
||||||
|
return torch.cuda.current_stream(device).cuda_stream
|
||||||
|
|
||||||
|
|
||||||
def init_custom_reduce(
|
def init_custom_reduce(
|
||||||
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
|
||||||
):
|
):
|
||||||
@@ -88,9 +95,35 @@ def rmsnorm(
|
|||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
out: Optional[torch.Tensor] = None,
|
out: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
with input.device as device:
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty_like(input)
|
out = torch.empty_like(input)
|
||||||
stream = torch.cuda.current_stream().cuda_stream
|
_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||||
stream_int = int(stream)
|
|
||||||
_rmsnorm(out, input, weight, eps, stream_int)
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def fused_add_rmsnorm(
|
||||||
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> None:
|
||||||
|
with input.device as device:
|
||||||
|
_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_rmsnorm(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
with input.device as device:
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty_like(input)
|
||||||
|
_gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_fused_add_rmsnorm(
|
||||||
|
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
||||||
|
) -> None:
|
||||||
|
with input.device as device:
|
||||||
|
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
|
||||||
|
|||||||
129
sgl-kernel/tests/test_norm.py
Normal file
129
sgl-kernel/tests/test_norm.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sgl_kernel
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def llama_rms_norm(x, w, eps=1e-6):
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + eps)
|
||||||
|
x = x * w.float()
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_rms_norm(x, w, eps=1e-6):
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + eps)
|
||||||
|
x = x * (1.0 + w.float())
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x + residual
|
||||||
|
residual = x
|
||||||
|
x = x.float()
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + eps)
|
||||||
|
x = x * (1.0 + w.float())
|
||||||
|
x = x.to(orig_dtype)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
def fused_add_rms_norm(x, residual, weight, eps):
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + eps)
|
||||||
|
x = (x * weight.float()).to(orig_dtype)
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
@pytest.mark.parametrize("specify_out", [True, False])
|
||||||
|
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
||||||
|
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||||
|
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||||
|
|
||||||
|
y_ref = llama_rms_norm(x, w)
|
||||||
|
if specify_out:
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
sgl_kernel.rmsnorm(x, w, out=y)
|
||||||
|
else:
|
||||||
|
y = sgl_kernel.rmsnorm(x, w)
|
||||||
|
|
||||||
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
x_native, residual_native = fused_add_rms_norm(
|
||||||
|
x.clone(), residual.clone(), weight, eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x_fused = x.clone()
|
||||||
|
residual_fused = residual.clone()
|
||||||
|
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||||
|
|
||||||
|
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||||
|
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
@pytest.mark.parametrize("specify_out", [True, False])
|
||||||
|
def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
|
||||||
|
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
||||||
|
w = torch.randn(hidden_size).to(0).to(dtype)
|
||||||
|
|
||||||
|
y_ref = gemma_rms_norm(x, w)
|
||||||
|
if specify_out:
|
||||||
|
y = torch.empty_like(x)
|
||||||
|
sgl_kernel.gemma_rmsnorm(x, w, out=y)
|
||||||
|
else:
|
||||||
|
y = sgl_kernel.gemma_rmsnorm(x, w)
|
||||||
|
|
||||||
|
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
|
||||||
|
eps = 1e-6
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x)
|
||||||
|
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
x_native, residual_native = gemma_fused_add_rms_norm(
|
||||||
|
x.clone(), residual.clone(), weight, eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x_fused = x.clone()
|
||||||
|
residual_fused = residual.clone()
|
||||||
|
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
|
||||||
|
|
||||||
|
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
|
||||||
|
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from sgl_kernel import rmsnorm
|
|
||||||
|
|
||||||
|
|
||||||
def llama_rms_norm(x, w, eps=1e-6):
|
|
||||||
orig_dtype = x.dtype
|
|
||||||
x = x.float()
|
|
||||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
||||||
x = x * torch.rsqrt(variance + eps)
|
|
||||||
x = x * w.float()
|
|
||||||
x = x.to(orig_dtype)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
|
|
||||||
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
|
||||||
@pytest.mark.parametrize("specify_out", [True, False])
|
|
||||||
def test_norm(batch_size, hidden_size, dtype, specify_out):
|
|
||||||
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
|
|
||||||
w = torch.randn(hidden_size).to(0).to(dtype)
|
|
||||||
|
|
||||||
y_ref = llama_rms_norm(x, w)
|
|
||||||
if specify_out:
|
|
||||||
y = torch.empty_like(x)
|
|
||||||
rmsnorm(x, w, out=y)
|
|
||||||
else:
|
|
||||||
y = rmsnorm(x, w)
|
|
||||||
|
|
||||||
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
|
|
||||||
Reference in New Issue
Block a user