Add greedy verification kernel (#4383)

This commit is contained in:
Ying Sheng
2025-03-16 00:58:26 -07:00
committed by GitHub
parent 06d12b39d3
commit 52a34d7448
11 changed files with 394 additions and 153 deletions

View File

@@ -183,8 +183,8 @@ void topk_softmax(
* From csrc/speculative
*/
void tree_speculative_sampling_target_only(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor predicts, // mutable
at::Tensor accept_index, // mutable
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
@@ -193,9 +193,22 @@ void tree_speculative_sampling_target_only(
at::Tensor uniform_samples,
at::Tensor target_probs,
at::Tensor draft_probs,
double threshold_single = 1,
double threshold_acc = 1,
bool deterministic = true,
int64_t cuda_stream = 0);
void verify_tree_greedy(
at::Tensor predicts, // mutable
at::Tensor accept_index, // mutable
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0);
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
@@ -209,16 +222,8 @@ void build_tree_kernel_efficient(
int64_t depth,
int64_t draft_token_num);
void build_tree_kernel(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
int64_t topk,
int64_t depth,
int64_t draft_token_num);
void segment_packbits(
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);
/*
* From FlashInfer