Files
sglang/sgl-kernel/csrc/elementwise/topk.cu

496 lines
17 KiB
Plaintext

/**
* @NOTE: This file is adapted from
* https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py
* We:
* 1. adapt from tilelang to pure cuda
* 2. optimize the performance a little
* 3. fix the potential illegal memory access
*/
#include <ATen/core/TensorBase.h>
#include <ATen/core/TensorBody.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cstddef>
#include <cstdint>
#include <optional>
namespace {
constexpr int TopK = 2048;
constexpr int kThreadsPerBlock = 1024;
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
struct FastTopKParams {
const float* __restrict__ input; // [B, input_stride]
int32_t* __restrict__ indices; // [B, TopK]
int32_t* __restrict__ lengths; // [B]
int64_t input_stride;
};
// when length <= TopK, we can directly write the indices
__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) {
const auto tid = threadIdx.x;
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
indice[i] = (i < length) ? i : -1;
}
}
// keep the first `length` entries, set others to -1
__device__ void naive_topk_transform(
const float* __restrict__ score,
int32_t length,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
}
}
// keep the first `length` entries, set others to -1
__device__ void naive_topk_transform_ragged(
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
}
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits) : static_cast<uint16_t>(bits | 0x8000);
return static_cast<uint8_t>(key >> 8);
}
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
uint32_t bits = __float_as_uint(x);
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
}
__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) {
// An optimized topk kernel copied from tilelang kernel
// We assume length > TopK here, or it will crash
int topk = TopK;
constexpr auto BLOCK_SIZE = 1024;
constexpr auto RADIX = 256;
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128];
alignas(128) __shared__ int s_counter;
alignas(128) __shared__ int s_threshold_bin_id;
alignas(128) __shared__ int s_num_input[2];
auto& s_histogram = s_histogram_buf[0];
// allocate for two rounds
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
const int tx = threadIdx.x;
// stage 1: 8bit coarse histogram
if (tx < RADIX + 1) s_histogram[tx] = 0;
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = convert_to_uint8(input[idx]);
::atomicAdd(&s_histogram[bin], 1);
}
__syncthreads();
const auto run_cumsum = [&] {
#pragma unroll 8
for (int i = 0; i < 8; ++i) {
static_assert(1 << 8 == RADIX);
if (C10_LIKELY(tx < RADIX)) {
const auto j = 1 << i;
const auto k = i & 1;
auto value = s_histogram_buf[k][tx];
if (tx < RADIX - j) {
value += s_histogram_buf[k][tx + j];
}
s_histogram_buf[k ^ 1][tx] = value;
}
__syncthreads();
}
};
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[0] = 0;
s_counter = 0;
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
topk -= s_histogram[threshold_bin + 1];
if (topk == 0) {
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto bin = static_cast<int>(convert_to_uint8(input[idx]));
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
}
}
__syncthreads();
return;
} else {
__syncthreads();
if (tx < RADIX + 1) {
s_histogram[tx] = 0;
}
__syncthreads();
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
const auto raw_input = input[idx];
const auto bin = static_cast<int>(convert_to_uint8(raw_input));
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin == threshold_bin) {
const auto pos = ::atomicAdd(&s_num_input[0], 1);
/// NOTE: (dark) fuse the histogram computation here
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
s_input_idx[0][pos] = idx;
const auto bin = convert_to_uint32(raw_input);
const auto sub_bin = (bin >> 24) & 0xFF;
::atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
__syncthreads();
}
// stage 2: refine with 8bit radix passes
#pragma unroll 4
for (int round = 0; round < 4; ++round) {
__shared__ int s_last_remain;
const auto r_idx = round % 2;
// clip here to prevent overflow
const auto _raw_num_input = s_num_input[r_idx];
const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE);
run_cumsum();
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
s_threshold_bin_id = tx;
s_num_input[r_idx ^ 1] = 0;
s_last_remain = topk - s_histogram[tx + 1];
}
__syncthreads();
const auto threshold_bin = s_threshold_bin_id;
topk -= s_histogram[threshold_bin + 1];
if (topk == 0) {
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto offset = 24 - round * 8;
const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF;
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
}
}
__syncthreads();
break;
} else {
__syncthreads();
if (tx < RADIX + 1) {
s_histogram[tx] = 0;
}
__syncthreads();
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
const auto idx = s_input_idx[r_idx][i];
const auto raw_input = input[idx];
const auto offset = 24 - round * 8;
const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF;
if (bin > threshold_bin) {
const auto pos = ::atomicAdd(&s_counter, 1);
index[pos] = idx;
} else if (bin == threshold_bin) {
if (round == 3) {
const auto pos = ::atomicAdd(&s_last_remain, -1);
if (pos > 0) {
index[TopK - pos] = idx;
}
} else {
const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
/// NOTE: (dark) fuse the histogram computation here
s_input_idx[r_idx ^ 1][pos] = idx;
const auto bin = convert_to_uint32(raw_input);
const auto sub_bin = (bin >> (offset - 8)) & 0xFF;
::atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
}
__syncthreads();
}
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // topk
void topk_kernel(const FastTopKParams params) {
const auto& [input, indices, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto length = lengths[bid];
const auto indice = indices + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_cuda(score, indice, length);
} else {
return fast_topk_cuda_tl(score, indice, length);
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // decode
void topk_transform_decode_kernel(
const FastTopKParams params,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table,
const int64_t src_stride) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto src_page_entry = src_page_table + bid * src_stride;
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // prefill
void topk_transform_prefill_kernel(
const FastTopKParams params,
int32_t* __restrict__ dst_page_table,
const int32_t* __restrict__ src_page_table,
const int64_t src_stride,
const int32_t* __restrict__ cu_seqlens_q,
const int64_t prefill_bs) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto dst_page_entry = dst_page_table + bid * TopK;
const auto score = input + bid * input_stride;
/// NOTE: prefill bs is usually small, we can just use a simple loop here
/// We ensure that last cu_seqlens is equal to number of blocks launched
__shared__ const int32_t* s_src_page_entry;
if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) {
if (tid < prefill_bs) {
if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) {
s_src_page_entry = src_page_table + tid * src_stride;
}
}
} else {
for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) {
if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) {
s_src_page_entry = src_page_table + i * src_stride;
}
}
}
__syncthreads();
const auto src_page_entry = s_src_page_entry;
if (length <= TopK) {
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_page_entry[idx_0] = src_page_entry[pos_0];
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_page_entry[idx_1] = src_page_entry[pos_1];
}
}
__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv
void topk_transform_prefill_ragged_kernel(
const FastTopKParams params,
int32_t* __restrict__ topk_indices_ragged,
const int32_t* __restrict__ topk_indices_offset) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto dst_indices_entry = topk_indices_ragged + bid * TopK;
const auto score = input + bid * input_stride;
const auto offset = topk_indices_offset[bid];
if (length <= TopK) {
return naive_topk_transform_ragged(score, length, dst_indices_entry, offset);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_indices_entry[idx_0] = pos_0 + offset;
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_indices_entry[idx_1] = pos_1 + offset;
}
}
auto get_params(
const at::Tensor& score,
const at::Tensor& lengths,
std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams {
const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
TORCH_CHECK(lengths.size(0) == B);
int32_t* indices_data_ptr = nullptr;
if (indices_opt.has_value()) {
const auto& indices = indices_opt.value();
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
TORCH_CHECK(indices.size(0) == B);
TORCH_CHECK(indices.size(1) == TopK);
indices_data_ptr = indices.data_ptr<int32_t>();
}
return FastTopKParams{
.input = score.data_ptr<float>(),
.indices = indices_data_ptr,
.lengths = lengths.data_ptr<int32_t>(),
.input_stride = score.stride(0),
};
}
template <auto* f, size_t max_dynamic_smem>
void setup_kernel_smem_once() {
[[maybe_unused]]
static const auto result =
[] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }();
TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result));
}
} // namespace
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) {
CHECK_CUDA(score);
CHECK_CUDA(indices);
CHECK_CUDA(lengths);
const auto params = get_params(score, lengths, indices);
const auto B = score.size(0);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_kernel, kSmem>();
topk_kernel<<<grid, block, kSmem, stream>>>(params);
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
void fast_topk_transform_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& dst_page_table,
const at::Tensor& src_page_table,
const at::Tensor& cu_seqlens_q) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(dst_page_table);
CHECK_CUDA(src_page_table);
CHECK_CUDA(cu_seqlens_q);
const auto params = get_params(score, lengths);
const auto B = score.size(0);
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous());
const auto prefill_bs = cu_seqlens_q.size(0) - 1;
TORCH_CHECK(dst_page_table.size(0) == B);
TORCH_CHECK(dst_page_table.size(1) == TopK);
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
const auto src_stride = src_page_table.stride(0);
// dispatch to decode or prefill
const auto is_decode = (prefill_bs == B);
if (is_decode) {
setup_kernel_smem_once<topk_transform_decode_kernel, kSmem>();
topk_transform_decode_kernel<<<grid, block, kSmem, stream>>>(
params, dst_page_table.data_ptr<int32_t>(), src_page_table.data_ptr<int32_t>(), src_stride);
} else {
setup_kernel_smem_once<topk_transform_prefill_kernel, kSmem>();
topk_transform_prefill_kernel<<<grid, block, kSmem, stream>>>(
params,
dst_page_table.data_ptr<int32_t>(),
src_page_table.data_ptr<int32_t>(),
src_stride,
cu_seqlens_q.data_ptr<int32_t>(),
prefill_bs);
}
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
void fast_topk_transform_ragged_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& topk_indices_ragged,
const at::Tensor& topk_indices_offset) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(topk_indices_ragged);
CHECK_CUDA(topk_indices_offset);
const auto params = get_params(score, lengths);
const auto B = score.size(0);
TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous());
TORCH_CHECK(topk_indices_offset.dim() == 1);
TORCH_CHECK(topk_indices_ragged.size(0) == B);
TORCH_CHECK(topk_indices_ragged.size(1) == TopK);
TORCH_CHECK(topk_indices_offset.size(0) == B);
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>();
topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>(
params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>());
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}