From 52a34d7448bf8a90ede346a701bf2061e5d25bca Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Sun, 16 Mar 2025 00:58:26 -0700 Subject: [PATCH] Add greedy verification kernel (#4383) --- sgl-kernel/csrc/speculative/eagle_utils.cu | 229 ++++++++++-------- sgl-kernel/csrc/speculative/packbit.cu | 47 ++++ .../csrc/speculative/speculative_sampling.cu | 11 +- .../csrc/speculative/speculative_sampling.cuh | 18 +- sgl-kernel/csrc/torch_extension.cc | 21 +- sgl-kernel/include/sgl_kernel_ops.h | 29 ++- sgl-kernel/python/sgl_kernel/__init__.py | 7 +- sgl-kernel/python/sgl_kernel/speculative.py | 58 +++-- sgl-kernel/setup.py | 1 + .../tests/speculative/test_eagle_utils.py | 98 ++++++++ .../test_speculative_sampling.py | 28 ++- 11 files changed, 394 insertions(+), 153 deletions(-) create mode 100644 sgl-kernel/csrc/speculative/packbit.cu create mode 100644 sgl-kernel/tests/speculative/test_eagle_utils.py rename sgl-kernel/tests/{ => speculative}/test_speculative_sampling.py (76%) diff --git a/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu index 1bfd6fd84..968a8a264 100644 --- a/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -17,6 +17,8 @@ #include #include +#include "pytorch_extension_utils.h" + // parent_list [bs, topk * (depth - 1) + 1)] // selected_index [bs, draft_token_num - 1] // verified_seq_len [bs] @@ -72,8 +74,8 @@ __global__ void build_tree_efficient( } if (parent_position == draft_token_num) { printf( - "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token " - "will be dropped."); + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); continue; } @@ -140,112 +142,141 @@ void build_tree_kernel_efficient( int32_t(draft_token_num)); } -// parent_list [bs, topk * (depth - 1) + 1)] -// selected_index [bs, draft_token_num - 1] -// verified_seq_len [bs] -// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = -// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, -// draft_token, depth + 2] -__global__ void build_tree( - int64_t* parent_list, - int64_t* selected_index, - int32_t* verified_seq_len, - bool* tree_mask, - int64_t* positions, - int64_t* retrive_index, - int topk, - int depth, - int draft_token_num) { - int bid = blockIdx.x; - int tid = threadIdx.x; +template +__global__ void VerifyTreeGreedy( + IdType* predicts, + IdType* accept_index, + IdType* accept_token_num, // mutable + IdType* candidates, + IdType* retrive_index, + IdType* retrive_next_token, + IdType* retrive_next_sibling, + IdType* target_predict, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens) { + uint32_t bx = blockIdx.x; - if (tid >= draft_token_num) { - return; - } - int seq_tree_idx = draft_token_num * draft_token_num * bid; - for (int i = 0; i < bid; i++) { - seq_tree_idx += verified_seq_len[i] * draft_token_num; - } - int seq_len = verified_seq_len[bid]; - int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; - for (int i = 0; i < draft_token_num - 1; i++) { - tree_mask[token_tree_idx + i] = false; - } + IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; + uint32_t num_accepted_tokens = 0; + IdType cur_index = 0; - int position = 0; - if (tid == 0) { - positions[bid * draft_token_num] = seq_len; - retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num; - return; - } + for (uint32_t j = 1; j < num_speculative_tokens; ++j) { + cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; + while (cur_index != -1) { + IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + IdType target_token_id = target_predict[last_accepted_retrive_idx]; - int depends_order[10]; - - int cur_position = tid - 1; - while (true) { - depends_order[position] = cur_position + 1; - position += 1; - tree_mask[token_tree_idx + cur_position] = true; - int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; - if (parent_tb_idx == 0) { - break; - } - - int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; - for (cur_position = 0; cur_position < draft_token_num; cur_position++) { - if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + if (draft_token_id == target_token_id) { + // accept token + predicts[last_accepted_retrive_idx] = target_token_id; + ++num_accepted_tokens; + accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index; + last_accepted_retrive_idx = draft_index; break; + } else { + cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index]; } } - if (cur_position == draft_token_num) { - printf( - "ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token " - "will be dropped."); - break; - } - } - positions[bid * draft_token_num + tid] = position + seq_len; - - int is_leaf = 0; - for (int i = 1; i < draft_token_num; i++) { - if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) { - is_leaf++; - } - } - if (is_leaf == 1) { - for (int i = 0; i < position; i++) { - retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] = - depends_order[i] + bid * draft_token_num; - } - retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num; + if (cur_index == -1) break; } + accept_token_num[bx] = num_accepted_tokens; + predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx]; } -void build_tree_kernel( - at::Tensor parent_list, - at::Tensor selected_index, - at::Tensor verified_seq_len, - at::Tensor tree_mask, - at::Tensor positions, +// predicts: [tot_num_draft_tokens] +// accept_index: [bs, num_spec_step] +// accept_token_num: [bs] +// candidates: [bs, num_draft_tokens] +// retrive_index: [bs, num_draft_tokens] +// retrive_next_token: [bs, num_draft_tokens] +// retrive_next_sibling: [bs, num_draft_tokens] +// target_predict: [bs, num_draft_tokens] +void verify_tree_greedy( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, at::Tensor retrive_index, - int64_t topk, - int64_t depth, - int64_t draft_token_num) { - // TODO (ying) check shape - // TODO (ying) check type - int bs = parent_list.size(0); - dim3 grid(bs); - dim3 block(draft_token_num); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor target_predict, + int64_t cuda_stream = 0) { + CHECK_INPUT(candidates); + CHECK_INPUT(retrive_index); + CHECK_INPUT(retrive_next_token); + CHECK_INPUT(retrive_next_sibling); + CHECK_INPUT(target_predict); + auto device = target_predict.device(); + CHECK_EQ(candidates.device(), device); + CHECK_EQ(retrive_index.device(), device); + CHECK_EQ(retrive_next_token.device(), device); + CHECK_EQ(retrive_next_sibling.device(), device); + CHECK_EQ(target_predict.device(), device); + CHECK_DIM(1, predicts); + CHECK_DIM(2, accept_index); + CHECK_DIM(1, accept_token_num); + CHECK_DIM(2, candidates); + CHECK_DIM(2, retrive_index); + CHECK_DIM(2, retrive_next_token); + CHECK_DIM(2, retrive_next_sibling); + CHECK_DIM(2, target_predict); + unsigned int batch_size = candidates.size(0); + unsigned int num_spec_step = accept_index.size(1); + unsigned int num_draft_tokens = candidates.size(1); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + CHECK_EQ(batch_size, retrive_index.size(0)); + CHECK_EQ(batch_size, retrive_next_token.size(0)); + CHECK_EQ(batch_size, retrive_next_sibling.size(0)); + CHECK_EQ(batch_size, target_predict.size(0)); + CHECK_EQ(num_draft_tokens, retrive_index.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_token.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1)); + CHECK_EQ(num_draft_tokens, target_predict.size(1)); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + if (predicts.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32)."); + } + if (accept_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32)."); + } + if (accept_token_num.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); + } + if (candidates.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32)."); + } + if (retrive_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32)."); + } + if (retrive_next_token.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32)."); + } + if (retrive_next_sibling.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32)."); + } + if (target_predict.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32)."); + } - build_tree<<>>( - static_cast(parent_list.data_ptr()), - static_cast(selected_index.data_ptr()), - static_cast(verified_seq_len.data_ptr()), - static_cast(tree_mask.data_ptr()), - static_cast(positions.data_ptr()), - static_cast(retrive_index.data_ptr()), - int32_t(topk), - int32_t(depth), - int32_t(draft_token_num)); + cudaStream_t stream = reinterpret_cast(cuda_stream); + dim3 grid(batch_size); + dim3 block(1); + + VerifyTreeGreedy<<>>( + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(target_predict.data_ptr()), + batch_size, + num_spec_step, + num_draft_tokens); } diff --git a/sgl-kernel/csrc/speculative/packbit.cu b/sgl-kernel/csrc/speculative/packbit.cu new file mode 100644 index 000000000..c65ba4518 --- /dev/null +++ b/sgl-kernel/csrc/speculative/packbit.cu @@ -0,0 +1,47 @@ +// This is only a pluggin used for flashinfer 0.1.6. The new version does not need it. +/* + * Copyright (c) 2025 by SGLang team. + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +// bitorder = "little" +void segment_packbits( + at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) { + CHECK_INPUT(x); + CHECK_INPUT(input_indptr); + CHECK_INPUT(output_indptr); + auto device = x.device(); + CHECK_EQ(input_indptr.device(), device); + CHECK_EQ(output_indptr.device(), device); + CHECK_EQ(y.device(), device); + unsigned int batch_size = input_indptr.size(0) - 1; + CHECK_EQ(output_indptr.size(0), batch_size + 1); + + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = quantization::SegmentPackBits( + static_cast(x.data_ptr()), + static_cast(y.data_ptr()), + static_cast(input_indptr.data_ptr()), + static_cast(output_indptr.data_ptr()), + batch_size, + quantization::BitOrder::kLittle, + stream); +} diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cu b/sgl-kernel/csrc/speculative/speculative_sampling.cu index 6eaafdb5b..c03e1d772 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cu +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cu @@ -14,7 +14,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "pytorch_extension_utils.h" #include "speculative_sampling.cuh" @@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only( at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, - bool deterministic, + double threshold_single, + double threshold_acc, + bool deterministic = true, int64_t cuda_stream = 0) { CHECK_INPUT(candidates); CHECK_INPUT(retrive_index); @@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only( if (draft_probs.scalar_type() != at::kFloat) { throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32)."); } + CHECK_GE(threshold_single, 0); + CHECK_GE(1, threshold_single); + CHECK_GE(threshold_acc, 0); + CHECK_GE(1, threshold_acc); cudaStream_t stream = reinterpret_cast(cuda_stream); cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly( @@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only( num_spec_step, num_draft_tokens, vocab_size, + static_cast(threshold_single), + static_cast(threshold_acc), deterministic, stream); diff --git a/sgl-kernel/csrc/speculative/speculative_sampling.cuh b/sgl-kernel/csrc/speculative/speculative_sampling.cuh index bf7099231..10b14713f 100644 --- a/sgl-kernel/csrc/speculative/speculative_sampling.cuh +++ b/sgl-kernel/csrc/speculative/speculative_sampling.cuh @@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly( uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t num_draft_tokens, - uint32_t d) { + uint32_t d, + DType threshold_single, + DType threshold_acc) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; extern __shared__ __align__(alignof(SamplingTempStorage)) @@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly( while (cur_index != -1) { IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index]; IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index]; - prob_acc += target_probs[cur_prob_offset + draft_token_id]; + DType target_prob_single = target_probs[cur_prob_offset + draft_token_id]; + prob_acc += target_prob_single; - if (coin < prob_acc) { + if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) { // accept token prob_acc = 0.; cur_prob_offset = (bx * num_draft_tokens + cur_index) * d; @@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( uint32_t num_speculative_tokens, uint32_t num_draft_tokens, uint32_t d, - bool deterministic, + DType threshold_single = 1, + DType threshold_acc = 1, + bool deterministic = true, cudaStream_t stream = 0) { constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); @@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); + float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f); void* args[] = { &predicts, &output_token_ids, @@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly( &batch_size, &num_speculative_tokens, &num_draft_tokens, - &d}; + &d, + &threshold_single, + &capped_threshold_acc}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { auto kernel = TreeSpeculativeSamplingTargetOnly< diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 5962ac857..29bf9427b 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -129,21 +129,24 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " "Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, " + "float threshold_single, float threshold_acc, " "bool deterministic, int cuda_stream) -> ()"); m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only); m.def( - "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " - "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! " - "retrive_next_sibling, " - "int topk, int depth, int draft_token_num) -> ()"); - m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); + "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " + "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "Tensor target_predict, int cuda_stream) -> ()"); + m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); m.def( - "build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " - "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, " - "int topk, int depth, int draft_token_num) -> ()"); - m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel); + "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " + "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " + "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); + m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); + + m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()"); + m.impl("segment_packbits", torch::kCUDA, &segment_packbits); /* * From FlashInfer diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 1198c101a..cfb39e9d4 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -183,8 +183,8 @@ void topk_softmax( * From csrc/speculative */ void tree_speculative_sampling_target_only( - at::Tensor predicts, - at::Tensor accept_index, + at::Tensor predicts, // mutable + at::Tensor accept_index, // mutable at::Tensor accept_token_num, // mutable at::Tensor candidates, at::Tensor retrive_index, @@ -193,9 +193,22 @@ void tree_speculative_sampling_target_only( at::Tensor uniform_samples, at::Tensor target_probs, at::Tensor draft_probs, + double threshold_single = 1, + double threshold_acc = 1, bool deterministic = true, int64_t cuda_stream = 0); +void verify_tree_greedy( + at::Tensor predicts, // mutable + at::Tensor accept_index, // mutable + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor target_predict, + int64_t cuda_stream = 0); + void build_tree_kernel_efficient( at::Tensor parent_list, at::Tensor selected_index, @@ -209,16 +222,8 @@ void build_tree_kernel_efficient( int64_t depth, int64_t draft_token_num); -void build_tree_kernel( - at::Tensor parent_list, - at::Tensor selected_index, - at::Tensor verified_seq_len, - at::Tensor tree_mask, - at::Tensor positions, - at::Tensor retrive_index, - int64_t topk, - int64_t depth, - int64_t draft_token_num); +void segment_packbits( + at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream); /* * From FlashInfer diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 7c05b50a9..36160f0e9 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -42,8 +42,13 @@ from sgl_kernel.sampling import ( top_p_sampling_from_probs, ) from sgl_kernel.speculative import ( - build_tree_kernel, build_tree_kernel_efficient, + segment_packbits, tree_speculative_sampling_target_only, + verify_tree_greedy, ) from sgl_kernel.version import __version__ + +build_tree_kernel = ( + None # TODO(ying): remove this after updating the sglang python code. +) diff --git a/sgl-kernel/python/sgl_kernel/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py index 53acb1d95..ebec2a5a9 100644 --- a/sgl-kernel/python/sgl_kernel/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -13,6 +13,8 @@ def tree_speculative_sampling_target_only( uniform_samples: torch.Tensor, target_probs: torch.Tensor, draft_probs: torch.Tensor, + threshold_single: float = 1.0, + threshold_acc: float = 1.0, deterministic: bool = True, ) -> None: torch.ops.sgl_kernel.tree_speculative_sampling_target_only( @@ -26,11 +28,36 @@ def tree_speculative_sampling_target_only( uniform_samples, target_probs, draft_probs, + threshold_single, + threshold_acc, deterministic, get_cuda_stream(), ) +def verify_tree_greedy( + predicts: torch.Tensor, # mutable + accept_index: torch.Tensor, # mutable + accept_token_num: torch.Tensor, # mutable + candidates: torch.Tensor, + retrive_index: torch.Tensor, + retrive_next_token: torch.Tensor, + retrive_next_sibling: torch.Tensor, + target_predict: torch.Tensor, +) -> None: + torch.ops.sgl_kernel.verify_tree_greedy( + predicts, + accept_index, + accept_token_num, + candidates, + retrive_index, + retrive_next_token, + retrive_next_sibling, + target_predict, + get_cuda_stream(), + ) + + def build_tree_kernel_efficient( parent_list: torch.Tensor, selected_index: torch.Tensor, @@ -59,25 +86,16 @@ def build_tree_kernel_efficient( ) -def build_tree_kernel( - parent_list: torch.Tensor, - selected_index: torch.Tensor, - verified_seq_len: torch.Tensor, - tree_mask: torch.Tensor, - positions: torch.Tensor, - retrive_index: torch.Tensor, - topk: int, - depth: int, - draft_token_num: int, +def segment_packbits( + x: torch.Tensor, + input_indptr: torch.Tensor, + output_indptr: torch.Tensor, + y: torch.Tensor, ) -> None: - torch.ops.sgl_kernel.build_tree_kernel( - parent_list, - selected_index, - verified_seq_len, - tree_mask, - positions, - retrive_index, - topk, - depth, - draft_token_num, + torch.ops.sgl_kernel.segment_packbits( + x, + input_indptr, + output_indptr, + y, + torch.cuda.current_stream().cuda_stream, ) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 37555ab55..26c6d0f70 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -209,6 +209,7 @@ sources = [ "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/speculative/eagle_utils.cu", "csrc/speculative/speculative_sampling.cu", + "csrc/speculative/packbit.cu", "csrc/torch_extension.cc", "3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/renorm.cu", diff --git a/sgl-kernel/tests/speculative/test_eagle_utils.py b/sgl-kernel/tests/speculative/test_eagle_utils.py new file mode 100644 index 000000000..1514029ec --- /dev/null +++ b/sgl-kernel/tests/speculative/test_eagle_utils.py @@ -0,0 +1,98 @@ +import torch +import torch.nn.functional as F +from sgl_kernel import verify_tree_greedy + + +def test_verify_tree_greedy(): + candidates = torch.tensor( + [ + [0, 1, 2, 3, 4, 5], + [7, 8, 9, 10, 11, 12], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_index = torch.tensor( + [ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_next_token = torch.tensor( + [ + [1, 2, -1, 4, 5, -1], + [4, 2, 3, -1, 5, -1], + ], + dtype=torch.int32, + device="cuda", + ) + retrive_next_sibling = torch.tensor( + [ + [-1, 3, -1, -1, -1, -1], + [-1, -1, -1, -1, 1, -1], + ], + dtype=torch.int32, + device="cuda", + ) + + target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda") + target_logits[0, 0, 3] = 10 + target_logits[0, 3, 4] = 10 + target_logits[0, 4, 5] = 10 + target_logits[1, 0, 11] = 10 + target_logits[1, 4, 12] = 10 + for i in range(target_logits.shape[0]): + for j in range(target_logits.shape[1]): + if torch.max(target_logits[i][j]) < 10: + target_logits[i][j][18] = 10 + + print(f"{target_logits=}") + target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32) + predict_shape = (12,) + + bs = candidates.shape[0] + num_spec_step = 4 + num_draft_tokens = candidates.shape[1] + + predicts = torch.full( + predict_shape, -1, dtype=torch.int32, device="cuda" + ) # mutable + accept_index = torch.full( + (bs, num_spec_step), -1, dtype=torch.int32, device="cuda" + ) # mutable + accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable + + print(f"{candidates=}") + print(f"{retrive_index=}") + print(f"{retrive_next_token=}") + print(f"{retrive_next_sibling=}") + print(f"{target_predict=}") + + verify_tree_greedy( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates, + retrive_index=retrive_index, + retrive_next_token=retrive_next_token, + retrive_next_sibling=retrive_next_sibling, + target_predict=target_predict, + ) + + print(f"{predicts=}") + print(f"{accept_index=}") + print(f"{accept_token_num=}") + + return predicts, accept_index, accept_token_num + + +if __name__ == "__main__": + predicts, accept_index, accept_token_num = test_verify_tree_greedy() + assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 3, 4, 5], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [3, 2] diff --git a/sgl-kernel/tests/test_speculative_sampling.py b/sgl-kernel/tests/speculative/test_speculative_sampling.py similarity index 76% rename from sgl-kernel/tests/test_speculative_sampling.py rename to sgl-kernel/tests/speculative/test_speculative_sampling.py index 545c3725a..2d45db2d0 100644 --- a/sgl-kernel/tests/test_speculative_sampling.py +++ b/sgl-kernel/tests/speculative/test_speculative_sampling.py @@ -3,7 +3,10 @@ import torch.nn.functional as F from sgl_kernel import tree_speculative_sampling_target_only -def test_tree_speculative_sampling_target_only(): +def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1): + print( + f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n" + ) candidates = torch.tensor( [ [0, 1, 2, 3, 4, 5], @@ -37,7 +40,7 @@ def test_tree_speculative_sampling_target_only(): device="cuda", ) - target_logits = torch.zeros((2, 6, 20), dtype=torch.float32, device="cuda") + target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda") target_logits[0, 0, 3] = 10 target_logits[0, 3, 4] = 10 target_logits[0, 4, 5] = 10 @@ -85,6 +88,8 @@ def test_tree_speculative_sampling_target_only(): uniform_samples=coins, target_probs=target_probs, draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, deterministic=True, ) @@ -92,6 +97,13 @@ def test_tree_speculative_sampling_target_only(): print(f"{accept_index=}") print(f"{accept_token_num=}") + return predicts, accept_index, accept_token_num + + +if __name__ == "__main__": + predicts, accept_index, accept_token_num = ( + test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1) + ) assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18] assert accept_index.tolist() == [ [0, 3, 4, 5], @@ -99,6 +111,12 @@ def test_tree_speculative_sampling_target_only(): ] assert accept_token_num.tolist() == [3, 2] - -if __name__ == "__main__": - test_tree_speculative_sampling_target_only() + predicts, accept_index, accept_token_num = ( + test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0) + ) + assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18] + assert accept_index.tolist() == [ + [0, 1, 2, -1], + [6, 10, 11, -1], + ] + assert accept_token_num.tolist() == [2, 2]