Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
7
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
7
sgl-kernel/csrc/common_extension.cc
Executable file → Normal file
@@ -201,13 +201,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||||
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
|
||||||
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From csrc/speculative
|
* From csrc/speculative
|
||||||
*/
|
*/
|
||||||
m.def(
|
m.def(
|
||||||
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||||
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
|
"Tensor uniform_samples, Tensor uniform_samples_for_final_sampling, Tensor target_probs, Tensor draft_probs, "
|
||||||
"float threshold_single, float threshold_acc, "
|
"float threshold_single, float threshold_acc, "
|
||||||
"bool deterministic, int cuda_stream) -> ()");
|
"bool deterministic, int cuda_stream) -> ()");
|
||||||
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
|
||||||
@@ -224,7 +225,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||||
|
|
||||||
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
|
m.def(
|
||||||
|
"segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int batch_size, "
|
||||||
|
"int cuda_stream) -> ()");
|
||||||
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|||||||
@@ -32,7 +32,7 @@
|
|||||||
__global__ void build_tree_efficient(
|
__global__ void build_tree_efficient(
|
||||||
int64_t* parent_list,
|
int64_t* parent_list,
|
||||||
int64_t* selected_index,
|
int64_t* selected_index,
|
||||||
int32_t* verified_seq_len,
|
int64_t* verified_seq_len,
|
||||||
bool* tree_mask,
|
bool* tree_mask,
|
||||||
int64_t* positions,
|
int64_t* positions,
|
||||||
int64_t* retrive_index,
|
int64_t* retrive_index,
|
||||||
@@ -135,7 +135,7 @@ void build_tree_kernel_efficient(
|
|||||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||||
static_cast<bool*>(tree_mask.data_ptr()),
|
static_cast<bool*>(tree_mask.data_ptr()),
|
||||||
static_cast<int64_t*>(positions.data_ptr()),
|
static_cast<int64_t*>(positions.data_ptr()),
|
||||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||||
@@ -146,32 +146,32 @@ void build_tree_kernel_efficient(
|
|||||||
int32_t(draft_token_num));
|
int32_t(draft_token_num));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename IdType>
|
template <typename IdType, typename IdType2>
|
||||||
__global__ void VerifyTreeGreedy(
|
__global__ void VerifyTreeGreedy(
|
||||||
IdType* predicts,
|
IdType* predicts,
|
||||||
IdType* accept_index,
|
IdType* accept_index,
|
||||||
IdType* accept_token_num, // mutable
|
IdType* accept_token_num, // mutable
|
||||||
IdType* candidates,
|
IdType2* candidates,
|
||||||
IdType* retrive_index,
|
IdType2* retrive_index,
|
||||||
IdType* retrive_next_token,
|
IdType2* retrive_next_token,
|
||||||
IdType* retrive_next_sibling,
|
IdType2* retrive_next_sibling,
|
||||||
IdType* target_predict,
|
IdType2* target_predict,
|
||||||
uint32_t batch_size,
|
uint32_t batch_size,
|
||||||
uint32_t num_speculative_tokens,
|
uint32_t num_speculative_tokens,
|
||||||
uint32_t num_draft_tokens) {
|
uint32_t num_draft_tokens) {
|
||||||
uint32_t bx = blockIdx.x;
|
uint32_t bx = blockIdx.x;
|
||||||
|
|
||||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||||
uint32_t num_accepted_tokens = 0;
|
uint32_t num_accepted_tokens = 0;
|
||||||
IdType cur_index = 0;
|
IdType2 cur_index = 0;
|
||||||
|
|
||||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||||
while (cur_index != -1) {
|
while (cur_index != -1) {
|
||||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||||
IdType target_token_id = target_predict[last_accepted_retrive_idx];
|
IdType2 target_token_id = target_predict[last_accepted_retrive_idx];
|
||||||
|
|
||||||
if (draft_token_id == target_token_id) {
|
if (draft_token_id == target_token_id) {
|
||||||
// accept token
|
// accept token
|
||||||
@@ -251,35 +251,35 @@ void verify_tree_greedy(
|
|||||||
if (accept_token_num.scalar_type() != at::kInt) {
|
if (accept_token_num.scalar_type() != at::kInt) {
|
||||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||||
}
|
}
|
||||||
if (candidates.scalar_type() != at::kInt) {
|
if (candidates.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_index.scalar_type() != at::kInt) {
|
if (retrive_index.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
if (retrive_next_token.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
if (retrive_next_sibling.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (target_predict.scalar_type() != at::kInt) {
|
if (target_predict.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
dim3 grid(batch_size);
|
dim3 grid(batch_size);
|
||||||
dim3 block(1);
|
dim3 block(1);
|
||||||
|
|
||||||
VerifyTreeGreedy<int><<<grid, block, 0, stream>>>(
|
VerifyTreeGreedy<int32_t, int64_t><<<grid, block, 0, stream>>>(
|
||||||
static_cast<int*>(predicts.data_ptr()),
|
static_cast<int32_t*>(predicts.data_ptr()),
|
||||||
static_cast<int*>(accept_index.data_ptr()),
|
static_cast<int32_t*>(accept_index.data_ptr()),
|
||||||
static_cast<int*>(accept_token_num.data_ptr()),
|
static_cast<int32_t*>(accept_token_num.data_ptr()),
|
||||||
static_cast<int*>(candidates.data_ptr()),
|
static_cast<int64_t*>(candidates.data_ptr()),
|
||||||
static_cast<int*>(retrive_index.data_ptr()),
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||||
static_cast<int*>(target_predict.data_ptr()),
|
static_cast<int64_t*>(target_predict.data_ptr()),
|
||||||
batch_size,
|
batch_size,
|
||||||
num_spec_step,
|
num_spec_step,
|
||||||
num_draft_tokens);
|
num_draft_tokens);
|
||||||
|
|||||||
@@ -24,7 +24,12 @@ using namespace flashinfer;
|
|||||||
|
|
||||||
// bitorder = "little"
|
// bitorder = "little"
|
||||||
void segment_packbits(
|
void segment_packbits(
|
||||||
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream) {
|
at::Tensor x,
|
||||||
|
at::Tensor input_indptr,
|
||||||
|
at::Tensor output_indptr,
|
||||||
|
at::Tensor y,
|
||||||
|
int64_t batch_size,
|
||||||
|
int64_t cuda_stream) {
|
||||||
CHECK_INPUT(x);
|
CHECK_INPUT(x);
|
||||||
CHECK_INPUT(input_indptr);
|
CHECK_INPUT(input_indptr);
|
||||||
CHECK_INPUT(output_indptr);
|
CHECK_INPUT(output_indptr);
|
||||||
@@ -32,8 +37,7 @@ void segment_packbits(
|
|||||||
CHECK_EQ(input_indptr.device(), device);
|
CHECK_EQ(input_indptr.device(), device);
|
||||||
CHECK_EQ(output_indptr.device(), device);
|
CHECK_EQ(output_indptr.device(), device);
|
||||||
CHECK_EQ(y.device(), device);
|
CHECK_EQ(y.device(), device);
|
||||||
unsigned int batch_size = input_indptr.size(0) - 1;
|
CHECK_GE(output_indptr.size(0), batch_size + 1);
|
||||||
CHECK_EQ(output_indptr.size(0), batch_size + 1);
|
|
||||||
|
|
||||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
cudaError_t status = quantization::SegmentPackBits(
|
cudaError_t status = quantization::SegmentPackBits(
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ void tree_speculative_sampling_target_only(
|
|||||||
at::Tensor retrive_next_token,
|
at::Tensor retrive_next_token,
|
||||||
at::Tensor retrive_next_sibling,
|
at::Tensor retrive_next_sibling,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
|
at::Tensor uniform_samples_for_final_sampling,
|
||||||
at::Tensor target_probs,
|
at::Tensor target_probs,
|
||||||
at::Tensor draft_probs,
|
at::Tensor draft_probs,
|
||||||
double threshold_single,
|
double threshold_single,
|
||||||
@@ -48,6 +49,7 @@ void tree_speculative_sampling_target_only(
|
|||||||
CHECK_INPUT(retrive_next_token);
|
CHECK_INPUT(retrive_next_token);
|
||||||
CHECK_INPUT(retrive_next_sibling);
|
CHECK_INPUT(retrive_next_sibling);
|
||||||
CHECK_INPUT(uniform_samples);
|
CHECK_INPUT(uniform_samples);
|
||||||
|
CHECK_INPUT(uniform_samples_for_final_sampling);
|
||||||
CHECK_INPUT(target_probs);
|
CHECK_INPUT(target_probs);
|
||||||
auto device = target_probs.device();
|
auto device = target_probs.device();
|
||||||
CHECK_EQ(candidates.device(), device);
|
CHECK_EQ(candidates.device(), device);
|
||||||
@@ -55,6 +57,7 @@ void tree_speculative_sampling_target_only(
|
|||||||
CHECK_EQ(retrive_next_token.device(), device);
|
CHECK_EQ(retrive_next_token.device(), device);
|
||||||
CHECK_EQ(retrive_next_sibling.device(), device);
|
CHECK_EQ(retrive_next_sibling.device(), device);
|
||||||
CHECK_EQ(uniform_samples.device(), device);
|
CHECK_EQ(uniform_samples.device(), device);
|
||||||
|
CHECK_EQ(uniform_samples_for_final_sampling.device(), device);
|
||||||
CHECK_EQ(target_probs.device(), device);
|
CHECK_EQ(target_probs.device(), device);
|
||||||
CHECK_DIM(1, predicts);
|
CHECK_DIM(1, predicts);
|
||||||
CHECK_DIM(2, accept_index);
|
CHECK_DIM(2, accept_index);
|
||||||
@@ -92,21 +95,24 @@ void tree_speculative_sampling_target_only(
|
|||||||
if (accept_token_num.scalar_type() != at::kInt) {
|
if (accept_token_num.scalar_type() != at::kInt) {
|
||||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||||
}
|
}
|
||||||
if (candidates.scalar_type() != at::kInt) {
|
if (candidates.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_index.scalar_type() != at::kInt) {
|
if (retrive_index.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
if (retrive_next_token.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
if (retrive_next_sibling.scalar_type() != at::kLong) {
|
||||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
|
||||||
}
|
}
|
||||||
if (uniform_samples.scalar_type() != at::kFloat) {
|
if (uniform_samples.scalar_type() != at::kFloat) {
|
||||||
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
|
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
|
||||||
}
|
}
|
||||||
|
if (uniform_samples_for_final_sampling.scalar_type() != at::kFloat) {
|
||||||
|
throw std::runtime_error("Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32).");
|
||||||
|
}
|
||||||
if (target_probs.scalar_type() != at::kFloat) {
|
if (target_probs.scalar_type() != at::kFloat) {
|
||||||
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
|
||||||
}
|
}
|
||||||
@@ -119,15 +125,16 @@ void tree_speculative_sampling_target_only(
|
|||||||
CHECK_GE(1, threshold_acc);
|
CHECK_GE(1, threshold_acc);
|
||||||
|
|
||||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||||
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int>(
|
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
|
||||||
static_cast<int*>(predicts.data_ptr()),
|
static_cast<int32_t*>(predicts.data_ptr()),
|
||||||
static_cast<int*>(accept_index.data_ptr()),
|
static_cast<int32_t*>(accept_index.data_ptr()),
|
||||||
static_cast<int*>(accept_token_num.data_ptr()),
|
static_cast<int32_t*>(accept_token_num.data_ptr()),
|
||||||
static_cast<int*>(candidates.data_ptr()),
|
static_cast<int64_t*>(candidates.data_ptr()),
|
||||||
static_cast<int*>(retrive_index.data_ptr()),
|
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||||
static_cast<float*>(uniform_samples.data_ptr()),
|
static_cast<float*>(uniform_samples.data_ptr()),
|
||||||
|
static_cast<float*>(uniform_samples_for_final_sampling.data_ptr()),
|
||||||
static_cast<float*>(target_probs.data_ptr()),
|
static_cast<float*>(target_probs.data_ptr()),
|
||||||
static_cast<float*>(draft_probs.data_ptr()),
|
static_cast<float*>(draft_probs.data_ptr()),
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|||||||
@@ -34,16 +34,18 @@ template <
|
|||||||
uint32_t VEC_SIZE,
|
uint32_t VEC_SIZE,
|
||||||
bool DETERMINISTIC,
|
bool DETERMINISTIC,
|
||||||
typename DType,
|
typename DType,
|
||||||
typename IdType>
|
typename IdType,
|
||||||
|
typename IdType2>
|
||||||
__global__ void TreeSpeculativeSamplingTargetOnly(
|
__global__ void TreeSpeculativeSamplingTargetOnly(
|
||||||
IdType* predicts, // mutable
|
IdType* predicts, // mutable
|
||||||
IdType* accept_index, // mutable
|
IdType* accept_index, // mutable
|
||||||
IdType* accept_token_num, // mutable
|
IdType* accept_token_num, // mutable
|
||||||
IdType* candidates,
|
IdType2* candidates,
|
||||||
IdType* retrive_index,
|
IdType2* retrive_index,
|
||||||
IdType* retrive_next_token,
|
IdType2* retrive_next_token,
|
||||||
IdType* retrive_next_sibling,
|
IdType2* retrive_next_sibling,
|
||||||
DType* uniform_samples,
|
DType* uniform_samples,
|
||||||
|
DType* uniform_samples_for_final_sampling,
|
||||||
DType* target_probs,
|
DType* target_probs,
|
||||||
DType* draft_probs,
|
DType* draft_probs,
|
||||||
uint32_t batch_size,
|
uint32_t batch_size,
|
||||||
@@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
DType prob_acc = 0.0;
|
DType prob_acc = 0.0;
|
||||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||||
DType coin = uniform_samples[bx * num_draft_tokens];
|
DType coin = uniform_samples[bx * num_draft_tokens];
|
||||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||||
uint32_t num_accepted_tokens = 0;
|
uint32_t num_accepted_tokens = 0;
|
||||||
IdType cur_index = 0;
|
IdType2 cur_index = 0;
|
||||||
|
|
||||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||||
while (cur_index != -1) {
|
while (cur_index != -1) {
|
||||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||||
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
||||||
prob_acc += target_prob_single;
|
prob_acc += target_prob_single;
|
||||||
|
|
||||||
@@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
}
|
}
|
||||||
accept_token_num[bx] = num_accepted_tokens;
|
accept_token_num[bx] = num_accepted_tokens;
|
||||||
|
|
||||||
|
// we need a different coin for the final sampling
|
||||||
|
coin = uniform_samples_for_final_sampling[bx];
|
||||||
|
|
||||||
// sample from relu(target_probs - draft_probs)
|
// sample from relu(target_probs - draft_probs)
|
||||||
DType sum_relu_q_minus_p(0);
|
DType sum_relu_q_minus_p(0);
|
||||||
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
||||||
@@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
|||||||
// value at not used indices are undefined
|
// value at not used indices are undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DType, typename IdType>
|
template <typename DType, typename IdType, typename IdType2>
|
||||||
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||||
IdType* predicts, // mutable
|
IdType* predicts, // mutable
|
||||||
IdType* output_token_ids, // mutable
|
IdType* output_token_ids, // mutable
|
||||||
IdType* output_accepted_token_num, // mutable
|
IdType* output_accepted_token_num, // mutable
|
||||||
IdType* candidates,
|
IdType2* candidates,
|
||||||
IdType* retrive_index,
|
IdType2* retrive_index,
|
||||||
IdType* retrive_next_token,
|
IdType2* retrive_next_token,
|
||||||
IdType* retrive_next_sibling,
|
IdType2* retrive_next_sibling,
|
||||||
DType* uniform_samples,
|
DType* uniform_samples,
|
||||||
|
DType* uniform_samples_for_final_sampling,
|
||||||
DType* target_probs,
|
DType* target_probs,
|
||||||
DType* draft_probs,
|
DType* draft_probs,
|
||||||
uint32_t batch_size,
|
uint32_t batch_size,
|
||||||
@@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
&retrive_next_token,
|
&retrive_next_token,
|
||||||
&retrive_next_sibling,
|
&retrive_next_sibling,
|
||||||
&uniform_samples,
|
&uniform_samples,
|
||||||
|
&uniform_samples_for_final_sampling,
|
||||||
&target_probs,
|
&target_probs,
|
||||||
&draft_probs,
|
&draft_probs,
|
||||||
&batch_size,
|
&batch_size,
|
||||||
@@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
|||||||
VEC_SIZE,
|
VEC_SIZE,
|
||||||
DETERMINISTIC,
|
DETERMINISTIC,
|
||||||
DType,
|
DType,
|
||||||
IdType>;
|
IdType,
|
||||||
|
IdType2>;
|
||||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||||
})});
|
})});
|
||||||
|
|||||||
@@ -331,6 +331,7 @@ void tree_speculative_sampling_target_only(
|
|||||||
at::Tensor retrive_next_token,
|
at::Tensor retrive_next_token,
|
||||||
at::Tensor retrive_next_sibling,
|
at::Tensor retrive_next_sibling,
|
||||||
at::Tensor uniform_samples,
|
at::Tensor uniform_samples,
|
||||||
|
at::Tensor uniform_samples_for_final_sampling,
|
||||||
at::Tensor target_probs,
|
at::Tensor target_probs,
|
||||||
at::Tensor draft_probs,
|
at::Tensor draft_probs,
|
||||||
double threshold_single = 1,
|
double threshold_single = 1,
|
||||||
@@ -363,7 +364,12 @@ void build_tree_kernel_efficient(
|
|||||||
int64_t draft_token_num);
|
int64_t draft_token_num);
|
||||||
|
|
||||||
void segment_packbits(
|
void segment_packbits(
|
||||||
at::Tensor x, at::Tensor input_indptr, at::Tensor output_indptr, at::Tensor y, int64_t cuda_stream);
|
at::Tensor x,
|
||||||
|
at::Tensor input_indptr,
|
||||||
|
at::Tensor output_indptr,
|
||||||
|
at::Tensor y,
|
||||||
|
int64_t batch_size,
|
||||||
|
int64_t cuda_stream = 0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* From FlashInfer
|
* From FlashInfer
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from sgl_kernel.speculative import (
|
|||||||
tree_speculative_sampling_target_only,
|
tree_speculative_sampling_target_only,
|
||||||
verify_tree_greedy,
|
verify_tree_greedy,
|
||||||
)
|
)
|
||||||
|
from sgl_kernel.top_k import fast_topk
|
||||||
from sgl_kernel.version import __version__
|
from sgl_kernel.version import __version__
|
||||||
|
|
||||||
build_tree_kernel = (
|
build_tree_kernel = (
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ def tree_speculative_sampling_target_only(
|
|||||||
retrive_next_token: torch.Tensor,
|
retrive_next_token: torch.Tensor,
|
||||||
retrive_next_sibling: torch.Tensor,
|
retrive_next_sibling: torch.Tensor,
|
||||||
uniform_samples: torch.Tensor,
|
uniform_samples: torch.Tensor,
|
||||||
|
uniform_samples_for_final_sampling: torch.Tensor,
|
||||||
target_probs: torch.Tensor,
|
target_probs: torch.Tensor,
|
||||||
draft_probs: torch.Tensor,
|
draft_probs: torch.Tensor,
|
||||||
threshold_single: float = 1.0,
|
threshold_single: float = 1.0,
|
||||||
@@ -26,6 +27,7 @@ def tree_speculative_sampling_target_only(
|
|||||||
retrive_next_token,
|
retrive_next_token,
|
||||||
retrive_next_sibling,
|
retrive_next_sibling,
|
||||||
uniform_samples,
|
uniform_samples,
|
||||||
|
uniform_samples_for_final_sampling,
|
||||||
target_probs,
|
target_probs,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
threshold_single,
|
threshold_single,
|
||||||
@@ -91,11 +93,13 @@ def segment_packbits(
|
|||||||
input_indptr: torch.Tensor,
|
input_indptr: torch.Tensor,
|
||||||
output_indptr: torch.Tensor,
|
output_indptr: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops.sgl_kernel.segment_packbits.default(
|
torch.ops.sgl_kernel.segment_packbits.default(
|
||||||
x,
|
x,
|
||||||
input_indptr,
|
input_indptr,
|
||||||
output_indptr,
|
output_indptr,
|
||||||
y,
|
y,
|
||||||
|
batch_size,
|
||||||
torch.cuda.current_stream().cuda_stream,
|
torch.cuda.current_stream().cuda_stream,
|
||||||
)
|
)
|
||||||
|
|||||||
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def fast_topk(values, topk, dim):
|
||||||
|
if topk == 1:
|
||||||
|
# Use max along the specified dimension to get both value and index
|
||||||
|
return torch.max(values, dim=dim, keepdim=True)
|
||||||
|
else:
|
||||||
|
# Use topk for efficiency with larger k values
|
||||||
|
# TODO: implement faster cuda kernels for large vocab sizes
|
||||||
|
return torch.topk(values, topk, dim=dim)
|
||||||
@@ -10,7 +10,7 @@ def test_verify_tree_greedy():
|
|||||||
[0, 1, 2, 3, 4, 5],
|
[0, 1, 2, 3, 4, 5],
|
||||||
[7, 8, 9, 10, 11, 12],
|
[7, 8, 9, 10, 11, 12],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
retrive_index = torch.tensor(
|
retrive_index = torch.tensor(
|
||||||
@@ -18,7 +18,7 @@ def test_verify_tree_greedy():
|
|||||||
[0, 1, 2, 3, 4, 5],
|
[0, 1, 2, 3, 4, 5],
|
||||||
[6, 7, 8, 9, 10, 11],
|
[6, 7, 8, 9, 10, 11],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
retrive_next_token = torch.tensor(
|
retrive_next_token = torch.tensor(
|
||||||
@@ -26,7 +26,7 @@ def test_verify_tree_greedy():
|
|||||||
[1, 2, -1, 4, 5, -1],
|
[1, 2, -1, 4, 5, -1],
|
||||||
[4, 2, 3, -1, 5, -1],
|
[4, 2, 3, -1, 5, -1],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
retrive_next_sibling = torch.tensor(
|
retrive_next_sibling = torch.tensor(
|
||||||
@@ -34,7 +34,7 @@ def test_verify_tree_greedy():
|
|||||||
[-1, 3, -1, -1, -1, -1],
|
[-1, 3, -1, -1, -1, -1],
|
||||||
[-1, -1, -1, -1, 1, -1],
|
[-1, -1, -1, -1, 1, -1],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,12 +49,11 @@ def test_verify_tree_greedy():
|
|||||||
if torch.max(target_logits[i][j]) < 10:
|
if torch.max(target_logits[i][j]) < 10:
|
||||||
target_logits[i][j][18] = 10
|
target_logits[i][j][18] = 10
|
||||||
|
|
||||||
target_predict = torch.argmax(target_logits, dim=-1).to(torch.int32)
|
target_predict = torch.argmax(target_logits, dim=-1)
|
||||||
predict_shape = (12,)
|
predict_shape = (12,)
|
||||||
|
|
||||||
bs = candidates.shape[0]
|
bs = candidates.shape[0]
|
||||||
num_spec_step = 4
|
num_spec_step = 4
|
||||||
num_draft_tokens = candidates.shape[1]
|
|
||||||
|
|
||||||
predicts = torch.full(
|
predicts = torch.full(
|
||||||
predict_shape, -1, dtype=torch.int32, device="cuda"
|
predict_shape, -1, dtype=torch.int32, device="cuda"
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
[0, 1, 2, 3, 4, 5],
|
[0, 1, 2, 3, 4, 5],
|
||||||
[7, 8, 9, 10, 11, 12],
|
[7, 8, 9, 10, 11, 12],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
retrive_index = torch.tensor(
|
retrive_index = torch.tensor(
|
||||||
@@ -50,7 +50,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
[0, 1, 2, 3, 4, 5],
|
[0, 1, 2, 3, 4, 5],
|
||||||
[6, 7, 8, 9, 10, 11],
|
[6, 7, 8, 9, 10, 11],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
retrive_next_token = torch.tensor(
|
retrive_next_token = torch.tensor(
|
||||||
@@ -58,7 +58,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
[1, 2, -1, 4, 5, -1],
|
[1, 2, -1, 4, 5, -1],
|
||||||
[4, 2, 3, -1, 5, -1],
|
[4, 2, 3, -1, 5, -1],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
retrive_next_sibling = torch.tensor(
|
retrive_next_sibling = torch.tensor(
|
||||||
@@ -66,7 +66,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
[-1, 3, -1, -1, -1, -1],
|
[-1, 3, -1, -1, -1, -1],
|
||||||
[-1, -1, -1, -1, 1, -1],
|
[-1, -1, -1, -1, 1, -1],
|
||||||
],
|
],
|
||||||
dtype=torch.int32,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -95,6 +95,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
target_probs = F.softmax(target_logits / expanded_temperature, dim=-1)
|
||||||
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
|
draft_probs = torch.full_like(target_probs, 0, dtype=torch.float32, device=device)
|
||||||
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
|
coins = torch.rand(bs, num_draft_tokens, device=device, dtype=torch.float32)
|
||||||
|
coins_for_final_sampling = torch.rand(bs, device=device).to(torch.float32)
|
||||||
|
|
||||||
tree_speculative_sampling_target_only(
|
tree_speculative_sampling_target_only(
|
||||||
predicts=predicts,
|
predicts=predicts,
|
||||||
@@ -105,6 +106,7 @@ def test_tree_speculative_sampling_target_only(
|
|||||||
retrive_next_token=retrive_next_token,
|
retrive_next_token=retrive_next_token,
|
||||||
retrive_next_sibling=retrive_next_sibling,
|
retrive_next_sibling=retrive_next_sibling,
|
||||||
uniform_samples=coins,
|
uniform_samples=coins,
|
||||||
|
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||||
target_probs=target_probs,
|
target_probs=target_probs,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
threshold_single=threshold_single,
|
threshold_single=threshold_single,
|
||||||
|
|||||||
Reference in New Issue
Block a user