[Feature] Add a fast-topk to sgl-kernel for DeepSeek v3.2 (#11194)
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -268,6 +268,7 @@ set(SOURCES
|
||||
"csrc/elementwise/concat_mla.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
"csrc/elementwise/rope.cu"
|
||||
"csrc/elementwise/topk.cu"
|
||||
"csrc/common_extension.cc"
|
||||
|
||||
"csrc/gemm/awq_kernel.cu"
|
||||
|
||||
@@ -107,6 +107,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()");
|
||||
m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q);
|
||||
|
||||
m.def("fast_topk(Tensor score, Tensor indices, Tensor lengths) -> ()");
|
||||
m.impl("fast_topk", torch::kCUDA, &fast_topk_interface);
|
||||
m.def(
|
||||
"fast_topk_transform_fused(Tensor score, Tensor lengths, Tensor dst_page_table, Tensor src_page_table, Tensor "
|
||||
"cu_seqlens_q) -> ()");
|
||||
m.impl("fast_topk_transform_fused", torch::kCUDA, &fast_topk_transform_interface);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
|
||||
422
sgl-kernel/csrc/elementwise/topk.cu
Normal file
422
sgl-kernel/csrc/elementwise/topk.cu
Normal file
@@ -0,0 +1,422 @@
|
||||
/**
|
||||
* @NOTE: This file is adapted from
|
||||
* https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py
|
||||
* We:
|
||||
* 1. adapt from tilelang to pure cuda
|
||||
* 2. optimize the performance a little
|
||||
* 3. fix the potential illegal memory access
|
||||
*/
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int TopK = 2048;
|
||||
constexpr int kThreadsPerBlock = 1024;
|
||||
constexpr size_t kSmem = 32 * 1024 * sizeof(uint32_t); // 128KB
|
||||
|
||||
struct FastTopKParams {
|
||||
const float* __restrict__ input; // [B, input_stride]
|
||||
int32_t* __restrict__ indices; // [B, TopK]
|
||||
int32_t* __restrict__ lengths; // [B]
|
||||
int64_t input_stride;
|
||||
};
|
||||
|
||||
// when length <= TopK, we can directly write the indices
|
||||
__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) {
|
||||
const auto tid = threadIdx.x;
|
||||
for (int i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||
indice[i] = (i < length) ? i : -1;
|
||||
}
|
||||
}
|
||||
|
||||
// keep the first `length` entries, set others to -1
|
||||
__device__ void naive_topk_transform(
|
||||
const float* __restrict__ score,
|
||||
int32_t length,
|
||||
int32_t* __restrict__ dst_page_table,
|
||||
const int32_t* __restrict__ src_page_table) {
|
||||
const auto tid = threadIdx.x;
|
||||
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
|
||||
dst_page_table[i] = (i < length) ? src_page_table[i] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
|
||||
__half h = __float2half_rn(x);
|
||||
uint16_t bits = __half_as_ushort(h);
|
||||
uint16_t key = (bits & 0x8000) ? static_cast<uint16_t>(~bits) : static_cast<uint16_t>(bits | 0x8000);
|
||||
return static_cast<uint8_t>(key >> 8);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t {
|
||||
uint32_t bits = __float_as_uint(x);
|
||||
return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u);
|
||||
}
|
||||
|
||||
__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int length) {
|
||||
// An optimized topk kernel copied from tilelang kernel
|
||||
// We assume length > TopK here, or it will crash
|
||||
int topk = TopK;
|
||||
constexpr auto BLOCK_SIZE = 1024;
|
||||
constexpr auto RADIX = 256;
|
||||
constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int));
|
||||
|
||||
alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128];
|
||||
alignas(128) __shared__ int s_counter;
|
||||
alignas(128) __shared__ int s_threshold_bin_id;
|
||||
alignas(128) __shared__ int s_num_input[2];
|
||||
|
||||
auto& s_histogram = s_histogram_buf[0];
|
||||
// allocate for two rounds
|
||||
extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE];
|
||||
|
||||
const int tx = threadIdx.x;
|
||||
|
||||
// stage 1: 8bit coarse histogram
|
||||
if (tx < RADIX + 1) s_histogram[tx] = 0;
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||
const auto bin = convert_to_uint8(input[idx]);
|
||||
::atomicAdd(&s_histogram[bin], 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto run_cumsum = [&] {
|
||||
#pragma unroll 8
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
static_assert(1 << 8 == RADIX);
|
||||
if (C10_LIKELY(tx < RADIX)) {
|
||||
const auto j = 1 << i;
|
||||
const auto k = i & 1;
|
||||
auto value = s_histogram_buf[k][tx];
|
||||
if (tx < RADIX - j) {
|
||||
value += s_histogram_buf[k][tx + j];
|
||||
}
|
||||
s_histogram_buf[k ^ 1][tx] = value;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
};
|
||||
|
||||
run_cumsum();
|
||||
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
|
||||
s_threshold_bin_id = tx;
|
||||
s_num_input[0] = 0;
|
||||
s_counter = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto threshold_bin = s_threshold_bin_id;
|
||||
topk -= s_histogram[threshold_bin + 1];
|
||||
|
||||
if (topk == 0) {
|
||||
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||
const auto bin = static_cast<int>(convert_to_uint8(input[idx]));
|
||||
if (bin > threshold_bin) {
|
||||
const auto pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
return;
|
||||
} else {
|
||||
__syncthreads();
|
||||
if (tx < RADIX + 1) {
|
||||
s_histogram[tx] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = tx; idx < length; idx += BLOCK_SIZE) {
|
||||
const auto raw_input = input[idx];
|
||||
const auto bin = static_cast<int>(convert_to_uint8(raw_input));
|
||||
if (bin > threshold_bin) {
|
||||
const auto pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
} else if (bin == threshold_bin) {
|
||||
const auto pos = ::atomicAdd(&s_num_input[0], 1);
|
||||
/// NOTE: (dark) fuse the histogram computation here
|
||||
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
|
||||
s_input_idx[0][pos] = idx;
|
||||
const auto bin = convert_to_uint32(raw_input);
|
||||
const auto sub_bin = (bin >> 24) & 0xFF;
|
||||
::atomicAdd(&s_histogram[sub_bin], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// stage 2: refine with 8bit radix passes
|
||||
#pragma unroll 4
|
||||
for (int round = 0; round < 4; ++round) {
|
||||
__shared__ int s_last_remain;
|
||||
const auto r_idx = round % 2;
|
||||
|
||||
// clip here to prevent overflow
|
||||
const auto _raw_num_input = s_num_input[r_idx];
|
||||
const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE);
|
||||
|
||||
run_cumsum();
|
||||
if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) {
|
||||
s_threshold_bin_id = tx;
|
||||
s_num_input[r_idx ^ 1] = 0;
|
||||
s_last_remain = topk - s_histogram[tx + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto threshold_bin = s_threshold_bin_id;
|
||||
topk -= s_histogram[threshold_bin + 1];
|
||||
|
||||
if (topk == 0) {
|
||||
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||
const auto idx = s_input_idx[r_idx][i];
|
||||
const auto offset = 24 - round * 8;
|
||||
const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF;
|
||||
if (bin > threshold_bin) {
|
||||
const auto pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
break;
|
||||
} else {
|
||||
__syncthreads();
|
||||
if (tx < RADIX + 1) {
|
||||
s_histogram[tx] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int i = tx; i < num_input; i += BLOCK_SIZE) {
|
||||
const auto idx = s_input_idx[r_idx][i];
|
||||
const auto raw_input = input[idx];
|
||||
const auto offset = 24 - round * 8;
|
||||
const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF;
|
||||
if (bin > threshold_bin) {
|
||||
const auto pos = ::atomicAdd(&s_counter, 1);
|
||||
index[pos] = idx;
|
||||
} else if (bin == threshold_bin) {
|
||||
if (round == 3) {
|
||||
const auto pos = ::atomicAdd(&s_last_remain, -1);
|
||||
if (pos > 0) {
|
||||
index[TopK - pos] = idx;
|
||||
}
|
||||
} else {
|
||||
const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1);
|
||||
if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) {
|
||||
/// NOTE: (dark) fuse the histogram computation here
|
||||
s_input_idx[r_idx ^ 1][pos] = idx;
|
||||
const auto bin = convert_to_uint32(raw_input);
|
||||
const auto sub_bin = (bin >> (offset - 8)) & 0xFF;
|
||||
::atomicAdd(&s_histogram[sub_bin], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(kThreadsPerBlock) // topk
|
||||
void topk_kernel(const FastTopKParams params) {
|
||||
const auto& [input, indices, lengths, input_stride] = params;
|
||||
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
||||
const auto length = lengths[bid];
|
||||
const auto indice = indices + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
if (length <= TopK) {
|
||||
return naive_topk_cuda(score, indice, length);
|
||||
} else {
|
||||
return fast_topk_cuda_tl(score, indice, length);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(kThreadsPerBlock) // decode
|
||||
void topk_transform_decode_kernel(
|
||||
const FastTopKParams params,
|
||||
int32_t* __restrict__ dst_page_table,
|
||||
const int32_t* __restrict__ src_page_table,
|
||||
const int64_t src_stride) {
|
||||
const auto& [input, _, lengths, input_stride] = params;
|
||||
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
||||
const auto tid = threadIdx.x;
|
||||
const auto length = lengths[bid];
|
||||
const auto src_page_entry = src_page_table + bid * src_stride;
|
||||
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
if (length <= TopK) {
|
||||
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||
} else {
|
||||
__shared__ int s_indices[TopK];
|
||||
fast_topk_cuda_tl(score, s_indices, length);
|
||||
// copy src[s_indices] to dst, we manually unroll here
|
||||
static_assert(TopK % kThreadsPerBlock == 0);
|
||||
static_assert(TopK / kThreadsPerBlock == 2);
|
||||
const auto idx_0 = tid;
|
||||
const auto pos_0 = s_indices[idx_0];
|
||||
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||
const auto idx_1 = tid + kThreadsPerBlock;
|
||||
const auto pos_1 = s_indices[idx_1];
|
||||
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ __launch_bounds__(kThreadsPerBlock) // prefill
|
||||
void topk_transform_prefill_kernel(
|
||||
const FastTopKParams params,
|
||||
int32_t* __restrict__ dst_page_table,
|
||||
const int32_t* __restrict__ src_page_table,
|
||||
const int64_t src_stride,
|
||||
const int32_t* __restrict__ cu_seqlens_q,
|
||||
const int64_t prefill_bs) {
|
||||
const auto& [input, _, lengths, input_stride] = params;
|
||||
const auto bid = static_cast<uint64_t>(blockIdx.x);
|
||||
const auto tid = threadIdx.x;
|
||||
const auto length = lengths[bid];
|
||||
const auto dst_page_entry = dst_page_table + bid * TopK;
|
||||
const auto score = input + bid * input_stride;
|
||||
|
||||
/// NOTE: prefill bs is usually small, we can just use a simple loop here
|
||||
/// We ensure that last cu_seqlens is equal to number of blocks launched
|
||||
__shared__ const int32_t* s_src_page_entry;
|
||||
if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) {
|
||||
if (tid < prefill_bs) {
|
||||
if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) {
|
||||
s_src_page_entry = src_page_table + tid * src_stride;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) {
|
||||
if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) {
|
||||
s_src_page_entry = src_page_table + i * src_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
const auto src_page_entry = s_src_page_entry;
|
||||
|
||||
if (length <= TopK) {
|
||||
return naive_topk_transform(score, length, dst_page_entry, src_page_entry);
|
||||
} else {
|
||||
__shared__ int s_indices[TopK];
|
||||
fast_topk_cuda_tl(score, s_indices, length);
|
||||
// copy src[s_indices] to dst, we manually unroll here
|
||||
static_assert(TopK % kThreadsPerBlock == 0);
|
||||
static_assert(TopK / kThreadsPerBlock == 2);
|
||||
const auto idx_0 = tid;
|
||||
const auto pos_0 = s_indices[idx_0];
|
||||
dst_page_entry[idx_0] = src_page_entry[pos_0];
|
||||
const auto idx_1 = tid + kThreadsPerBlock;
|
||||
const auto pos_1 = s_indices[idx_1];
|
||||
dst_page_entry[idx_1] = src_page_entry[pos_1];
|
||||
}
|
||||
}
|
||||
|
||||
auto get_params(at::Tensor score, at::Tensor lengths, std::optional<at::Tensor> indices_opt = std::nullopt)
|
||||
-> FastTopKParams {
|
||||
const auto B = score.size(0);
|
||||
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
|
||||
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
|
||||
TORCH_CHECK(lengths.size(0) == B);
|
||||
int32_t* indices_data_ptr = nullptr;
|
||||
if (indices_opt.has_value()) {
|
||||
const auto& indices = indices_opt.value();
|
||||
TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous());
|
||||
TORCH_CHECK(indices.size(0) == B);
|
||||
TORCH_CHECK(indices.size(1) == TopK);
|
||||
indices_data_ptr = indices.data_ptr<int32_t>();
|
||||
}
|
||||
|
||||
return FastTopKParams{
|
||||
.input = score.data_ptr<float>(),
|
||||
.indices = indices_data_ptr,
|
||||
.lengths = lengths.data_ptr<int32_t>(),
|
||||
.input_stride = score.stride(0),
|
||||
};
|
||||
}
|
||||
|
||||
template <auto* f, size_t max_dynamic_smem>
|
||||
void setup_kernel_smem_once() {
|
||||
[[maybe_unused]]
|
||||
static const auto result =
|
||||
[] { return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); }();
|
||||
TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) {
|
||||
CHECK_CUDA(score);
|
||||
CHECK_CUDA(indices);
|
||||
CHECK_CUDA(lengths);
|
||||
const auto params = get_params(score, lengths, indices);
|
||||
const auto B = score.size(0);
|
||||
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||
const auto block = dim3{kThreadsPerBlock};
|
||||
setup_kernel_smem_once<topk_kernel, kSmem>();
|
||||
topk_kernel<<<grid, block, kSmem, stream>>>(params);
|
||||
const auto result = cudaGetLastError();
|
||||
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
|
||||
void fast_topk_transform_interface(
|
||||
at::Tensor score,
|
||||
at::Tensor lengths,
|
||||
at::Tensor dst_page_table,
|
||||
at::Tensor src_page_table,
|
||||
at::Tensor cu_seqlens_q) {
|
||||
CHECK_CUDA(score);
|
||||
CHECK_CUDA(lengths);
|
||||
CHECK_CUDA(dst_page_table);
|
||||
CHECK_CUDA(src_page_table);
|
||||
CHECK_CUDA(cu_seqlens_q);
|
||||
const auto params = get_params(score, lengths);
|
||||
const auto B = score.size(0);
|
||||
TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous());
|
||||
TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous());
|
||||
const auto prefill_bs = cu_seqlens_q.size(0) - 1;
|
||||
TORCH_CHECK(dst_page_table.size(0) == B);
|
||||
TORCH_CHECK(dst_page_table.size(1) == TopK);
|
||||
TORCH_CHECK(src_page_table.size(0) == prefill_bs);
|
||||
TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs
|
||||
|
||||
// launch kernel
|
||||
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
const auto grid = dim3{static_cast<uint32_t>(B)};
|
||||
const auto block = dim3{kThreadsPerBlock};
|
||||
const auto src_stride = src_page_table.stride(0);
|
||||
|
||||
// dispatch to decode or prefill
|
||||
const auto is_decode = (prefill_bs == B);
|
||||
if (is_decode) {
|
||||
setup_kernel_smem_once<topk_transform_decode_kernel, kSmem>();
|
||||
topk_transform_decode_kernel<<<grid, block, kSmem, stream>>>(
|
||||
params, dst_page_table.data_ptr<int32_t>(), src_page_table.data_ptr<int32_t>(), src_stride);
|
||||
} else {
|
||||
setup_kernel_smem_once<topk_transform_prefill_kernel, kSmem>();
|
||||
topk_transform_prefill_kernel<<<grid, block, kSmem, stream>>>(
|
||||
params,
|
||||
dst_page_table.data_ptr<int32_t>(),
|
||||
src_page_table.data_ptr<int32_t>(),
|
||||
src_stride,
|
||||
cu_seqlens_q.data_ptr<int32_t>(),
|
||||
prefill_bs);
|
||||
}
|
||||
|
||||
const auto result = cudaGetLastError();
|
||||
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
|
||||
}
|
||||
@@ -174,6 +174,14 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
||||
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
||||
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
|
||||
|
||||
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths);
|
||||
void fast_topk_transform_interface(
|
||||
at::Tensor score,
|
||||
at::Tensor lengths,
|
||||
at::Tensor dst_page_table,
|
||||
at::Tensor src_page_table,
|
||||
at::Tensor cu_seqlens_q);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
#endif
|
||||
|
||||
@@ -309,7 +309,7 @@ from sgl_kernel.speculative import (
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.top_k import fast_topk, fast_topk_transform_fused, fast_topk_v2
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
if torch.version.hip is not None:
|
||||
|
||||
@@ -9,3 +9,32 @@ def fast_topk(values, topk, dim):
|
||||
# Use topk for efficiency with larger k values
|
||||
# TODO: implement faster cuda kernels for large vocab sizes
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
||||
|
||||
def fast_topk_v2(score: torch.Tensor, lengths: torch.Tensor, topk: int) -> torch.Tensor:
|
||||
assert (
|
||||
topk == 2048
|
||||
), "fast_topk_v2 is only optimized for deepseek v3.2 model, where topk=2048"
|
||||
assert score.dim() == 2
|
||||
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||
torch.ops.sgl_kernel.fast_topk(score, topk_indices, lengths)
|
||||
return topk_indices
|
||||
|
||||
|
||||
def fast_topk_transform_fused(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
page_table_size_1: torch.Tensor, # NOTE: page size should be 1
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
topk == 2048
|
||||
), "fast_topk_transform_fused is only optimized for deepseek v3.2 model, where topk=2048"
|
||||
assert score.dim() == 2
|
||||
src_page_table = page_table_size_1
|
||||
dst_page_table = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||
torch.ops.sgl_kernel.fast_topk_transform_fused(
|
||||
score, lengths, dst_page_table, src_page_table, cu_seqlens_q
|
||||
)
|
||||
return dst_page_table
|
||||
|
||||
120
sgl-kernel/tests/test_topk.py
Normal file
120
sgl-kernel/tests/test_topk.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import pytest
|
||||
import torch
|
||||
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
|
||||
|
||||
|
||||
def _ref_torch_impl(score: torch.Tensor, seq_len: int, topk: int) -> torch.Tensor:
|
||||
assert score.dim() == 2
|
||||
return torch.topk(score[:, :seq_len], topk, dim=-1, sorted=False).indices
|
||||
|
||||
|
||||
def _ref_torch_transform_decode_impl(
|
||||
score: torch.Tensor,
|
||||
seq_len: int,
|
||||
src_page_table: torch.Tensor,
|
||||
topk: int,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _ = score.shape
|
||||
assert score.shape[0] == src_page_table.shape[0]
|
||||
assert seq_len >= topk
|
||||
indices = _ref_torch_impl(score, seq_len, topk)
|
||||
topk_indices = torch.empty(
|
||||
(batch_size, topk), dtype=torch.int32, device=score.device
|
||||
)
|
||||
for i in range(batch_size):
|
||||
topk_indices[i] = src_page_table[i, indices[i]]
|
||||
return topk_indices
|
||||
|
||||
|
||||
MAX_SEQ_LEN = 131072
|
||||
MAX_PERMIT_ERROR = 0
|
||||
|
||||
|
||||
def assert_equal(
|
||||
score: torch.Tensor,
|
||||
indices_ref: torch.Tensor,
|
||||
indices_our: torch.Tensor,
|
||||
bs: int,
|
||||
k: int,
|
||||
seq_len: int,
|
||||
):
|
||||
indices_our_cpu = indices_our.cpu().tolist()
|
||||
indices_ref_cpu = indices_ref.cpu().tolist()
|
||||
for i in range(bs):
|
||||
indices_ref_set_i = set(indices_ref_cpu[i])
|
||||
indices_our_set_i = set(indices_our_cpu[i])
|
||||
more = indices_our_set_i - indices_ref_set_i
|
||||
less = indices_ref_set_i - indices_our_set_i
|
||||
if len(more) > MAX_PERMIT_ERROR or len(less) > MAX_PERMIT_ERROR:
|
||||
# check whether more values are the same with less values
|
||||
# if so, either one is acceptable, since their values are the same
|
||||
more_values = sorted(score[i, idx].item() for idx in more)
|
||||
less_values = sorted(score[i, idx].item() for idx in less)
|
||||
assert (
|
||||
more_values == less_values
|
||||
), f"{bs=}, {k=}, {seq_len=}, {i=}, {more=}, {less=} failed, with {more_values=}, {less_values=}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
|
||||
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
|
||||
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
|
||||
@torch.inference_mode()
|
||||
def test_topk_kernel(bs: int, k: int, seq_len: int) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
|
||||
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
indices_ref = _ref_torch_impl(score, seq_len, k)
|
||||
indices_our = fast_topk_v2(score, lengths, k)
|
||||
|
||||
# sort and compare
|
||||
indices_ref = torch.sort(indices_ref, dim=-1).values
|
||||
indices_our = torch.sort(indices_our, dim=-1).values
|
||||
|
||||
assert_equal(score, indices_ref, indices_our, bs, k, seq_len)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bs", [1, 132, 256, 4096])
|
||||
@pytest.mark.parametrize("k", [2048]) # we only support 2048 now
|
||||
@pytest.mark.parametrize("seq_len", [2048, 4096, 16384, 65536])
|
||||
@torch.inference_mode()
|
||||
def test_topk_transform_kernel(bs: int, k: int, seq_len: int) -> None:
|
||||
# TODO(dark): test prefill kernel, though nothing special
|
||||
MAX_PERMIT_ERROR = 1
|
||||
torch.manual_seed(42)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
torch.cuda.set_stream(stream)
|
||||
score = torch.randn(bs, MAX_SEQ_LEN, dtype=torch.float32, device="cuda")
|
||||
lengths = torch.full((bs,), seq_len, dtype=torch.int32, device="cuda")
|
||||
src_page_table = torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
||||
src_page_table = src_page_table.unsqueeze(0).expand(bs, -1)
|
||||
# NOTE: for decode, cumulative seqlens_q is just 0..=bs
|
||||
# NOTE: since page table is arange, they equal topk indices
|
||||
cu_seqlens_q = torch.arange(0, bs + 1, dtype=torch.int32, device="cuda")
|
||||
dst_page_table_ref = _ref_torch_transform_decode_impl(
|
||||
score=score,
|
||||
seq_len=seq_len,
|
||||
src_page_table=src_page_table,
|
||||
topk=k,
|
||||
)
|
||||
dst_page_table_our = fast_topk_transform_fused(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
page_table_size_1=src_page_table,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
topk=k,
|
||||
)
|
||||
|
||||
# sort and compare
|
||||
dst_page_table_our = torch.sort(dst_page_table_our, dim=-1).values
|
||||
dst_page_table_ref = torch.sort(dst_page_table_ref, dim=-1).values
|
||||
|
||||
assert_equal(score, dst_page_table_ref, dst_page_table_our, bs, k, seq_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user