Add greedy verification kernel (#4383)
This commit is contained in:
@@ -17,6 +17,8 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "pytorch_extension_utils.h"
|
||||||
|
|
||||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||||
// selected_index [bs, draft_token_num - 1]
|
// selected_index [bs, draft_token_num - 1]
|
||||||
// verified_seq_len [bs]
|
// verified_seq_len [bs]
|
||||||
@@ -72,8 +74,8 @@ __global__ void build_tree_efficient(
|
|||||||
}
|
}
|
||||||
if (parent_position == draft_token_num) {
|
if (parent_position == draft_token_num) {
|
||||||
printf(
|
printf(
|
||||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
|
||||||
"will be dropped.");
|
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,112 +142,141 @@ void build_tree_kernel_efficient(
|
|||||||
int32_t(draft_token_num));
|
int32_t(draft_token_num));
|
||||||
}
|
}
|
||||||
|
|
||||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
template <typename IdType>
|
||||||
// selected_index [bs, draft_token_num - 1]
|
__global__ void VerifyTreeGreedy(
|
||||||
// verified_seq_len [bs]
|
IdType* predicts,
|
||||||
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
|
IdType* accept_index,
|
||||||
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
|
IdType* accept_token_num, // mutable
|
||||||
// draft_token, depth + 2]
|
IdType* candidates,
|
||||||
__global__ void build_tree(
|
IdType* retrive_index,
|
||||||
int64_t* parent_list,
|
IdType* retrive_next_token,
|
||||||
int64_t* selected_index,
|
IdType* retrive_next_sibling,
|
||||||
int32_t* verified_seq_len,
|
IdType* target_predict,
|
||||||
bool* tree_mask,
|
uint32_t batch_size,
|
||||||
int64_t* positions,
|
uint32_t num_speculative_tokens,
|
||||||
int64_t* retrive_index,
|
uint32_t num_draft_tokens) {
|
||||||
int topk,
|
uint32_t bx = blockIdx.x;
|
||||||
int depth,
|
|
||||||
int draft_token_num) {
|
|
||||||
int bid = blockIdx.x;
|
|
||||||
int tid = threadIdx.x;
|
|
||||||
|
|
||||||
if (tid >= draft_token_num) {
|
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||||
return;
|
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||||
}
|
uint32_t num_accepted_tokens = 0;
|
||||||
int seq_tree_idx = draft_token_num * draft_token_num * bid;
|
IdType cur_index = 0;
|
||||||
for (int i = 0; i < bid; i++) {
|
|
||||||
seq_tree_idx += verified_seq_len[i] * draft_token_num;
|
|
||||||
}
|
|
||||||
int seq_len = verified_seq_len[bid];
|
|
||||||
int token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
|
|
||||||
for (int i = 0; i < draft_token_num - 1; i++) {
|
|
||||||
tree_mask[token_tree_idx + i] = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
int position = 0;
|
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||||
if (tid == 0) {
|
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||||
positions[bid * draft_token_num] = seq_len;
|
while (cur_index != -1) {
|
||||||
retrive_index[bid * draft_token_num * (depth + 2)] = bid * draft_token_num;
|
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||||
return;
|
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||||
}
|
IdType target_token_id = target_predict[last_accepted_retrive_idx];
|
||||||
|
|
||||||
int depends_order[10];
|
if (draft_token_id == target_token_id) {
|
||||||
|
// accept token
|
||||||
int cur_position = tid - 1;
|
predicts[last_accepted_retrive_idx] = target_token_id;
|
||||||
while (true) {
|
++num_accepted_tokens;
|
||||||
depends_order[position] = cur_position + 1;
|
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
|
||||||
position += 1;
|
last_accepted_retrive_idx = draft_index;
|
||||||
tree_mask[token_tree_idx + cur_position] = true;
|
|
||||||
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
|
|
||||||
if (parent_tb_idx == 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
|
|
||||||
for (cur_position = 0; cur_position < draft_token_num; cur_position++) {
|
|
||||||
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
|
|
||||||
break;
|
break;
|
||||||
|
} else {
|
||||||
|
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cur_position == draft_token_num) {
|
if (cur_index == -1) break;
|
||||||
printf(
|
|
||||||
"ERROR: invalid eagle tree!!! Detected a token with no parent token selected. Check the logprob. The token "
|
|
||||||
"will be dropped.");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
positions[bid * draft_token_num + tid] = position + seq_len;
|
|
||||||
|
|
||||||
int is_leaf = 0;
|
|
||||||
for (int i = 1; i < draft_token_num; i++) {
|
|
||||||
if (tree_mask[seq_tree_idx + i * (draft_token_num + seq_len) + seq_len + tid]) {
|
|
||||||
is_leaf++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (is_leaf == 1) {
|
|
||||||
for (int i = 0; i < position; i++) {
|
|
||||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2) + position - i] =
|
|
||||||
depends_order[i] + bid * draft_token_num;
|
|
||||||
}
|
|
||||||
retrive_index[(bid * (draft_token_num) + tid) * (depth + 2)] = bid * draft_token_num;
|
|
||||||
}
|
}
|
||||||
|
accept_token_num[bx] = num_accepted_tokens;
|
||||||
|
predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
void build_tree_kernel(
|
// predicts: [tot_num_draft_tokens]
|
||||||
at::Tensor parent_list,
|
// accept_index: [bs, num_spec_step]
|
||||||
at::Tensor selected_index,
|
// accept_token_num: [bs]
|
||||||
at::Tensor verified_seq_len,
|
// candidates: [bs, num_draft_tokens]
|
||||||
at::Tensor tree_mask,
|
// retrive_index: [bs, num_draft_tokens]
|
||||||
at::Tensor positions,
|
// retrive_next_token: [bs, num_draft_tokens]
|
||||||
|
// retrive_next_sibling: [bs, num_draft_tokens]
|
||||||
|
// target_predict: [bs, num_draft_tokens]
|
||||||
|
void verify_tree_greedy(
|
||||||
|
at::Tensor predicts,
|
||||||
|
at::Tensor accept_index,
|
||||||
|
at::Tensor accept_token_num, // mutable
|
||||||
|
at::Tensor candidates,
|
||||||
at::Tensor retrive_index,
|
at::Tensor retrive_index,
|
||||||
int64_t topk,
|
at::Tensor retrive_next_token,
|
||||||
int64_t depth,
|
at::Tensor retrive_next_sibling,
|
||||||
int64_t draft_token_num) {
|
at::Tensor target_predict,
|
||||||
// TODO (ying) check shape
|
int64_t cuda_stream = 0) {
|
||||||
// TODO (ying) check type
|
CHECK_INPUT(candidates);
|
||||||
int bs = parent_list.size(0);
|
CHECK_INPUT(retrive_index);
|
||||||
dim3 grid(bs);
|
CHECK_INPUT(retrive_next_token);
|
||||||
dim3 block(draft_token_num);
|
CHECK_INPUT(retrive_next_sibling);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
CHECK_INPUT(target_predict);
|
||||||
|
auto device = target_predict.device();
|
||||||
|
CHECK_EQ(candidates.device(), device);
|
||||||
|
CHECK_EQ(retrive_index.device(), device);
|
||||||
|
CHECK_EQ(retrive_next_token.device(), device);
|
||||||
|
CHECK_EQ(retrive_next_sibling.device(), device);
|
||||||
|
CHECK_EQ(target_predict.device(), device);
|
||||||
|
CHECK_DIM(1, predicts);
|
||||||
|
CHECK_DIM(2, accept_index);
|
||||||
|
CHECK_DIM(1, accept_token_num);
|
||||||
|
CHECK_DIM(2, candidates);
|
||||||
|
CHECK_DIM(2, retrive_index);
|
||||||
|
CHECK_DIM(2, retrive_next_token);
|
||||||
|
CHECK_DIM(2, retrive_next_sibling);
|
||||||
|
CHECK_DIM(2, target_predict);
|
||||||
|
unsigned int batch_size = candidates.size(0);
|
||||||
|
unsigned int num_spec_step = accept_index.size(1);
|
||||||
|
unsigned int num_draft_tokens = candidates.size(1);
|
||||||
|
CHECK_EQ(batch_size, accept_index.size(0));
|
||||||
|
CHECK_EQ(batch_size, accept_token_num.size(0));
|
||||||
|
CHECK_EQ(batch_size, retrive_index.size(0));
|
||||||
|
CHECK_EQ(batch_size, retrive_next_token.size(0));
|
||||||
|
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
|
||||||
|
CHECK_EQ(batch_size, target_predict.size(0));
|
||||||
|
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
|
||||||
|
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
|
||||||
|
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
|
||||||
|
CHECK_EQ(num_draft_tokens, target_predict.size(1));
|
||||||
|
CHECK_EQ(batch_size, accept_index.size(0));
|
||||||
|
CHECK_EQ(batch_size, accept_token_num.size(0));
|
||||||
|
if (predicts.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (accept_index.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (accept_token_num.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (candidates.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (retrive_index.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (retrive_next_token.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
if (target_predict.scalar_type() != at::kInt) {
|
||||||
|
throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32).");
|
||||||
|
}
|
||||||
|
|
||||||
build_tree<<<grid, block, 0, stream>>>(
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
dim3 grid(batch_size);
|
||||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
dim3 block(1);
|
||||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
|
||||||
static_cast<bool*>(tree_mask.data_ptr()),
|
VerifyTreeGreedy<int><<<grid, block, 0, stream>>>(
|
||||||
static_cast<int64_t*>(positions.data_ptr()),
|
static_cast<int*>(predicts.data_ptr()),
|
||||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
static_cast<int*>(accept_index.data_ptr()),
|
||||||
int32_t(topk),
|
static_cast<int*>(accept_token_num.data_ptr()),
|
||||||
int32_t(depth),
|
static_cast<int*>(candidates.data_ptr()),
|
||||||
int32_t(draft_token_num));
|
static_cast<int*>(retrive_index.data_ptr()),
|
||||||
|
static_cast<int*>(retrive_next_token.data_ptr()),
|
||||||
|
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
||||||
|
static_cast<int*>(target_predict.data_ptr()),
|
||||||
|
batch_size,
|
||||||
|
num_spec_step,
|
||||||
|
num_draft_tokens);
|
||||||
}
|
}
|
||||||
|
|||||||
47
sgl-kernel/csrc/speculative/packbit.cu
Normal file
47
sgl-kernel/csrc/speculative/packbit.cu
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
// This is only a pluggin used for flashinfer 0.1.6. The new version does not need it.
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2025 by SGLang team.
|
||||||
|
* Copyright (c) 2025 by FlashInfer team.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <flashinfer/quantization.cuh>
|
||||||
|
|
||||||
|
#include "pytorch_extension_utils.h"
|
||||||
|
|
||||||
|
using namespace flashinfer;
|
||||||
|
|
||||||
|
// bitorder = "little"
|
||||||
|
void segment_packbits(
|
||||||
|
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) {
|
||||||
|
CHECK_INPUT(x);
|
||||||
|
CHECK_INPUT(input_indptr);
|
||||||
|
CHECK_INPUT(output_indptr);
|
||||||
|
auto device = x.device();
|
||||||
|
CHECK_EQ(input_indptr.device(), device);
|
||||||
|
CHECK_EQ(output_indptr.device(), device);
|
||||||
|
CHECK_EQ(y.device(), device);
|
||||||
|
unsigned int batch_size = input_indptr.size(0) - 1;
|
||||||
|
CHECK_EQ(output_indptr.size(0), batch_size + 1);
|
||||||
|
|
||||||
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
|
cudaError_t status = quantization::SegmentPackBits(
|
||||||
|
static_cast<bool*>(x.data_ptr()),
|
||||||
|
static_cast<uint8_t*>(y.data_ptr()),
|
||||||
|
static_cast<int32_t*>(input_indptr.data_ptr()),
|
||||||
|
static_cast<int32_t*>(output_indptr.data_ptr()),
|
||||||
|
batch_size,
|
||||||
|
quantization::BitOrder::kLittle,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
@@ -14,7 +14,6 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "pytorch_extension_utils.h"
|
#include "pytorch_extension_utils.h"
|
||||||
#include "speculative_sampling.cuh"
|
#include "speculative_sampling.cuh"
|
||||||
|
|
||||||
@@ -40,7 +39,9 @@ void tree_speculative_sampling_target_only(
|
|||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
at::Tensor target_probs,
|
at::Tensor target_probs,
|
||||||
at::Tensor draft_probs,
|
at::Tensor draft_probs,
|
||||||
bool deterministic,
|
double threshold_single,
|
||||||
|
double threshold_acc,
|
||||||
|
bool deterministic = true,
|
||||||
int64_t cuda_stream = 0) {
|
int64_t cuda_stream = 0) {
|
||||||
CHECK_INPUT(candidates);
|
CHECK_INPUT(candidates);
|
||||||
CHECK_INPUT(retrive_index);
|
CHECK_INPUT(retrive_index);
|
||||||
@@ -112,6 +113,10 @@ void tree_speculative_sampling_target_only(
|
|||||||
if (draft_probs.scalar_type() != at::kFloat) {
|
if (draft_probs.scalar_type() != at::kFloat) {
|
||||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||||
}
|
}
|
||||||
|
CHECK_GE(threshold_single, 0);
|
||||||
|
CHECK_GE(1, threshold_single);
|
||||||
|
CHECK_GE(threshold_acc, 0);
|
||||||
|
CHECK_GE(1, threshold_acc);
|
||||||
|
|
||||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
||||||
@@ -129,6 +134,8 @@ void tree_speculative_sampling_target_only(
|
|||||||
num_spec_step,
|
num_spec_step,
|
||||||
num_draft_tokens,
|
num_draft_tokens,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
static_cast<float>(threshold_single),
|
||||||
|
static_cast<float>(threshold_acc),
|
||||||
deterministic,
|
deterministic,
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
|
|||||||
@@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
uint32_t batch_size,
|
uint32_t batch_size,
|
||||||
uint32_t num_speculative_tokens,
|
uint32_t num_speculative_tokens,
|
||||||
uint32_t num_draft_tokens,
|
uint32_t num_draft_tokens,
|
||||||
uint32_t d) {
|
uint32_t d,
|
||||||
|
DType threshold_single,
|
||||||
|
DType threshold_acc) {
|
||||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||||
|
|
||||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||||
@@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
while (cur_index != -1) {
|
while (cur_index != -1) {
|
||||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||||
prob_acc += target_probs[cur_prob_offset + draft_token_id];
|
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
||||||
|
prob_acc += target_prob_single;
|
||||||
|
|
||||||
if (coin < prob_acc) {
|
if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) {
|
||||||
// accept token
|
// accept token
|
||||||
prob_acc = 0.;
|
prob_acc = 0.;
|
||||||
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
||||||
@@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
uint32_t num_speculative_tokens,
|
uint32_t num_speculative_tokens,
|
||||||
uint32_t num_draft_tokens,
|
uint32_t num_draft_tokens,
|
||||||
uint32_t d,
|
uint32_t d,
|
||||||
bool deterministic,
|
DType threshold_single = 1,
|
||||||
|
DType threshold_acc = 1,
|
||||||
|
bool deterministic = true,
|
||||||
cudaStream_t stream = 0) {
|
cudaStream_t stream = 0) {
|
||||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||||
@@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||||
dim3 nblks(batch_size);
|
dim3 nblks(batch_size);
|
||||||
dim3 nthrs(BLOCK_THREADS);
|
dim3 nthrs(BLOCK_THREADS);
|
||||||
|
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
||||||
void* args[] = {
|
void* args[] = {
|
||||||
&predicts,
|
&predicts,
|
||||||
&output_token_ids,
|
&output_token_ids,
|
||||||
@@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
&batch_size,
|
&batch_size,
|
||||||
&num_speculative_tokens,
|
&num_speculative_tokens,
|
||||||
&num_draft_tokens,
|
&num_draft_tokens,
|
||||||
&d};
|
&d,
|
||||||
|
&threshold_single,
|
||||||
|
&capped_threshold_acc};
|
||||||
DISPATCH_ALIGNED_VEC_SIZE(
|
DISPATCH_ALIGNED_VEC_SIZE(
|
||||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||||
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
||||||
|
|||||||
@@ -129,21 +129,24 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||||
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
||||||
|
"float threshold_single, float threshold_acc, "
|
||||||
"bool deterministic, int cuda_stream) -> ()");
|
"bool deterministic, int cuda_stream) -> ()");
|
||||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, Tensor! "
|
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||||
"retrive_next_sibling, "
|
"Tensor target_predict, int cuda_stream) -> ()");
|
||||||
"int topk, int depth, int draft_token_num) -> ()");
|
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"build_tree_kernel(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, "
|
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||||
"int topk, int depth, int draft_token_num) -> ()");
|
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||||
m.impl("build_tree_kernel", torch::kCUDA, &build_tree_kernel);
|
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||||
|
|
||||||
|
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
|
||||||
|
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
|
|||||||
@@ -183,8 +183,8 @@ void topk_softmax(
|
|||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
*/
|
*/
|
||||||
void tree_speculative_sampling_target_only(
|
void tree_speculative_sampling_target_only(
|
||||||
at::Tensor predicts,
|
at::Tensor predicts, // mutable
|
||||||
at::Tensor accept_index,
|
at::Tensor accept_index, // mutable
|
||||||
at::Tensor accept_token_num, // mutable
|
at::Tensor accept_token_num, // mutable
|
||||||
at::Tensor candidates,
|
at::Tensor candidates,
|
||||||
at::Tensor retrive_index,
|
at::Tensor retrive_index,
|
||||||
@@ -193,9 +193,22 @@ void tree_speculative_sampling_target_only(
|
|||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
at::Tensor target_probs,
|
at::Tensor target_probs,
|
||||||
at::Tensor draft_probs,
|
at::Tensor draft_probs,
|
||||||
|
double threshold_single = 1,
|
||||||
|
double threshold_acc = 1,
|
||||||
bool deterministic = true,
|
bool deterministic = true,
|
||||||
int64_t cuda_stream = 0);
|
int64_t cuda_stream = 0);
|
||||||
|
|
||||||
|
void verify_tree_greedy(
|
||||||
|
at::Tensor predicts, // mutable
|
||||||
|
at::Tensor accept_index, // mutable
|
||||||
|
at::Tensor accept_token_num, // mutable
|
||||||
|
at::Tensor candidates,
|
||||||
|
at::Tensor retrive_index,
|
||||||
|
at::Tensor retrive_next_token,
|
||||||
|
at::Tensor retrive_next_sibling,
|
||||||
|
at::Tensor target_predict,
|
||||||
|
int64_t cuda_stream = 0);
|
||||||
|
|
||||||
void build_tree_kernel_efficient(
|
void build_tree_kernel_efficient(
|
||||||
at::Tensor parent_list,
|
at::Tensor parent_list,
|
||||||
at::Tensor selected_index,
|
at::Tensor selected_index,
|
||||||
@@ -209,16 +222,8 @@ void build_tree_kernel_efficient(
|
|||||||
int64_t depth,
|
int64_t depth,
|
||||||
int64_t draft_token_num);
|
int64_t draft_token_num);
|
||||||
|
|
||||||
void build_tree_kernel(
|
void segment_packbits(
|
||||||
at::Tensor parent_list,
|
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);
|
||||||
at::Tensor selected_index,
|
|
||||||
at::Tensor verified_seq_len,
|
|
||||||
at::Tensor tree_mask,
|
|
||||||
at::Tensor positions,
|
|
||||||
at::Tensor retrive_index,
|
|
||||||
int64_t topk,
|
|
||||||
int64_t depth,
|
|
||||||
int64_t draft_token_num);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
|
|||||||
@@ -42,8 +42,13 @@ from sgl_kernel.sampling import (
|
|||||||
top_p_sampling_from_probs,
|
top_p_sampling_from_probs,
|
||||||
)
|
)
|
||||||
from sgl_kernel.speculative import (
|
from sgl_kernel.speculative import (
|
||||||
build_tree_kernel,
|
|
||||||
build_tree_kernel_efficient,
|
build_tree_kernel_efficient,
|
||||||
|
segment_packbits,
|
||||||
tree_speculative_sampling_target_only,
|
tree_speculative_sampling_target_only,
|
||||||
|
verify_tree_greedy,
|
||||||
)
|
)
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
|
build_tree_kernel = (
|
||||||
|
None # TODO(ying): remove this after updating the sglang python code.
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ def tree_speculative_sampling_target_only(
|
|||||||
uniform_samples: torch.Tensor,
|
uniform_samples: torch.Tensor,
|
||||||
target_probs: torch.Tensor,
|
target_probs: torch.Tensor,
|
||||||
draft_probs: torch.Tensor,
|
draft_probs: torch.Tensor,
|
||||||
|
threshold_single: float = 1.0,
|
||||||
|
threshold_acc: float = 1.0,
|
||||||
deterministic: bool = True,
|
deterministic: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
|
||||||
@@ -26,11 +28,36 @@ def tree_speculative_sampling_target_only(
|
|||||||
uniform_samples,
|
uniform_samples,
|
||||||
target_probs,
|
target_probs,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
|
threshold_single,
|
||||||
|
threshold_acc,
|
||||||
deterministic,
|
deterministic,
|
||||||
get_cuda_stream(),
|
get_cuda_stream(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_tree_greedy(
|
||||||
|
predicts: torch.Tensor, # mutable
|
||||||
|
accept_index: torch.Tensor, # mutable
|
||||||
|
accept_token_num: torch.Tensor, # mutable
|
||||||
|
candidates: torch.Tensor,
|
||||||
|
retrive_index: torch.Tensor,
|
||||||
|
retrive_next_token: torch.Tensor,
|
||||||
|
retrive_next_sibling: torch.Tensor,
|
||||||
|
target_predict: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.sgl_kernel.verify_tree_greedy(
|
||||||
|
predicts,
|
||||||
|
accept_index,
|
||||||
|
accept_token_num,
|
||||||
|
candidates,
|
||||||
|
retrive_index,
|
||||||
|
retrive_next_token,
|
||||||
|
retrive_next_sibling,
|
||||||
|
target_predict,
|
||||||
|
get_cuda_stream(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient(
|
def build_tree_kernel_efficient(
|
||||||
parent_list: torch.Tensor,
|
parent_list: torch.Tensor,
|
||||||
selected_index: torch.Tensor,
|
selected_index: torch.Tensor,
|
||||||
@@ -59,25 +86,16 @@ def build_tree_kernel_efficient(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel(
|
def segment_packbits(
|
||||||
parent_list: torch.Tensor,
|
x: torch.Tensor,
|
||||||
selected_index: torch.Tensor,
|
input_indptr: torch.Tensor,
|
||||||
verified_seq_len: torch.Tensor,
|
output_indptr: torch.Tensor,
|
||||||
tree_mask: torch.Tensor,
|
y: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
|
||||||
retrive_index: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
depth: int,
|
|
||||||
draft_token_num: int,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.build_tree_kernel(
|
torch.ops.sgl_kernel.segment_packbits(
|
||||||
parent_list,
|
x,
|
||||||
selected_index,
|
input_indptr,
|
||||||
verified_seq_len,
|
output_indptr,
|
||||||
tree_mask,
|
y,
|
||||||
positions,
|
torch.cuda.current_stream().cuda_stream,
|
||||||
retrive_index,
|
|
||||||
topk,
|
|
||||||
depth,
|
|
||||||
draft_token_num,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -209,6 +209,7 @@ sources = [
|
|||||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||||
"csrc/speculative/eagle_utils.cu",
|
"csrc/speculative/eagle_utils.cu",
|
||||||
"csrc/speculative/speculative_sampling.cu",
|
"csrc/speculative/speculative_sampling.cu",
|
||||||
|
"csrc/speculative/packbit.cu",
|
||||||
"csrc/torch_extension.cc",
|
"csrc/torch_extension.cc",
|
||||||
"3rdparty/flashinfer/csrc/norm.cu",
|
"3rdparty/flashinfer/csrc/norm.cu",
|
||||||
"3rdparty/flashinfer/csrc/renorm.cu",
|
"3rdparty/flashinfer/csrc/renorm.cu",
|
||||||
|
|||||||
98
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
98
sgl-kernel/tests/speculative/test_eagle_utils.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from sgl_kernel import verify_tree_greedy
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_tree_greedy():
|
||||||
|
candidates = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 1, 2, 3, 4, 5],
|
||||||
|
[7, 8, 9, 10, 11, 12],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
retrive_index = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 1, 2, 3, 4, 5],
|
||||||
|
[6, 7, 8, 9, 10, 11],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
retrive_next_token = torch.tensor(
|
||||||
|
[
|
||||||
|
[1, 2, -1, 4, 5, -1],
|
||||||
|
[4, 2, 3, -1, 5, -1],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
retrive_next_sibling = torch.tensor(
|
||||||
|
[
|
||||||
|
[-1, 3, -1, -1, -1, -1],
|
||||||
|
[-1, -1, -1, -1, 1, -1],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
|
||||||
|
target_logits[0, 0, 3] = 10
|
||||||
|
target_logits[0, 3, 4] = 10
|
||||||
|
target_logits[0, 4, 5] = 10
|
||||||
|
target_logits[1, 0, 11] = 10
|
||||||
|
target_logits[1, 4, 12] = 10
|
||||||
|
for i in range(target_logits.shape[0]):
|
||||||
|
for j in range(target_logits.shape[1]):
|
||||||
|
if torch.max(target_logits[i][j]) < 10:
|
||||||
|
target_logits[i][j][18] = 10
|
||||||
|
|
||||||
|
print(f"{target_logits=}")
|
||||||
|
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
||||||
|
predict_shape = (12,)
|
||||||
|
|
||||||
|
bs = candidates.shape[0]
|
||||||
|
num_spec_step = 4
|
||||||
|
num_draft_tokens = candidates.shape[1]
|
||||||
|
|
||||||
|
predicts = torch.full(
|
||||||
|
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||||
|
) # mutable
|
||||||
|
accept_index = torch.full(
|
||||||
|
(bs, num_spec_step), -1, dtype=torch.int32, device="cuda"
|
||||||
|
) # mutable
|
||||||
|
accept_token_num = torch.full((bs,), 0, dtype=torch.int32, device="cuda") # mutable
|
||||||
|
|
||||||
|
print(f"{candidates=}")
|
||||||
|
print(f"{retrive_index=}")
|
||||||
|
print(f"{retrive_next_token=}")
|
||||||
|
print(f"{retrive_next_sibling=}")
|
||||||
|
print(f"{target_predict=}")
|
||||||
|
|
||||||
|
verify_tree_greedy(
|
||||||
|
predicts=predicts,
|
||||||
|
accept_index=accept_index,
|
||||||
|
accept_token_num=accept_token_num,
|
||||||
|
candidates=candidates,
|
||||||
|
retrive_index=retrive_index,
|
||||||
|
retrive_next_token=retrive_next_token,
|
||||||
|
retrive_next_sibling=retrive_next_sibling,
|
||||||
|
target_predict=target_predict,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"{predicts=}")
|
||||||
|
print(f"{accept_index=}")
|
||||||
|
print(f"{accept_token_num=}")
|
||||||
|
|
||||||
|
return predicts, accept_index, accept_token_num
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predicts, accept_index, accept_token_num = test_verify_tree_greedy()
|
||||||
|
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||||
|
assert accept_index.tolist() == [
|
||||||
|
[0, 3, 4, 5],
|
||||||
|
[6, 10, 11, -1],
|
||||||
|
]
|
||||||
|
assert accept_token_num.tolist() == [3, 2]
|
||||||
@@ -3,7 +3,10 @@ import torch.nn.functional as F
|
|||||||
from sgl_kernel import tree_speculative_sampling_target_only
|
from sgl_kernel import tree_speculative_sampling_target_only
|
||||||
|
|
||||||
|
|
||||||
def test_tree_speculative_sampling_target_only():
|
def test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1):
|
||||||
|
print(
|
||||||
|
f"\n============= run test: {threshold_single=} {threshold_acc=} ==============\n"
|
||||||
|
)
|
||||||
candidates = torch.tensor(
|
candidates = torch.tensor(
|
||||||
[
|
[
|
||||||
[0, 1, 2, 3, 4, 5],
|
[0, 1, 2, 3, 4, 5],
|
||||||
@@ -37,7 +40,7 @@ def test_tree_speculative_sampling_target_only():
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
target_logits = torch.zeros((2, 6, 20), dtype=torch.float32, device="cuda")
|
target_logits = torch.full((2, 6, 20), 1, dtype=torch.float32, device="cuda")
|
||||||
target_logits[0, 0, 3] = 10
|
target_logits[0, 0, 3] = 10
|
||||||
target_logits[0, 3, 4] = 10
|
target_logits[0, 3, 4] = 10
|
||||||
target_logits[0, 4, 5] = 10
|
target_logits[0, 4, 5] = 10
|
||||||
@@ -85,6 +88,8 @@ def test_tree_speculative_sampling_target_only():
|
|||||||
uniform_samples=coins,
|
uniform_samples=coins,
|
||||||
target_probs=target_probs,
|
target_probs=target_probs,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
|
threshold_single=threshold_single,
|
||||||
|
threshold_acc=threshold_acc,
|
||||||
deterministic=True,
|
deterministic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,6 +97,13 @@ def test_tree_speculative_sampling_target_only():
|
|||||||
print(f"{accept_index=}")
|
print(f"{accept_index=}")
|
||||||
print(f"{accept_token_num=}")
|
print(f"{accept_token_num=}")
|
||||||
|
|
||||||
|
return predicts, accept_index, accept_token_num
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
predicts, accept_index, accept_token_num = (
|
||||||
|
test_tree_speculative_sampling_target_only(threshold_single=1, threshold_acc=1)
|
||||||
|
)
|
||||||
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
assert predicts.tolist() == [3, -1, -1, 4, 5, 18, 11, -1, -1, -1, 12, 18]
|
||||||
assert accept_index.tolist() == [
|
assert accept_index.tolist() == [
|
||||||
[0, 3, 4, 5],
|
[0, 3, 4, 5],
|
||||||
@@ -99,6 +111,12 @@ def test_tree_speculative_sampling_target_only():
|
|||||||
]
|
]
|
||||||
assert accept_token_num.tolist() == [3, 2]
|
assert accept_token_num.tolist() == [3, 2]
|
||||||
|
|
||||||
|
predicts, accept_index, accept_token_num = (
|
||||||
if __name__ == "__main__":
|
test_tree_speculative_sampling_target_only(threshold_single=0, threshold_acc=0)
|
||||||
test_tree_speculative_sampling_target_only()
|
)
|
||||||
|
assert predicts.tolist() == [1, 2, 18, -1, -1, -1, 11, -1, -1, -1, 12, 18]
|
||||||
|
assert accept_index.tolist() == [
|
||||||
|
[0, 1, 2, -1],
|
||||||
|
[6, 10, 11, -1],
|
||||||
|
]
|
||||||
|
assert accept_token_num.tolist() == [2, 2]
|
||||||
Reference in New Issue
Block a user