496 lines
17 KiB
Plaintext
496 lines
17 KiB
Plaintext
/**
|
|
* @NOTE: This file is adapted from
|
|
* https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py
|
|
* We:
|
|
* 1. adapt from tilelang to pure cuda
|
|
* 2. optimize the performance a little
|
|
* 3. fix the potential illegal memory access
|
|
*/
|
|
#include <ATen/core/TensorBase.h>
|
|
#include <ATen/core/TensorBody.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <cuda.h>
|
|
#include <cuda_fp16.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <optional>
|
|
|
|
namespace {
|
|
|
|
constexpr int TopK = 2048;
|
|
constexpr int kThreadsPerBlock = 1024;
|
|
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
|
|
|
|
struct FastTopKParams {
|
|
const float* __restrict__ input; // [B, input_stride]
|
|
int32_t* __restrict__ indices; // [B, TopK]
|
|
int32_t* __restrict__ lengths; // [B]
|
|
int64_t input_stride;
|
|
};
|
|
|
|
// when length <= TopK, we can directly write the indices
|
|
__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) {
|
|
const auto tid = threadIdx.x;
|
|
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
|
|
indice[i] = (i < length) ? i : -1;
|
|
}
|
|
}
|
|
|
|
// keep the first `length` entries, set others to -1
|
|
__device__ void naive_topk_transform(
|
|
const float* __restrict__ score,
|
|
int32_t length,
|
|
int32_t* __restrict__ dst_page_table,
|
|
const int32_t* __restrict__ src_page_table) {
|
|
const auto tid = threadIdx.x;
|
|
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
|
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
|
|
}
|
|
}
|
|
|
|
// keep the first `length` entries, set others to -1
|
|
__device__ void naive_topk_transform_ragged(
|
|
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
|
|
const auto tid = threadIdx.x;
|
|
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
|
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
|
|
}
|
|
}
|
|
|
|
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
|
__half h = __float2half_rn(x);
|
|
uint16_t bits = __half_as_ushort(h);
|
|
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits) : static_cast<uint16_t>(bits | 0x8000);
|
|
return static_cast<uint8_t>(key >> 8);
|
|
}
|
|
|
|
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
|
|
uint32_t bits = __float_as_uint(x);
|
|
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
|
|
}
|
|
|
|
__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) {
|
|
// An optimized topk kernel copied from tilelang kernel
|
|
// We assume length > TopK here, or it will crash
|
|
int topk = TopK;
|
|
constexpr auto BLOCK_SIZE = 1024;
|
|
constexpr auto RADIX = 256;
|
|
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
|
|
|
alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128];
|
|
alignas(128) __shared__ int s_counter;
|
|
alignas(128) __shared__ int s_threshold_bin_id;
|
|
alignas(128) __shared__ int s_num_input[2];
|
|
|
|
auto& s_histogram = s_histogram_buf[0];
|
|
// allocate for two rounds
|
|
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
|
|
|
const int tx = threadIdx.x;
|
|
|
|
// stage 1: 8bit coarse histogram
|
|
if (tx < RADIX + 1) s_histogram[tx] = 0;
|
|
__syncthreads();
|
|
|
|
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
|
const auto bin = convert_to_uint8(input[idx]);
|
|
::atomicAdd(&s_histogram[bin], 1);
|
|
}
|
|
__syncthreads();
|
|
|
|
const auto run_cumsum = [&] {
|
|
#pragma unroll 8
|
|
for (int i = 0; i < 8; ++i) {
|
|
static_assert(1 << 8 == RADIX);
|
|
if (C10_LIKELY(tx < RADIX)) {
|
|
const auto j = 1 << i;
|
|
const auto k = i & 1;
|
|
auto value = s_histogram_buf[k][tx];
|
|
if (tx < RADIX - j) {
|
|
value += s_histogram_buf[k][tx + j];
|
|
}
|
|
s_histogram_buf[k ^ 1][tx] = value;
|
|
}
|
|
__syncthreads();
|
|
}
|
|
};
|
|
|
|
run_cumsum();
|
|
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
|
|
s_threshold_bin_id = tx;
|
|
s_num_input[0] = 0;
|
|
s_counter = 0;
|
|
}
|
|
__syncthreads();
|
|
|
|
const auto threshold_bin = s_threshold_bin_id;
|
|
topk -= s_histogram[threshold_bin + 1];
|
|
|
|
if (topk == 0) {
|
|
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
|
const auto bin = static_cast<int>(convert_to_uint8(input[idx]));
|
|
if (bin > threshold_bin) {
|
|
const auto pos = ::atomicAdd(&s_counter, 1);
|
|
index[pos] = idx;
|
|
}
|
|
}
|
|
__syncthreads();
|
|
return;
|
|
} else {
|
|
__syncthreads();
|
|
if (tx < RADIX + 1) {
|
|
s_histogram[tx] = 0;
|
|
}
|
|
__syncthreads();
|
|
|
|
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
|
const auto raw_input = input[idx];
|
|
const auto bin = static_cast<int>(convert_to_uint8(raw_input));
|
|
if (bin > threshold_bin) {
|
|
const auto pos = ::atomicAdd(&s_counter, 1);
|
|
index[pos] = idx;
|
|
} else if (bin == threshold_bin) {
|
|
const auto pos = ::atomicAdd(&s_num_input[0], 1);
|
|
/// NOTE: (dark) fuse the histogram computation here
|
|
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
|
|
s_input_idx[0][pos] = idx;
|
|
const auto bin = convert_to_uint32(raw_input);
|
|
const auto sub_bin = (bin >> 24) & 0xFF;
|
|
::atomicAdd(&s_histogram[sub_bin], 1);
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// stage 2: refine with 8bit radix passes
|
|
#pragma unroll 4
|
|
for (int round = 0; round < 4; ++round) {
|
|
__shared__ int s_last_remain;
|
|
const auto r_idx = round % 2;
|
|
|
|
// clip here to prevent overflow
|
|
const auto _raw_num_input = s_num_input[r_idx];
|
|
const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE);
|
|
|
|
run_cumsum();
|
|
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
|
|
s_threshold_bin_id = tx;
|
|
s_num_input[r_idx ^ 1] = 0;
|
|
s_last_remain = topk - s_histogram[tx + 1];
|
|
}
|
|
__syncthreads();
|
|
|
|
const auto threshold_bin = s_threshold_bin_id;
|
|
topk -= s_histogram[threshold_bin + 1];
|
|
|
|
if (topk == 0) {
|
|
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
|
const auto idx = s_input_idx[r_idx][i];
|
|
const auto offset = 24 - round * 8;
|
|
const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF;
|
|
if (bin > threshold_bin) {
|
|
const auto pos = ::atomicAdd(&s_counter, 1);
|
|
index[pos] = idx;
|
|
}
|
|
}
|
|
__syncthreads();
|
|
break;
|
|
} else {
|
|
__syncthreads();
|
|
if (tx < RADIX + 1) {
|
|
s_histogram[tx] = 0;
|
|
}
|
|
__syncthreads();
|
|
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
|
const auto idx = s_input_idx[r_idx][i];
|
|
const auto raw_input = input[idx];
|
|
const auto offset = 24 - round * 8;
|
|
const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF;
|
|
if (bin > threshold_bin) {
|
|
const auto pos = ::atomicAdd(&s_counter, 1);
|
|
index[pos] = idx;
|
|
} else if (bin == threshold_bin) {
|
|
if (round == 3) {
|
|
const auto pos = ::atomicAdd(&s_last_remain, -1);
|
|
if (pos > 0) {
|
|
index[TopK - pos] = idx;
|
|
}
|
|
} else {
|
|
const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
|
|
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
|
|
/// NOTE: (dark) fuse the histogram computation here
|
|
s_input_idx[r_idx ^ 1][pos] = idx;
|
|
const auto bin = convert_to_uint32(raw_input);
|
|
const auto sub_bin = (bin >> (offset - 8)) & 0xFF;
|
|
::atomicAdd(&s_histogram[sub_bin], 1);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
}
|
|
}
|
|
|
|
__global__ __launch_bounds__(kThreadsPerBlock) // topk
|
|
void topk_kernel(const FastTopKParams params) {
|
|
const auto& [input, indices, lengths, input_stride] = params;
|
|
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
|
const auto length = lengths[bid];
|
|
const auto indice = indices + bid * TopK;
|
|
const auto score = input + bid * input_stride;
|
|
if (length <= TopK) {
|
|
return naive_topk_cuda(score, indice, length);
|
|
} else {
|
|
return fast_topk_cuda_tl(score, indice, length);
|
|
}
|
|
}
|
|
|
|
__global__ __launch_bounds__(kThreadsPerBlock) // decode
|
|
void topk_transform_decode_kernel(
|
|
const FastTopKParams params,
|
|
int32_t* __restrict__ dst_page_table,
|
|
const int32_t* __restrict__ src_page_table,
|
|
const int64_t src_stride) {
|
|
const auto& [input, _, lengths, input_stride] = params;
|
|
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
|
const auto tid = threadIdx.x;
|
|
const auto length = lengths[bid];
|
|
const auto src_page_entry = src_page_table + bid * src_stride;
|
|
const auto dst_page_entry = dst_page_table + bid * TopK;
|
|
const auto score = input + bid * input_stride;
|
|
if (length <= TopK) {
|
|
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
|
} else {
|
|
__shared__ int s_indices[TopK];
|
|
fast_topk_cuda_tl(score, s_indices, length);
|
|
// copy src[s_indices] to dst, we manually unroll here
|
|
static_assert(TopK % kThreadsPerBlock == 0);
|
|
static_assert(TopK / kThreadsPerBlock == 2);
|
|
const auto idx_0 = tid;
|
|
const auto pos_0 = s_indices[idx_0];
|
|
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
|
const auto idx_1 = tid + kThreadsPerBlock;
|
|
const auto pos_1 = s_indices[idx_1];
|
|
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
|
}
|
|
}
|
|
|
|
__global__ __launch_bounds__(kThreadsPerBlock) // prefill
|
|
void topk_transform_prefill_kernel(
|
|
const FastTopKParams params,
|
|
int32_t* __restrict__ dst_page_table,
|
|
const int32_t* __restrict__ src_page_table,
|
|
const int64_t src_stride,
|
|
const int32_t* __restrict__ cu_seqlens_q,
|
|
const int64_t prefill_bs) {
|
|
const auto& [input, _, lengths, input_stride] = params;
|
|
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
|
const auto tid = threadIdx.x;
|
|
const auto length = lengths[bid];
|
|
const auto dst_page_entry = dst_page_table + bid * TopK;
|
|
const auto score = input + bid * input_stride;
|
|
|
|
/// NOTE: prefill bs is usually small, we can just use a simple loop here
|
|
/// We ensure that last cu_seqlens is equal to number of blocks launched
|
|
__shared__ const int32_t* s_src_page_entry;
|
|
if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) {
|
|
if (tid < prefill_bs) {
|
|
if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) {
|
|
s_src_page_entry = src_page_table + tid * src_stride;
|
|
}
|
|
}
|
|
} else {
|
|
for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) {
|
|
if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) {
|
|
s_src_page_entry = src_page_table + i * src_stride;
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
const auto src_page_entry = s_src_page_entry;
|
|
|
|
if (length <= TopK) {
|
|
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
|
} else {
|
|
__shared__ int s_indices[TopK];
|
|
fast_topk_cuda_tl(score, s_indices, length);
|
|
// copy src[s_indices] to dst, we manually unroll here
|
|
static_assert(TopK % kThreadsPerBlock == 0);
|
|
static_assert(TopK / kThreadsPerBlock == 2);
|
|
const auto idx_0 = tid;
|
|
const auto pos_0 = s_indices[idx_0];
|
|
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
|
const auto idx_1 = tid + kThreadsPerBlock;
|
|
const auto pos_1 = s_indices[idx_1];
|
|
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
|
}
|
|
}
|
|
|
|
__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv
|
|
void topk_transform_prefill_ragged_kernel(
|
|
const FastTopKParams params,
|
|
int32_t* __restrict__ topk_indices_ragged,
|
|
const int32_t* __restrict__ topk_indices_offset) {
|
|
const auto& [input, _, lengths, input_stride] = params;
|
|
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
|
const auto tid = threadIdx.x;
|
|
const auto length = lengths[bid];
|
|
const auto dst_indices_entry = topk_indices_ragged + bid * TopK;
|
|
const auto score = input + bid * input_stride;
|
|
const auto offset = topk_indices_offset[bid];
|
|
|
|
if (length <= TopK) {
|
|
return naive_topk_transform_ragged(score, length, dst_indices_entry, offset);
|
|
} else {
|
|
__shared__ int s_indices[TopK];
|
|
fast_topk_cuda_tl(score, s_indices, length);
|
|
// copy src[s_indices] to dst, we manually unroll here
|
|
static_assert(TopK % kThreadsPerBlock == 0);
|
|
static_assert(TopK / kThreadsPerBlock == 2);
|
|
const auto idx_0 = tid;
|
|
const auto pos_0 = s_indices[idx_0];
|
|
dst_indices_entry[idx_0] = pos_0 + offset;
|
|
const auto idx_1 = tid + kThreadsPerBlock;
|
|
const auto pos_1 = s_indices[idx_1];
|
|
dst_indices_entry[idx_1] = pos_1 + offset;
|
|
}
|
|
}
|
|
|
|
auto get_params(
|
|
const at::Tensor& score,
|
|
const at::Tensor& lengths,
|
|
std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams {
|
|
const auto B = score.size(0);
|
|
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
|
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
|
TORCH_CHECK(lengths.size(0) == B);
|
|
int32_t* indices_data_ptr = nullptr;
|
|
if (indices_opt.has_value()) {
|
|
const auto& indices = indices_opt.value();
|
|
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
|
|
TORCH_CHECK(indices.size(0) == B);
|
|
TORCH_CHECK(indices.size(1) == TopK);
|
|
indices_data_ptr = indices.data_ptr<int32_t>();
|
|
}
|
|
|
|
return FastTopKParams{
|
|
.input = score.data_ptr<float>(),
|
|
.indices = indices_data_ptr,
|
|
.lengths = lengths.data_ptr<int32_t>(),
|
|
.input_stride = score.stride(0),
|
|
};
|
|
}
|
|
|
|
template <auto* f, size_t max_dynamic_smem>
|
|
void setup_kernel_smem_once() {
|
|
[[maybe_unused]]
|
|
static const auto result =
|
|
[] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }();
|
|
TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result));
|
|
}
|
|
|
|
} // namespace
|
|
|
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
|
|
|
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) {
|
|
CHECK_CUDA(score);
|
|
CHECK_CUDA(indices);
|
|
CHECK_CUDA(lengths);
|
|
const auto params = get_params(score, lengths, indices);
|
|
const auto B = score.size(0);
|
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
|
const auto block = dim3{kThreadsPerBlock};
|
|
setup_kernel_smem_once<topk_kernel, kSmem>();
|
|
topk_kernel<<<grid, block, kSmem, stream>>>(params);
|
|
const auto result = cudaGetLastError();
|
|
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
|
}
|
|
|
|
void fast_topk_transform_interface(
|
|
const at::Tensor& score,
|
|
const at::Tensor& lengths,
|
|
at::Tensor& dst_page_table,
|
|
const at::Tensor& src_page_table,
|
|
const at::Tensor& cu_seqlens_q) {
|
|
CHECK_CUDA(score);
|
|
CHECK_CUDA(lengths);
|
|
CHECK_CUDA(dst_page_table);
|
|
CHECK_CUDA(src_page_table);
|
|
CHECK_CUDA(cu_seqlens_q);
|
|
const auto params = get_params(score, lengths);
|
|
const auto B = score.size(0);
|
|
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
|
|
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
|
|
TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous());
|
|
const auto prefill_bs = cu_seqlens_q.size(0) - 1;
|
|
TORCH_CHECK(dst_page_table.size(0) == B);
|
|
TORCH_CHECK(dst_page_table.size(1) == TopK);
|
|
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
|
|
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
|
|
|
|
// launch kernel
|
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
|
const auto block = dim3{kThreadsPerBlock};
|
|
const auto src_stride = src_page_table.stride(0);
|
|
|
|
// dispatch to decode or prefill
|
|
const auto is_decode = (prefill_bs == B);
|
|
if (is_decode) {
|
|
setup_kernel_smem_once<topk_transform_decode_kernel, kSmem>();
|
|
topk_transform_decode_kernel<<<grid, block, kSmem, stream>>>(
|
|
params, dst_page_table.data_ptr<int32_t>(), src_page_table.data_ptr<int32_t>(), src_stride);
|
|
} else {
|
|
setup_kernel_smem_once<topk_transform_prefill_kernel, kSmem>();
|
|
topk_transform_prefill_kernel<<<grid, block, kSmem, stream>>>(
|
|
params,
|
|
dst_page_table.data_ptr<int32_t>(),
|
|
src_page_table.data_ptr<int32_t>(),
|
|
src_stride,
|
|
cu_seqlens_q.data_ptr<int32_t>(),
|
|
prefill_bs);
|
|
}
|
|
|
|
const auto result = cudaGetLastError();
|
|
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
|
}
|
|
|
|
void fast_topk_transform_ragged_interface(
|
|
const at::Tensor& score,
|
|
const at::Tensor& lengths,
|
|
at::Tensor& topk_indices_ragged,
|
|
const at::Tensor& topk_indices_offset) {
|
|
CHECK_CUDA(score);
|
|
CHECK_CUDA(lengths);
|
|
CHECK_CUDA(topk_indices_ragged);
|
|
CHECK_CUDA(topk_indices_offset);
|
|
|
|
const auto params = get_params(score, lengths);
|
|
const auto B = score.size(0);
|
|
TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous());
|
|
TORCH_CHECK(topk_indices_offset.dim() == 1);
|
|
|
|
TORCH_CHECK(topk_indices_ragged.size(0) == B);
|
|
TORCH_CHECK(topk_indices_ragged.size(1) == TopK);
|
|
TORCH_CHECK(topk_indices_offset.size(0) == B);
|
|
|
|
// launch kernel
|
|
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
const auto grid = dim3{static_cast<uint32_t>(B)};
|
|
const auto block = dim3{kThreadsPerBlock};
|
|
|
|
setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>();
|
|
topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>(
|
|
params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>());
|
|
|
|
const auto result = cudaGetLastError();
|
|
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
|
}
|