Add greedy verification kernel (#4383)
This commit is contained in:
@@ -183,8 +183,8 @@ void topk_softmax(
|
||||
* From csrc/speculative
|
||||
*/
|
||||
void tree_speculative_sampling_target_only(
|
||||
at::Tensor predicts,
|
||||
at::Tensor accept_index,
|
||||
at::Tensor predicts, // mutable
|
||||
at::Tensor accept_index, // mutable
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
@@ -193,9 +193,22 @@ void tree_speculative_sampling_target_only(
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
double threshold_single = 1,
|
||||
double threshold_acc = 1,
|
||||
bool deterministic = true,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void verify_tree_greedy(
|
||||
at::Tensor predicts, // mutable
|
||||
at::Tensor accept_index, // mutable
|
||||
at::Tensor accept_token_num, // mutable
|
||||
at::Tensor candidates,
|
||||
at::Tensor retrive_index,
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor target_predict,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
void build_tree_kernel_efficient(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
@@ -209,16 +222,8 @@ void build_tree_kernel_efficient(
|
||||
int64_t depth,
|
||||
int64_t draft_token_num);
|
||||
|
||||
void build_tree_kernel(
|
||||
at::Tensor parent_list,
|
||||
at::Tensor selected_index,
|
||||
at::Tensor verified_seq_len,
|
||||
at::Tensor tree_mask,
|
||||
at::Tensor positions,
|
||||
at::Tensor retrive_index,
|
||||
int64_t topk,
|
||||
int64_t depth,
|
||||
int64_t draft_token_num);
|
||||
void segment_packbits(
|
||||
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
|
||||
Reference in New Issue
Block a user