From 7353fb9b97705c89d205aa3477b446759fcb86b7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 22 Jan 2025 21:32:48 +0800 Subject: [PATCH] feat: integrate norm kernels into sgl-kernel (#3052) --- sgl-kernel/src/sgl-kernel/__init__.py | 16 ++- .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 16 +++ sgl-kernel/src/sgl-kernel/ops/__init__.py | 45 +++++- sgl-kernel/tests/test_norm.py | 129 ++++++++++++++++++ sgl-kernel/tests/test_rmsnorm.py | 31 ----- 5 files changed, 195 insertions(+), 42 deletions(-) create mode 100644 sgl-kernel/tests/test_norm.py delete mode 100644 sgl-kernel/tests/test_rmsnorm.py diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 3352abeb5..bdbc0ce84 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,6 +1,9 @@ from sgl_kernel.ops import ( custom_dispose, custom_reduce, + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, @@ -12,14 +15,17 @@ from sgl_kernel.ops import ( ) __all__ = [ - "moe_align_block_size", - "init_custom_reduce", "custom_dispose", "custom_reduce", - "int8_scaled_mm", - "sampling_scaling_penalties", + "fused_add_rmsnorm", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", "get_graph_buffer_ipc_meta", + "init_custom_reduce", + "int8_scaled_mm", + "moe_align_block_size", "register_graph_buffers", - "rotary_embedding", "rmsnorm", + "rotary_embedding", + "sampling_scaling_penalties", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index ed359bfbb..8f9d1ae53 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten // rms norm void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); +// fused rms norm +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); + +// gemma rms norm +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused gemma rms norm +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); // rms norm m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)"); + // fused rms norm + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)"); + // gemma rms norm + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)"); + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index e9eadb759..bbfd76878 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -3,6 +3,9 @@ from typing import Optional import torch from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import dispose as _dispose +from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm +from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm +from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm from sgl_kernel.ops._kernels import ( get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta, ) @@ -17,6 +20,10 @@ from sgl_kernel.ops._kernels import ( ) +def get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): @@ -88,9 +95,35 @@ def rmsnorm( eps: float = 1e-6, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if out is None: - out = torch.empty_like(input) - stream = torch.cuda.current_stream().cuda_stream - stream_int = int(stream) - _rmsnorm(out, input, weight, eps, stream_int) - return out + with input.device as device: + if out is None: + out = torch.empty_like(input) + _rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + return out + + +def fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + _fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) + + +def gemma_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + _gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device)) + return out + + +def gemma_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device)) diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py new file mode 100644 index 000000000..32f8c25d9 --- /dev/null +++ b/sgl-kernel/tests/test_norm.py @@ -0,0 +1,129 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py + +import pytest +import sgl_kernel +import torch + + +def llama_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * w.float() + x = x.to(orig_dtype) + return x + + +def gemma_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x + + +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): + orig_dtype = x.dtype + x = x + residual + residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x, residual + + +def fused_add_rms_norm(x, residual, weight, eps): + orig_dtype = x.dtype + x = x.to(torch.float32) + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * weight.float()).to(orig_dtype) + return x, residual + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = llama_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.rmsnorm(x, w, out=y) + else: + y = sgl_kernel.rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = gemma_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.gemma_rmsnorm(x, w, out=y) + else: + y = sgl_kernel.gemma_rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = gemma_fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) diff --git a/sgl-kernel/tests/test_rmsnorm.py b/sgl-kernel/tests/test_rmsnorm.py deleted file mode 100644 index dda225de9..000000000 --- a/sgl-kernel/tests/test_rmsnorm.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest -import torch -from sgl_kernel import rmsnorm - - -def llama_rms_norm(x, w, eps=1e-6): - orig_dtype = x.dtype - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = x * w.float() - x = x.to(orig_dtype) - return x - - -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("specify_out", [True, False]) -def test_norm(batch_size, hidden_size, dtype, specify_out): - x = torch.randn(batch_size, hidden_size).to(0).to(dtype) - w = torch.randn(hidden_size).to(0).to(dtype) - - y_ref = llama_rms_norm(x, w) - if specify_out: - y = torch.empty_like(x) - rmsnorm(x, w, out=y) - else: - y = rmsnorm(x, w) - - torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)