Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
7
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
7
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
@@ -201,13 +201,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
m.def(
|
||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
||||
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
|
||||
"float threshold_single, float threshold_acc, "
|
||||
"bool deterministic, int cuda_stream) -> ()");
|
||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||
@@ -224,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
|
||||
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
|
||||
m.def(
|
||||
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, "
|
||||
"int cuda_stream) -> ()");
|
||||
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||
|
||||
/*
|
||||
|
||||
Reference in New Issue
Block a user