From 5d9d15e70f7e73223a3d2baf3851b95a9d5356f0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 25 Jan 2025 16:52:17 +0800 Subject: [PATCH] support fp32 in sampling_scaling_penalties kernel (#3121) --- .../csrc/sampling_scaling_penalties.cu | 3 +-- sgl-kernel/src/sgl-kernel/csrc/utils.h | 18 ++++++++++++++++++ .../tests/test_sampling_scaling_penalties.py | 10 +++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu index 2a9de4d9f..18beb8644 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include #include @@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] { uint32_t vec_size = 16 / sizeof(scalar_t); const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size); sampling_scaling_penalties_kernel<<>>( diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.h b/sgl-kernel/src/sgl-kernel/csrc/utils.h index 2fed2d60c..ed802d4fd 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.h +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -44,3 +45,20 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() diff --git a/sgl-kernel/tests/test_sampling_scaling_penalties.py b/sgl-kernel/tests/test_sampling_scaling_penalties.py index 6194c7617..a56eca866 100644 --- a/sgl-kernel/tests/test_sampling_scaling_penalties.py +++ b/sgl-kernel/tests/test_sampling_scaling_penalties.py @@ -2,10 +2,14 @@ import pytest import torch from sgl_kernel import sampling_scaling_penalties +batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] +vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] +dtypes = [torch.float32, torch.half, torch.bfloat16] -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65]) -@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) + +@pytest.mark.parametrize("batch_size", batch_sizes) +@pytest.mark.parametrize("vocab_size", vocab_sizes) +@pytest.mark.parametrize("dtype", dtypes) def test_sampling_scaling_penalties(batch_size, vocab_size, dtype): device = torch.device("cuda") rtol = 1e-3