add sampling_scaling_penalties kernel (#2846)
This commit is contained in:
@@ -32,6 +32,7 @@ add_library(_kernels SHARED
|
||||
src/sgl-kernel/csrc/trt_reduce_kernel.cu
|
||||
src/sgl-kernel/csrc/moe_align_kernel.cu
|
||||
src/sgl-kernel/csrc/int8_gemm_kernel.cu
|
||||
src/sgl-kernel/csrc/sampling_scaling_penalties.cu
|
||||
src/sgl-kernel/csrc/sgl_kernel_ops.cu
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "sgl-kernel"
|
||||
version = "0.0.2.post11"
|
||||
version = "0.0.2.post12"
|
||||
description = "Kernel Library for SGLang"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
|
||||
@@ -50,6 +50,7 @@ ext_modules = [
|
||||
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
|
||||
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
|
||||
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
|
||||
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
|
||||
],
|
||||
include_dirs=include_dirs,
|
||||
|
||||
@@ -4,6 +4,7 @@ from sgl_kernel.ops import (
|
||||
init_custom_reduce,
|
||||
int8_scaled_mm,
|
||||
moe_align_block_size,
|
||||
sampling_scaling_penalties,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -12,4 +13,5 @@ __all__ = [
|
||||
"custom_dispose",
|
||||
"custom_reduce",
|
||||
"int8_scaled_mm",
|
||||
"sampling_scaling_penalties",
|
||||
]
|
||||
|
||||
64
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
Normal file
64
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
Normal file
@@ -0,0 +1,64 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include "utils.hpp"
|
||||
#include "vectorization.cuh"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void sampling_scaling_penalties_kernel(
|
||||
const scalar_t* logits,
|
||||
const scalar_t* scaling_penalties,
|
||||
scalar_t* output,
|
||||
const int32_t numel) {
|
||||
|
||||
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32_t stride = blockDim.x * gridDim.x;
|
||||
|
||||
auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
|
||||
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
|
||||
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);
|
||||
|
||||
const int32_t num_vec_elems = numel >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = tid; i < num_vec_elems; i += stride) {
|
||||
vec4_t<scalar_t> logits_vec = vectorized_logits[i];
|
||||
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
|
||||
vec4_t<scalar_t> out_vec;
|
||||
|
||||
out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
|
||||
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
|
||||
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
|
||||
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;
|
||||
|
||||
vectorized_output[i] = out_vec;
|
||||
}
|
||||
|
||||
const int32_t start_idx = num_vec_elems * 4;
|
||||
for (int32_t i = start_idx + tid; i < numel; i += stride) {
|
||||
scalar_t logit = logits[i];
|
||||
scalar_t penalty = scaling_penalties[i];
|
||||
output[i] = logit > 0 ? logit / penalty : logit * penalty;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
|
||||
auto output = torch::empty_like(logits);
|
||||
const auto numel = logits.numel();
|
||||
const int threads = 512;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
|
||||
logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
|
||||
const int blocks = (numel + threads * 4 - 1) / (threads * 4);
|
||||
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
logits.data_ptr<scalar_t>(),
|
||||
scaling_penalties.data_ptr<scalar_t>(),
|
||||
output.data_ptr<scalar_t>(),
|
||||
numel);
|
||||
}));
|
||||
|
||||
return output;
|
||||
}
|
||||
@@ -12,6 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
|
||||
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
|
||||
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
|
||||
|
||||
// sampling_scaling_penalties
|
||||
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties);
|
||||
|
||||
// int8_scaled_mm
|
||||
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
|
||||
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
|
||||
@@ -24,6 +27,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
|
||||
// moe_align_block_size
|
||||
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
|
||||
// sampling_scaling_penalties
|
||||
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
|
||||
// int8_scaled_mm
|
||||
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
|
||||
}
|
||||
|
||||
30
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
Normal file
30
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
|
||||
#pragma once
|
||||
/**
|
||||
* __device__ datatypes vectorized by 4
|
||||
*/
|
||||
|
||||
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
||||
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
// Vectorization containers
|
||||
template <typename scalar_t>
|
||||
struct __align__(8) vec4_t {
|
||||
scalar_t x;
|
||||
scalar_t y;
|
||||
scalar_t z;
|
||||
scalar_t w;
|
||||
};
|
||||
|
||||
template <typename quant_type_t>
|
||||
struct __align__(4) q8x4_t {
|
||||
static_assert(std::is_same_v<quant_type_t, int8_t> ||
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
|
||||
quant_type_t x;
|
||||
quant_type_t y;
|
||||
quant_type_t z;
|
||||
quant_type_t w;
|
||||
};
|
||||
@@ -3,6 +3,9 @@ from sgl_kernel.ops._kernels import dispose as _dispose
|
||||
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
|
||||
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
|
||||
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
|
||||
from sgl_kernel.ops._kernels import (
|
||||
sampling_scaling_penalties as _sampling_scaling_penalties,
|
||||
)
|
||||
|
||||
|
||||
def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
|
||||
@@ -39,6 +42,10 @@ def moe_align_block_size(
|
||||
)
|
||||
|
||||
|
||||
def sampling_scaling_penalties(logits, scaling_penalties):
|
||||
return _sampling_scaling_penalties(logits, scaling_penalties)
|
||||
|
||||
|
||||
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
||||
return _int8_scaled_mm(
|
||||
mat_a,
|
||||
|
||||
39
sgl-kernel/tests/test_sampling_scaling_penalties.py
Normal file
39
sgl-kernel/tests/test_sampling_scaling_penalties.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
|
||||
def test_sampling_scaling_penalties():
|
||||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
|
||||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
|
||||
dtypes = [torch.float32, torch.half, torch.bfloat16]
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in dtypes:
|
||||
rtol = 1e-3
|
||||
atol = 1e-3
|
||||
|
||||
for bs in batch_sizes:
|
||||
for vocab_size in vocab_sizes:
|
||||
logits = torch.randn(bs, vocab_size, device=device, dtype=dtype)
|
||||
scaling_penalties = (
|
||||
torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5
|
||||
)
|
||||
|
||||
ref_output = torch.where(
|
||||
logits > 0, logits / scaling_penalties, logits * scaling_penalties
|
||||
)
|
||||
|
||||
kernel_output = sampling_scaling_penalties(logits, scaling_penalties)
|
||||
|
||||
torch.testing.assert_close(
|
||||
kernel_output,
|
||||
ref_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sampling_scaling_penalties()
|
||||
print("All tests passed!")
|
||||
Reference in New Issue
Block a user