[DeepseekV32] Add fast_topk_transform_ragged_fused kernel (#11815)

Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
hlu1
2025-10-19 17:13:39 -07:00
committed by GitHub
parent 252dc4e112
commit 3b80232d06
6 changed files with 201 additions and 20 deletions

View File

@@ -174,13 +174,18 @@ void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
void fast_topk_interface(at::Tensor score, at::Tensor indices, at::Tensor lengths);
void fast_topk_interface(const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths);
void fast_topk_transform_interface(
at::Tensor score,
at::Tensor lengths,
at::Tensor dst_page_table,
at::Tensor src_page_table,
at::Tensor cu_seqlens_q);
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& dst_page_table,
const at::Tensor& src_page_table,
const at::Tensor& cu_seqlens_q);
void fast_topk_transform_ragged_interface(
const at::Tensor& score,
const at::Tensor& lengths,
at::Tensor& topk_indices_ragged,
const at::Tensor& topk_indices_offset);
#ifdef USE_ROCM
void gelu_quick(at::Tensor& out, const at::Tensor& input);