diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 3c267a4de..15818d289 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -32,6 +32,7 @@ add_library(_kernels SHARED src/sgl-kernel/csrc/trt_reduce_kernel.cu src/sgl-kernel/csrc/moe_align_kernel.cu src/sgl-kernel/csrc/int8_gemm_kernel.cu + src/sgl-kernel/csrc/sampling_scaling_penalties.cu src/sgl-kernel/csrc/sgl_kernel_ops.cu ) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 359ffafd7..b03b4c02b 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post11" +version = "0.0.2.post12" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index c93e87f6b..83025d6d6 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -50,6 +50,7 @@ ext_modules = [ "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", ], include_dirs=include_dirs, diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 892808f1e..62c366731 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -4,6 +4,7 @@ from sgl_kernel.ops import ( init_custom_reduce, int8_scaled_mm, moe_align_block_size, + sampling_scaling_penalties, ) __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "custom_dispose", "custom_reduce", "int8_scaled_mm", + "sampling_scaling_penalties", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu new file mode 100644 index 000000000..30264caa3 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -0,0 +1,64 @@ +#include +#include +#include +#include +#include "utils.hpp" +#include "vectorization.cuh" + +template +__global__ void sampling_scaling_penalties_kernel( + const scalar_t* logits, + const scalar_t* scaling_penalties, + scalar_t* output, + const int32_t numel) { + + const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t stride = blockDim.x * gridDim.x; + + auto const* vectorized_logits = reinterpret_cast const*>(logits); + auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties); + auto* vectorized_output = reinterpret_cast*>(output); + + const int32_t num_vec_elems = numel >> 2; + +#pragma unroll 4 + for (int32_t i = tid; i < num_vec_elems; i += stride) { + vec4_t logits_vec = vectorized_logits[i]; + vec4_t penalties_vec = vectorized_penalties[i]; + vec4_t out_vec; + + out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; + out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; + out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; + out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; + + vectorized_output[i] = out_vec; + } + + const int32_t start_idx = num_vec_elems * 4; + for (int32_t i = start_idx + tid; i < numel; i += stride) { + scalar_t logit = logits[i]; + scalar_t penalty = scaling_penalties[i]; + output[i] = logit > 0 ? logit / penalty : logit * penalty; + } +} + +torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) { + auto output = torch::empty_like(logits); + const auto numel = logits.numel(); + const int threads = 512; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, + logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { + const int blocks = (numel + threads * 4 - 1) / (threads * 4); + sampling_scaling_penalties_kernel<<>>( + logits.data_ptr(), + scaling_penalties.data_ptr(), + output.data_ptr(), + numel); + })); + + return output; +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 6ed543e6c..fbfe51442 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -12,6 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); +// sampling_scaling_penalties +torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties); + // int8_scaled_mm torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype, @@ -24,6 +27,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); // moe_align_block_size m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); + // sampling_scaling_penalties + m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); // int8_scaled_mm m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh new file mode 100644 index 000000000..cb36d0e7a --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh @@ -0,0 +1,30 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +// TODO(luka/varun) use FP8_TYPE instead after refactoring +#include +#include + +// Vectorization containers +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +template +struct __align__(4) q8x4_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t x; + quant_type_t y; + quant_type_t z; + quant_type_t w; +}; diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index e388ae356..03a8db80f 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -3,6 +3,9 @@ from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size +from sgl_kernel.ops._kernels import ( + sampling_scaling_penalties as _sampling_scaling_penalties, +) def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): @@ -39,6 +42,10 @@ def moe_align_block_size( ) +def sampling_scaling_penalties(logits, scaling_penalties): + return _sampling_scaling_penalties(logits, scaling_penalties) + + def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return _int8_scaled_mm( mat_a, diff --git a/sgl-kernel/tests/test_sampling_scaling_penalties.py b/sgl-kernel/tests/test_sampling_scaling_penalties.py new file mode 100644 index 000000000..4b9746fd7 --- /dev/null +++ b/sgl-kernel/tests/test_sampling_scaling_penalties.py @@ -0,0 +1,39 @@ +import torch +from sgl_kernel import sampling_scaling_penalties + + +def test_sampling_scaling_penalties(): + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] + vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] + dtypes = [torch.float32, torch.half, torch.bfloat16] + device = torch.device("cuda") + + for dtype in dtypes: + rtol = 1e-3 + atol = 1e-3 + + for bs in batch_sizes: + for vocab_size in vocab_sizes: + logits = torch.randn(bs, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + ref_output = torch.where( + logits > 0, logits / scaling_penalties, logits * scaling_penalties + ) + + kernel_output = sampling_scaling_penalties(logits, scaling_penalties) + + torch.testing.assert_close( + kernel_output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}", + ) + + +if __name__ == "__main__": + test_sampling_scaling_penalties() + print("All tests passed!")