From f1b68618281d680add95b9c30635ef644f1f6f25 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Thu, 23 Jan 2025 22:19:04 +0800 Subject: [PATCH] use flashinfer vec_dtypes in sgl_kernel (#3083) --- .../csrc/sampling_scaling_penalties.cu | 47 ++++++++++--------- .../src/sgl-kernel/csrc/vectorization.cuh | 29 ------------ .../tests/test_sampling_scaling_penalties.py | 47 +++++++++---------- 3 files changed, 47 insertions(+), 76 deletions(-) delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu index 2f53bb1a9..2a9de4d9f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -1,11 +1,12 @@ #include #include #include +#include #include +#include #include "utils.h" -#include "vectorization.cuh" template __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties, @@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; const int32_t stride = blockDim.x * gridDim.x; - auto const* vectorized_logits = reinterpret_cast const*>(logits); - auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties); - auto* vectorized_output = reinterpret_cast*>(output); + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + using vec_t = flashinfer::vec_t; - const int32_t num_vec_elems = numel >> 2; + const int32_t num_vec_elems = numel / vec_size; -#pragma unroll 4 +#pragma unroll 1 for (int32_t i = tid; i < num_vec_elems; i += stride) { - vec4_t logits_vec = vectorized_logits[i]; - vec4_t penalties_vec = vectorized_penalties[i]; - vec4_t out_vec; + vec_t logits_vec, penalties_vec, out_vec; + logits_vec.cast_load(logits + i * vec_size); + penalties_vec.cast_load(scaling_penalties + i * vec_size); - out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; - out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; - out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; - out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + out_vec[j] = logits_vec[j] > scalar_t(0.0f) ? logits_vec[j] / penalties_vec[j] : logits_vec[j] * penalties_vec[j]; + } - vectorized_output[i] = out_vec; + out_vec.cast_store(output + i * vec_size); } - const int32_t start_idx = num_vec_elems * 4; + // process the remaining elements + const int32_t start_idx = num_vec_elems * vec_size; for (int32_t i = start_idx + tid; i < numel; i += stride) { scalar_t logit = logits[i]; scalar_t penalty = scaling_penalties[i]; - output[i] = logit > 0 ? logit / penalty : logit * penalty; + output[i] = logit > scalar_t(0.0f) ? logit / penalty : logit * penalty; } } @@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { - const int blocks = (numel + threads * 4 - 1) / (threads * 4); - sampling_scaling_penalties_kernel<<>>( - logits.data_ptr(), scaling_penalties.data_ptr(), output.data_ptr(), numel); - })); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] { + uint32_t vec_size = 16 / sizeof(scalar_t); + const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size); + sampling_scaling_penalties_kernel<<>>( + static_cast(logits.data_ptr()), static_cast(scaling_penalties.data_ptr()), + static_cast(output.data_ptr()), numel); + return true; + }); return output; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh deleted file mode 100644 index 2bfb71018..000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh +++ /dev/null @@ -1,29 +0,0 @@ -// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh -#pragma once -/** - * __device__ datatypes vectorized by 4 - */ - -// Include both AMD and NVIDIA fp8 types to avoid circular import -// TODO(luka/varun) use FP8_TYPE instead after refactoring -#include -#include - -// Vectorization containers -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -template -struct __align__(4) q8x4_t { - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - quant_type_t x; - quant_type_t y; - quant_type_t z; - quant_type_t w; -}; diff --git a/sgl-kernel/tests/test_sampling_scaling_penalties.py b/sgl-kernel/tests/test_sampling_scaling_penalties.py index 4b9746fd7..00f12bfbe 100644 --- a/sgl-kernel/tests/test_sampling_scaling_penalties.py +++ b/sgl-kernel/tests/test_sampling_scaling_penalties.py @@ -1,37 +1,34 @@ +import pytest import torch from sgl_kernel import sampling_scaling_penalties -def test_sampling_scaling_penalties(): - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] - vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] - dtypes = [torch.float32, torch.half, torch.bfloat16] +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65]) +@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +def test_sampling_scaling_penalties(batch_size, vocab_size, dtype): device = torch.device("cuda") + rtol = 1e-3 + atol = 1e-3 - for dtype in dtypes: - rtol = 1e-3 - atol = 1e-3 + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) - for bs in batch_sizes: - for vocab_size in vocab_sizes: - logits = torch.randn(bs, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5 - ) + ref_output = torch.where( + logits > 0, logits / scaling_penalties, logits * scaling_penalties + ) - ref_output = torch.where( - logits > 0, logits / scaling_penalties, logits * scaling_penalties - ) + kernel_output = sampling_scaling_penalties(logits, scaling_penalties) - kernel_output = sampling_scaling_penalties(logits, scaling_penalties) - - torch.testing.assert_close( - kernel_output, - ref_output, - rtol=rtol, - atol=atol, - msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}", - ) + torch.testing.assert_close( + kernel_output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Failed for batch_size={batch_size}, vocab_size={vocab_size}, dtype={dtype}", + ) if __name__ == "__main__":