support fp32 in sampling_scaling_penalties kernel (#3121)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <pytorch_extension_utils.h>
|
||||
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <flashinfer/vec_dtypes.cuh>
|
||||
@@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] {
|
||||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] {
|
||||
uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
|
||||
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include <pytorch_extension_utils.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <sstream>
|
||||
@@ -44,3 +45,20 @@ inline int getSMVersion() {
|
||||
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
|
||||
return sm_major * 10 + sm_minor;
|
||||
}
|
||||
|
||||
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
|
||||
[&]() -> bool { \
|
||||
switch (pytorch_dtype) { \
|
||||
case at::ScalarType::Float: { \
|
||||
using c_type = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
|
||||
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
|
||||
default: \
|
||||
std::ostringstream oss; \
|
||||
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
|
||||
TORCH_CHECK(false, oss.str()); \
|
||||
return false; \
|
||||
} \
|
||||
}()
|
||||
|
||||
@@ -2,10 +2,14 @@ import pytest
|
||||
import torch
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
|
||||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
|
||||
dtypes = [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
|
||||
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||
|
||||
@pytest.mark.parametrize("batch_size", batch_sizes)
|
||||
@pytest.mark.parametrize("vocab_size", vocab_sizes)
|
||||
@pytest.mark.parametrize("dtype", dtypes)
|
||||
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
|
||||
device = torch.device("cuda")
|
||||
rtol = 1e-3
|
||||
|
||||
Reference in New Issue
Block a user