Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -331,6 +331,7 @@ void tree_speculative_sampling_target_only(
|
||||
at::Tensor retrive_next_token,
|
||||
at::Tensor retrive_next_sibling,
|
||||
at::Tensor uniform_samples,
|
||||
at::Tensor uniform_samples_for_final_sampling,
|
||||
at::Tensor target_probs,
|
||||
at::Tensor draft_probs,
|
||||
double threshold_single = 1,
|
||||
@@ -363,7 +364,12 @@ void build_tree_kernel_efficient(
|
||||
int64_t draft_token_num);
|
||||
|
||||
void segment_packbits(
|
||||
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);
|
||||
at::Tensor x,
|
||||
at::Tensor input_indptr,
|
||||
at::Tensor output_indptr,
|
||||
at::Tensor y,
|
||||
int64_t batch_size,
|
||||
int64_t cuda_stream = 0);
|
||||
|
||||
/*
|
||||
* From FlashInfer
|
||||
|
||||
Reference in New Issue
Block a user