support fp32 in sampling_scaling_penalties kernel (#3121)

This commit is contained in:
Xiaoyu Zhang
2025-01-25 16:52:17 +08:00
committed by GitHub
parent 665e5e85f6
commit 5d9d15e70f
3 changed files with 26 additions and 5 deletions

View File

@@ -1,7 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <pytorch_extension_utils.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
@@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] {
uint32_t vec_size = 16 / sizeof(scalar_t);
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(

View File

@@ -1,4 +1,5 @@
#pragma once
#include <pytorch_extension_utils.h>
#include <torch/extension.h>
#include <sstream>
@@ -44,3 +45,20 @@ inline int getSMVersion() {
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()

View File

@@ -2,10 +2,14 @@ import pytest
import torch
from sgl_kernel import sampling_scaling_penalties
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
dtypes = [torch.float32, torch.half, torch.bfloat16]
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("vocab_size", vocab_sizes)
@pytest.mark.parametrize("dtype", dtypes)
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
device = torch.device("cuda")
rtol = 1e-3