use flashinfer vec_dtypes in sgl_kernel (#3083)
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <pytorch_extension_utils.h>
|
||||||
|
|
||||||
#include <THC/THCAtomics.cuh>
|
#include <THC/THCAtomics.cuh>
|
||||||
|
#include <flashinfer/vec_dtypes.cuh>
|
||||||
|
|
||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
#include "vectorization.cuh"
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
|
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
|
||||||
@@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const
|
|||||||
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int32_t stride = blockDim.x * gridDim.x;
|
const int32_t stride = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
|
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||||
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
|
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
|
||||||
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);
|
|
||||||
|
|
||||||
const int32_t num_vec_elems = numel >> 2;
|
const int32_t num_vec_elems = numel / vec_size;
|
||||||
|
|
||||||
#pragma unroll 4
|
#pragma unroll 1
|
||||||
for (int32_t i = tid; i < num_vec_elems; i += stride) {
|
for (int32_t i = tid; i < num_vec_elems; i += stride) {
|
||||||
vec4_t<scalar_t> logits_vec = vectorized_logits[i];
|
vec_t logits_vec, penalties_vec, out_vec;
|
||||||
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
|
logits_vec.cast_load(logits + i * vec_size);
|
||||||
vec4_t<scalar_t> out_vec;
|
penalties_vec.cast_load(scaling_penalties + i * vec_size);
|
||||||
|
|
||||||
out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
|
#pragma unroll
|
||||||
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
|
for (uint32_t j = 0; j < vec_size; ++j) {
|
||||||
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
|
out_vec[j] = logits_vec[j] > scalar_t(0.0f) ? logits_vec[j] / penalties_vec[j] : logits_vec[j] * penalties_vec[j];
|
||||||
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;
|
}
|
||||||
|
|
||||||
vectorized_output[i] = out_vec;
|
out_vec.cast_store(output + i * vec_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int32_t start_idx = num_vec_elems * 4;
|
// process the remaining elements
|
||||||
|
const int32_t start_idx = num_vec_elems * vec_size;
|
||||||
for (int32_t i = start_idx + tid; i < numel; i += stride) {
|
for (int32_t i = start_idx + tid; i < numel; i += stride) {
|
||||||
scalar_t logit = logits[i];
|
scalar_t logit = logits[i];
|
||||||
scalar_t penalty = scaling_penalties[i];
|
scalar_t penalty = scaling_penalties[i];
|
||||||
output[i] = logit > 0 ? logit / penalty : logit * penalty;
|
output[i] = logit > scalar_t(0.0f) ? logit / penalty : logit * penalty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
|
|||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] {
|
||||||
at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
|
uint32_t vec_size = 16 / sizeof(scalar_t);
|
||||||
const int blocks = (numel + threads * 4 - 1) / (threads * 4);
|
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
|
||||||
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||||
logits.data_ptr<scalar_t>(), scaling_penalties.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), numel);
|
static_cast<scalar_t*>(logits.data_ptr()), static_cast<scalar_t*>(scaling_penalties.data_ptr()),
|
||||||
}));
|
static_cast<scalar_t*>(output.data_ptr()), numel);
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,29 +0,0 @@
|
|||||||
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
|
|
||||||
#pragma once
|
|
||||||
/**
|
|
||||||
* __device__ datatypes vectorized by 4
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
|
||||||
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
|
||||||
|
|
||||||
// Vectorization containers
|
|
||||||
template <typename scalar_t>
|
|
||||||
struct __align__(8) vec4_t {
|
|
||||||
scalar_t x;
|
|
||||||
scalar_t y;
|
|
||||||
scalar_t z;
|
|
||||||
scalar_t w;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename quant_type_t>
|
|
||||||
struct __align__(4) q8x4_t {
|
|
||||||
static_assert(std::is_same_v<quant_type_t, int8_t> || std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
|
||||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
|
|
||||||
quant_type_t x;
|
|
||||||
quant_type_t y;
|
|
||||||
quant_type_t z;
|
|
||||||
quant_type_t w;
|
|
||||||
};
|
|
||||||
@@ -1,37 +1,34 @@
|
|||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import sampling_scaling_penalties
|
from sgl_kernel import sampling_scaling_penalties
|
||||||
|
|
||||||
|
|
||||||
def test_sampling_scaling_penalties():
|
@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65])
|
||||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
|
@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767])
|
||||||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||||
dtypes = [torch.float32, torch.half, torch.bfloat16]
|
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
rtol = 1e-3
|
||||||
|
atol = 1e-3
|
||||||
|
|
||||||
for dtype in dtypes:
|
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
|
||||||
rtol = 1e-3
|
scaling_penalties = (
|
||||||
atol = 1e-3
|
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
|
||||||
|
)
|
||||||
|
|
||||||
for bs in batch_sizes:
|
ref_output = torch.where(
|
||||||
for vocab_size in vocab_sizes:
|
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
||||||
logits = torch.randn(bs, vocab_size, device=device, dtype=dtype)
|
)
|
||||||
scaling_penalties = (
|
|
||||||
torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5
|
|
||||||
)
|
|
||||||
|
|
||||||
ref_output = torch.where(
|
kernel_output = sampling_scaling_penalties(logits, scaling_penalties)
|
||||||
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_output = sampling_scaling_penalties(logits, scaling_penalties)
|
torch.testing.assert_close(
|
||||||
|
kernel_output,
|
||||||
torch.testing.assert_close(
|
ref_output,
|
||||||
kernel_output,
|
rtol=rtol,
|
||||||
ref_output,
|
atol=atol,
|
||||||
rtol=rtol,
|
msg=f"Failed for batch_size={batch_size}, vocab_size={vocab_size}, dtype={dtype}",
|
||||||
atol=atol,
|
)
|
||||||
msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user