diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc old mode 100755 new mode 100644 index 68424f07c..ed9f406e6 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -201,13 +201,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); + /* * From csrc/speculative */ m.def( "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " - "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " + "Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, " "float threshold_single, float threshold_acc, " "bool deterministic, int cuda_stream) -> ()"); m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); @@ -224,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); - m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"); + m.def( + "segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, " + "int cuda_stream) -> ()"); m.impl("segment_packbits", torch::kCUDA, &segment_packbits); /* diff --git a/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu index aeb6b8421..8b0759765 100644 --- a/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -32,7 +32,7 @@ __global__ void build_tree_efficient( int64_t* parent_list, int64_t* selected_index, - int32_t* verified_seq_len, + int64_t* verified_seq_len, bool* tree_mask, int64_t* positions, int64_t* retrive_index, @@ -135,7 +135,7 @@ void build_tree_kernel_efficient( build_tree_efficient<<>>( static_cast(parent_list.data_ptr()), static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), + static_cast(verified_seq_len.data_ptr()), static_cast(tree_mask.data_ptr()), static_cast(positions.data_ptr()), static_cast(retrive_index.data_ptr()), @@ -146,32 +146,32 @@ void build_tree_kernel_efficient( int32_t(draft_token_num)); } -template +template __global__ void VerifyTreeGreedy( IdType* predicts, IdType* accept_index, IdType* accept_token_num, // mutable - IdType* candidates, - IdType* retrive_index, - IdType* retrive_next_token, - IdType* retrive_next_sibling, - IdType* target_predict, + IdType2* candidates, + IdType2* retrive_index, + IdType2* retrive_next_token, + IdType2* retrive_next_sibling, + IdType2* target_predict, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t num_draft_tokens) { uint32_t bx = blockIdx.x; - IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; uint32_t num_accepted_tokens = 0; - IdType cur_index = 0; + IdType2 cur_index = 0; for (uint32_t j = 1; j < num_speculative_tokens; ++j) { cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; while (cur_index != -1) { - IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; - IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; - IdType target_token_id = target_predict[last_accepted_retrive_idx]; + IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + IdType2 target_token_id = target_predict[last_accepted_retrive_idx]; if (draft_token_id == target_token_id) { // accept token @@ -251,35 +251,35 @@ void verify_tree_greedy( if (accept_token_num.scalar_type() != at::kInt) { throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); } - if (candidates.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32)."); + if (candidates.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64)."); } - if (retrive_index.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32)."); + if (retrive_index.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64)."); } - if (retrive_next_token.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32)."); + if (retrive_next_token.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64)."); } - if (retrive_next_sibling.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32)."); + if (retrive_next_sibling.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64)."); } - if (target_predict.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32)."); + if (target_predict.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64)."); } cudaStream_t stream = reinterpret_cast(cuda_stream); dim3 grid(batch_size); dim3 block(1); - VerifyTreeGreedy<<>>( - static_cast(predicts.data_ptr()), - static_cast(accept_index.data_ptr()), - static_cast(accept_token_num.data_ptr()), - static_cast(candidates.data_ptr()), - static_cast(retrive_index.data_ptr()), - static_cast(retrive_next_token.data_ptr()), - static_cast(retrive_next_sibling.data_ptr()), - static_cast(target_predict.data_ptr()), + VerifyTreeGreedy<<>>( + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(target_predict.data_ptr()), batch_size, num_spec_step, num_draft_tokens); diff --git a/sgl-kernel/csrc/speculative/packbit.cu b/sgl-kernel/csrc/speculative/packbit.cu index 687dbfa2b..1decc3ded 100644 --- a/sgl-kernel/csrc/speculative/packbit.cu +++ b/sgl-kernel/csrc/speculative/packbit.cu @@ -24,7 +24,12 @@ using namespace flashinfer; // bitorder = "little" void segment_packbits( - at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) { + at::Tensor x, + at::Tensor input_indptr, + at::Tensor output_indptr, + at::Tensor y, + int64_t batch_size, + int64_t cuda_stream) { CHECK_INPUT(x); CHECK_INPUT(input_indptr); CHECK_INPUT(output_indptr); @@ -32,8 +37,7 @@ void segment_packbits( CHECK_EQ(input_indptr.device(), device); CHECK_EQ(output_indptr.device(), device); CHECK_EQ(y.device(), device); - unsigned int batch_size = input_indptr.size(0) - 1; - CHECK_EQ(output_indptr.size(0), batch_size + 1); + CHECK_GE(output_indptr.size(0), batch_size + 1); cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = quantization::SegmentPackBits( diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/csrc/speculative/speculative_sampling.cu index c03e1d772..ca545e99e 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cu +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cu @@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only( at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, at::Tensor uniform_samples, + at::Tensor uniform_samples_for_final_sampling, at::Tensor target_probs, at::Tensor draft_probs, double threshold_single, @@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only( CHECK_INPUT(retrive_next_token); CHECK_INPUT(retrive_next_sibling); CHECK_INPUT(uniform_samples); + CHECK_INPUT(uniform_samples_for_final_sampling); CHECK_INPUT(target_probs); auto device = target_probs.device(); CHECK_EQ(candidates.device(), device); @@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only( CHECK_EQ(retrive_next_token.device(), device); CHECK_EQ(retrive_next_sibling.device(), device); CHECK_EQ(uniform_samples.device(), device); + CHECK_EQ(uniform_samples_for_final_sampling.device(), device); CHECK_EQ(target_probs.device(), device); CHECK_DIM(1, predicts); CHECK_DIM(2, accept_index); @@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only( if (accept_token_num.scalar_type() != at::kInt) { throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); } - if (candidates.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32)."); + if (candidates.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64)."); } - if (retrive_index.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32)."); + if (retrive_index.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64)."); } - if (retrive_next_token.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32)."); + if (retrive_next_token.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64)."); } - if (retrive_next_sibling.scalar_type() != at::kInt) { - throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32)."); + if (retrive_next_sibling.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64)."); } if (uniform_samples.scalar_type() != at::kFloat) { throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32)."); } + if (uniform_samples_for_final_sampling.scalar_type() != at::kFloat) { + throw std::runtime_error("Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32)."); + } if (target_probs.scalar_type() != at::kFloat) { throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32)."); } @@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only( CHECK_GE(1, threshold_acc); cudaStream_t stream = reinterpret_cast(cuda_stream); - cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( - static_cast(predicts.data_ptr()), - static_cast(accept_index.data_ptr()), - static_cast(accept_token_num.data_ptr()), - static_cast(candidates.data_ptr()), - static_cast(retrive_index.data_ptr()), - static_cast(retrive_next_token.data_ptr()), - static_cast(retrive_next_sibling.data_ptr()), + cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), static_cast(uniform_samples.data_ptr()), + static_cast(uniform_samples_for_final_sampling.data_ptr()), static_cast(target_probs.data_ptr()), static_cast(draft_probs.data_ptr()), batch_size, diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh index a773c0e27..59f18bc2f 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -34,16 +34,18 @@ template < uint32_t VEC_SIZE, bool DETERMINISTIC, typename DType, - typename IdType> + typename IdType, + typename IdType2> __global__ void TreeSpeculativeSamplingTargetOnly( IdType* predicts, // mutable IdType* accept_index, // mutable IdType* accept_token_num, // mutable - IdType* candidates, - IdType* retrive_index, - IdType* retrive_next_token, - IdType* retrive_next_sibling, + IdType2* candidates, + IdType2* retrive_index, + IdType2* retrive_next_token, + IdType2* retrive_next_sibling, DType* uniform_samples, + DType* uniform_samples_for_final_sampling, DType* target_probs, DType* draft_probs, uint32_t batch_size, @@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly( DType prob_acc = 0.0; uint32_t cur_prob_offset = bx * num_draft_tokens * d; DType coin = uniform_samples[bx * num_draft_tokens]; - IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; uint32_t num_accepted_tokens = 0; - IdType cur_index = 0; + IdType2 cur_index = 0; for (uint32_t j = 1; j < num_speculative_tokens; ++j) { cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; while (cur_index != -1) { - IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; - IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index]; DType target_prob_single = target_probs[cur_prob_offset + draft_token_id]; prob_acc += target_prob_single; @@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly( } accept_token_num[bx] = num_accepted_tokens; + // we need a different coin for the final sampling + coin = uniform_samples_for_final_sampling[bx]; + // sample from relu(target_probs - draft_probs) DType sum_relu_q_minus_p(0); vec_t q_vec, p_vec; @@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly( // value at not used indices are undefined } -template +template cudaError_t TreeSpeculativeSamplingTargetOnly( IdType* predicts, // mutable IdType* output_token_ids, // mutable IdType* output_accepted_token_num, // mutable - IdType* candidates, - IdType* retrive_index, - IdType* retrive_next_token, - IdType* retrive_next_sibling, + IdType2* candidates, + IdType2* retrive_index, + IdType2* retrive_next_token, + IdType2* retrive_next_sibling, DType* uniform_samples, + DType* uniform_samples_for_final_sampling, DType* target_probs, DType* draft_probs, uint32_t batch_size, @@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( &retrive_next_token, &retrive_next_sibling, &uniform_samples, + &uniform_samples_for_final_sampling, &target_probs, &draft_probs, &batch_size, @@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( VEC_SIZE, DETERMINISTIC, DType, - IdType>; + IdType, + IdType2>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index bb267735b..9588bc736 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -331,6 +331,7 @@ void tree_speculative_sampling_target_only( at::Tensor retrive_next_token, at::Tensor retrive_next_sibling, at::Tensor uniform_samples, + at::Tensor uniform_samples_for_final_sampling, at::Tensor target_probs, at::Tensor draft_probs, double threshold_single = 1, @@ -363,7 +364,12 @@ void build_tree_kernel_efficient( int64_t draft_token_num); void segment_packbits( - at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream); + at::Tensor x, + at::Tensor input_indptr, + at::Tensor output_indptr, + at::Tensor y, + int64_t batch_size, + int64_t cuda_stream = 0); /* * From FlashInfer diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 4d5065bd4..d9ce1ff5a 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -72,6 +72,7 @@ from sgl_kernel.speculative import ( tree_speculative_sampling_target_only, verify_tree_greedy, ) +from sgl_kernel.top_k import fast_topk from sgl_kernel.version import __version__ build_tree_kernel = ( diff --git a/sgl-kernel/python/sgl_kernel/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py index 6eee58394..0ff46148a 100644 --- a/sgl-kernel/python/sgl_kernel/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only( retrive_next_token: torch.Tensor, retrive_next_sibling: torch.Tensor, uniform_samples: torch.Tensor, + uniform_samples_for_final_sampling: torch.Tensor, target_probs: torch.Tensor, draft_probs: torch.Tensor, threshold_single: float = 1.0, @@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only( retrive_next_token, retrive_next_sibling, uniform_samples, + uniform_samples_for_final_sampling, target_probs, draft_probs, threshold_single, @@ -91,11 +93,13 @@ def segment_packbits( input_indptr: torch.Tensor, output_indptr: torch.Tensor, y: torch.Tensor, + batch_size: int, ) -> None: torch.ops.sgl_kernel.segment_packbits.default( x, input_indptr, output_indptr, y, + batch_size, torch.cuda.current_stream().cuda_stream, ) diff --git a/sgl-kernel/python/sgl_kernel/top_k.py b/sgl-kernel/python/sgl_kernel/top_k.py new file mode 100644 index 000000000..fc29a6db8 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/top_k.py @@ -0,0 +1,11 @@ +import torch + + +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + return torch.max(values, dim=dim, keepdim=True) + else: + # Use topk for efficiency with larger k values + # TODO: implement faster cuda kernels for large vocab sizes + return torch.topk(values, topk, dim=dim) diff --git a/sgl-kernel/tests/speculative/test_eagle_utils.py b/sgl-kernel/tests/speculative/test_eagle_utils.py index 03e6825de..503355387 100644 --- a/sgl-kernel/tests/speculative/test_eagle_utils.py +++ b/sgl-kernel/tests/speculative/test_eagle_utils.py @@ -10,7 +10,7 @@ def test_verify_tree_greedy(): [0, 1, 2, 3, 4, 5], [7, 8, 9, 10, 11, 12], ], - dtype=torch.int32, + dtype=torch.int64, device="cuda", ) retrive_index = torch.tensor( @@ -18,7 +18,7 @@ def test_verify_tree_greedy(): [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], ], - dtype=torch.int32, + dtype=torch.int64, device="cuda", ) retrive_next_token = torch.tensor( @@ -26,7 +26,7 @@ def test_verify_tree_greedy(): [1, 2, -1, 4, 5, -1], [4, 2, 3, -1, 5, -1], ], - dtype=torch.int32, + dtype=torch.int64, device="cuda", ) retrive_next_sibling = torch.tensor( @@ -34,7 +34,7 @@ def test_verify_tree_greedy(): [-1, 3, -1, -1, -1, -1], [-1, -1, -1, -1, 1, -1], ], - dtype=torch.int32, + dtype=torch.int64, device="cuda", ) @@ -49,12 +49,11 @@ def test_verify_tree_greedy(): if torch.max(target_logits[i][j]) < 10: target_logits[i][j][18] = 10 - target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32) + target_predict = torch.argmax(target_logits, dim=-1) predict_shape = (12,) bs = candidates.shape[0] num_spec_step = 4 - num_draft_tokens = candidates.shape[1] predicts = torch.full( predict_shape, -1, dtype=torch.int32, device="cuda" diff --git a/sgl-kernel/tests/speculative/test_speculative_sampling.py b/sgl-kernel/tests/speculative/test_speculative_sampling.py index 56dd02b84..a9b59bb2e 100644 --- a/sgl-kernel/tests/speculative/test_speculative_sampling.py +++ b/sgl-kernel/tests/speculative/test_speculative_sampling.py @@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only( [0, 1, 2, 3, 4, 5], [7, 8, 9, 10, 11, 12], ], - dtype=torch.int32, + dtype=torch.int64, device=device, ) retrive_index = torch.tensor( @@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only( [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], ], - dtype=torch.int32, + dtype=torch.int64, device=device, ) retrive_next_token = torch.tensor( @@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only( [1, 2, -1, 4, 5, -1], [4, 2, 3, -1, 5, -1], ], - dtype=torch.int32, + dtype=torch.int64, device=device, ) retrive_next_sibling = torch.tensor( @@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only( [-1, 3, -1, -1, -1, -1], [-1, -1, -1, -1, 1, -1], ], - dtype=torch.int32, + dtype=torch.int64, device=device, ) @@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only( target_probs = F.softmax(target_logits / expanded_temperature, dim=-1) draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device) coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32) + coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32) tree_speculative_sampling_target_only( predicts=predicts, @@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only( retrive_next_token=retrive_next_token, retrive_next_sibling=retrive_next_sibling, uniform_samples=coins, + uniform_samples_for_final_sampling=coins_for_final_sampling, target_probs=target_probs, draft_probs=draft_probs, threshold_single=threshold_single,