From 827aa8730b7c3965a01d55b72b66d244a0d20ddd Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 27 Jan 2025 19:11:01 +0800 Subject: [PATCH] cleanup sgl-kernel kernels (#3175) --- .github/workflows/pr-test.yml | 1 + sgl-kernel/setup.py | 2 +- sgl-kernel/src/sgl-kernel/__init__.py | 2 - .../csrc/fused_add_rms_norm_kernel.cu | 140 ++++++++++++++++++ .../src/sgl-kernel/include/sgl_kernels_ops.h | 6 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 10 +- sgl-kernel/src/sgl-kernel/torch_extension.cc | 10 +- sgl-kernel/tests/test_norm.py | 2 +- 8 files changed, 147 insertions(+), 26 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 28fbec903..6ed6046ee 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -51,6 +51,7 @@ jobs: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner strategy: + fail-fast: false matrix: range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] steps: diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 645d8070d..f887f5c19 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -88,7 +88,7 @@ sources = [ "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/rotary_embedding.cu", + "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index e82eece48..a3d35072d 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -17,7 +17,6 @@ from sgl_kernel.ops import ( moe_align_block_size, register_graph_buffers, rmsnorm, - rotary_embedding, sampling_scaling_penalties, silu_and_mul, top_k_renorm_prob, @@ -44,7 +43,6 @@ __all__ = [ "moe_align_block_size", "register_graph_buffers", "rmsnorm", - "rotary_embedding", "sampling_scaling_penalties", "silu_and_mul", "top_k_renorm_prob", diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu new file mode 100644 index 000000000..4c4ecb966 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -0,0 +1,140 @@ +// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh +// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu +// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 + +#include + +#include +#include +#include +#include + +#include "utils.h" + +using namespace flashinfer; + +template +__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, + const uint32_t d, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.f); + vec_t residual_vec; + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + float x = float(input_vec[j]); + x += float(residual_vec[j]); + sum_sq += x * x; + residual_vec[j] = (T)x; + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } + + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t residual_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } +} + +template +cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + void* args[] = {&input, &residual, &weight, &d, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = FusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + + return cudaSuccess; +} + +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(residual.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); + // support float16, bfloat16 and float32 + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = + FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index f03a09364..c5cc30c18 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -50,15 +50,11 @@ void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, torch::Tensor new_kv); -// rotary embedding -void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox); - // rms norm void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); // fused rms norm -void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); // gemma rms norm void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 2fa1d9579..5aa484ff5 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -142,12 +142,6 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): ) -def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): - return torch.ops.sgl_kernels.rotary_embedding( - positions, query, key, head_size, cos_sin_cache, is_neox - ) - - # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # Kudos to @yzh119 def rmsnorm( @@ -167,9 +161,7 @@ def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - torch.ops.sgl_kernels.fused_add_rmsnorm( - input, residual, weight, eps, _get_cuda_stream(device) - ) + torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) def gemma_rmsnorm( diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 70cdde9d8..01f93199c 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -45,19 +45,13 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "new_kv) -> ()"); m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); - // rotary embedding - m.def( - "rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool " - "is_neox) -> ()"); - m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); - // rms norm m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); m.impl("rmsnorm", torch::kCUDA, &rmsnorm); // fused rms norm - m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); - m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm); + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); // gemma rms norm m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py index 7b38dba72..d22da931f 100644 --- a/sgl-kernel/tests/test_norm.py +++ b/sgl-kernel/tests/test_norm.py @@ -69,7 +69,7 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): eps = 1e-6