[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1
2025-10-19 17:13:39 -07:00
committed by GitHub
parent 252dc4e112
commit 3b80232d06
6 changed files with 201 additions and 20 deletions

View File

@@ -51,6 +51,15 @@ __device__ void naive_topk_transform(
}
}
// keep the first `length` entries, set others to -1
__device__ void naive_topk_transform_ragged(
const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) {
const auto tid = threadIdx.x;
for (auto i = tid; i < TopK; i += kThreadsPerBlock) {
topk_indices_ragged[i] = (i < length) ? static_cast<int32_t>(i) + offset : -1;
}
}
__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t {
__half h = __float2half_rn(x);
uint16_t bits = __half_as_ushort(h);
@@ -322,8 +331,40 @@ __global__ __launch_bounds__(kThreadsPerBlock) // prefill
}
}
auto get_params(at::Tensor score, at::Tensor lengths, std::optional<at::Tensor> indices_opt = std::nullopt)
-> FastTopKParams {
__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv
void topk_transform_prefill_ragged_kernel(
const FastTopKParams params,
int32_t* __restrict__ topk_indices_ragged,
const int32_t* __restrict__ topk_indices_offset) {
const auto& [input, _, lengths, input_stride] = params;
const auto bid = static_cast<uint64_t>(blockIdx.x);
const auto tid = threadIdx.x;
const auto length = lengths[bid];
const auto dst_indices_entry = topk_indices_ragged + bid * TopK;
const auto score = input + bid * input_stride;
const auto offset = topk_indices_offset[bid];
if (length <= TopK) {
return naive_topk_transform_ragged(score, length, dst_indices_entry, offset);
} else {
__shared__ int s_indices[TopK];
fast_topk_cuda_tl(score, s_indices, length);
// copy src[s_indices] to dst, we manually unroll here
static_assert(TopK % kThreadsPerBlock == 0);
static_assert(TopK / kThreadsPerBlock == 2);
const auto idx_0 = tid;
const auto pos_0 = s_indices[idx_0];
dst_indices_entry[idx_0] = pos_0 + offset;
const auto idx_1 = tid + kThreadsPerBlock;
const auto pos_1 = s_indices[idx_1];
dst_indices_entry[idx_1] = pos_1 + offset;
}
}
auto get_params(
const at::Tensor& score,
const at::Tensor& lengths,
std::optional<at::Tensor> indices_opt = std::nullopt) -> FastTopKParams {
const auto B = score.size(0);
TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1);
TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous());
@@ -357,7 +398,7 @@ void setup_kernel_smem_once() {
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths) {
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths) {
CHECK_CUDA(score);
CHECK_CUDA(indices);
CHECK_CUDA(lengths);
@@ -373,11 +414,11 @@ void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor length
}
void fast_topk_transform_interface(
at::Tensor score,
at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens_q) {
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& dst_page_table,
const at::Tensor& src_page_table,
const at::Tensor& cu_seqlens_q) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(dst_page_table);
@@ -420,3 +461,35 @@ void fast_topk_transform_interface(
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}
void fast_topk_transform_ragged_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& topk_indices_ragged,
const at::Tensor& topk_indices_offset) {
CHECK_CUDA(score);
CHECK_CUDA(lengths);
CHECK_CUDA(topk_indices_ragged);
CHECK_CUDA(topk_indices_offset);
const auto params = get_params(score, lengths);
const auto B = score.size(0);
TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous());
TORCH_CHECK(topk_indices_offset.dim() == 1);
TORCH_CHECK(topk_indices_ragged.size(0) == B);
TORCH_CHECK(topk_indices_ragged.size(1) == TopK);
TORCH_CHECK(topk_indices_offset.size(0) == B);
// launch kernel
const auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto grid = dim3{static_cast<uint32_t>(B)};
const auto block = dim3{kThreadsPerBlock};
setup_kernel_smem_once<topk_transform_prefill_ragged_kernel, kSmem>();
topk_transform_prefill_ragged_kernel<<<grid, block, kSmem, stream>>>(
params, topk_indices_ragged.data_ptr<int32_t>(), topk_indices_offset.data_ptr<int32_t>());
const auto result = cudaGetLastError();
TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result));
}