sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,408 @@
/*
* Copyright (c) 2025 by SGLang team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
int tree_mask_mode) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx;
if (tree_mask_mode == FULL_MASK) {
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
} else {
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
}
tree_mask[token_tree_idx - 1] = true;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
// positions [bs * draft_token]
// retrive_index [bs, draft_token]
// retrive_next_token [bs, draft_token]
// retrive_next_sibling [bs, draft_token]
__global__ void build_tree_efficient_partial_packed(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
uint8_t* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
size_t num_bytes_per_item) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
tree_mask[token_tree_idx] = 1; // little endian
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
int byte_idx = (cur_position + 1) / 8;
int bit_idx = (cur_position + 1) % 8;
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num,
int64_t tree_mask_mode) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
size_t num_bytes_per_item = 1;
if (draft_token_num > 16) {
num_bytes_per_item = 4;
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
build_tree_efficient_partial_packed<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<uint8_t*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
num_bytes_per_item);
} else {
build_tree_efficient<<<grid, block, 0, stream>>>(
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
int32_t(tree_mask_mode));
}
}
template <typename IdType, typename IdType2>
__global__ void VerifyTreeGreedy(
IdType* predicts,
IdType* accept_index,
IdType* accept_token_num, // mutable
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
IdType2* target_predict,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens) {
uint32_t bx = blockIdx.x;
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
uint32_t num_accepted_tokens = 0;
IdType2 cur_index = 0;
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
while (cur_index != -1) {
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
IdType2 target_token_id = target_predict[last_accepted_retrive_idx];
if (draft_token_id == target_token_id) {
// accept token
predicts[last_accepted_retrive_idx] = target_token_id;
++num_accepted_tokens;
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
last_accepted_retrive_idx = draft_index;
break;
} else {
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
}
}
if (cur_index == -1) break;
}
accept_token_num[bx] = num_accepted_tokens;
predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx];
}
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// target_predict: [bs, num_draft_tokens]
void verify_tree_greedy(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
CHECK_INPUT(retrive_next_sibling);
CHECK_INPUT(target_predict);
auto device = target_predict.device();
CHECK_EQ(candidates.device(), device);
CHECK_EQ(retrive_index.device(), device);
CHECK_EQ(retrive_next_token.device(), device);
CHECK_EQ(retrive_next_sibling.device(), device);
CHECK_EQ(target_predict.device(), device);
CHECK_DIM(1, predicts);
CHECK_DIM(2, accept_index);
CHECK_DIM(1, accept_token_num);
CHECK_DIM(2, candidates);
CHECK_DIM(2, retrive_index);
CHECK_DIM(2, retrive_next_token);
CHECK_DIM(2, retrive_next_sibling);
CHECK_DIM(2, target_predict);
unsigned int batch_size = candidates.size(0);
unsigned int num_spec_step = accept_index.size(1);
unsigned int num_draft_tokens = candidates.size(1);
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
CHECK_EQ(batch_size, retrive_index.size(0));
CHECK_EQ(batch_size, retrive_next_token.size(0));
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
CHECK_EQ(batch_size, target_predict.size(0));
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
CHECK_EQ(num_draft_tokens, target_predict.size(1));
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
if (predicts.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
}
if (accept_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
}
if (accept_token_num.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
}
if (candidates.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
}
if (retrive_index.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
}
if (retrive_next_token.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
}
if (retrive_next_sibling.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
}
if (target_predict.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
}
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
dim3 grid(batch_size);
dim3 block(1);
VerifyTreeGreedy<int32_t, int64_t><<<grid, block, 0, stream>>>(
static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()),
static_cast<int32_t*>(accept_token_num.data_ptr()),
static_cast<int64_t*>(candidates.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
static_cast<int64_t*>(target_predict.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens);
}

View File

@@ -0,0 +1,411 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#include "hip/hip_runtime.h"
/*
* Copyright (c) 2025 by SGLang team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <ATen/hip/HIPContext.h>
#ifndef USE_ROCM
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#endif
typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode;
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] =
// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b,
// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token]
__global__ void build_tree_efficient(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
bool* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
int tree_mask_mode) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_tree_idx = draft_token_num * draft_token_num * bid;
for (int i = 0; i < bid; i++) {
seq_tree_idx += verified_seq_len[i] * draft_token_num;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx;
if (tree_mask_mode == FULL_MASK) {
token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1;
} else {
token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1;
}
tree_mask[token_tree_idx - 1] = true;
for (int i = 0; i < draft_token_num - 1; i++) {
tree_mask[token_tree_idx + i] = false;
}
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
tree_mask[token_tree_idx + cur_position] = true;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
// parent_list [bs, topk * (depth - 1) + 1)]
// selected_index [bs, draft_token_num - 1]
// verified_seq_len [bs]
// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item]
// positions [bs * draft_token]
// retrive_index [bs, draft_token]
// retrive_next_token [bs, draft_token]
// retrive_next_sibling [bs, draft_token]
__global__ void build_tree_efficient_partial_packed(
int64_t* parent_list,
int64_t* selected_index,
int64_t* verified_seq_len,
uint8_t* tree_mask,
int64_t* positions,
int64_t* retrive_index,
int64_t* retrive_next_token,
int64_t* retrive_next_sibling,
int topk,
int depth,
int draft_token_num,
size_t num_bytes_per_item) {
int bid = blockIdx.x;
int tid = threadIdx.x;
if (tid >= draft_token_num) {
return;
}
int seq_len = verified_seq_len[bid];
int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item;
tree_mask[token_tree_idx] = 1; // little endian
int position = 0;
if (tid == 0) {
positions[bid * draft_token_num] = seq_len;
int retrive_index_offset = bid * draft_token_num;
for (int i = draft_token_num - 1; i > 0; --i) {
int current_token_idx = retrive_index_offset + i;
retrive_index[bid * draft_token_num + i] = current_token_idx;
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk;
int parent_position = 0;
if (parent_tb_idx > 0) {
int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (; parent_position < draft_token_num; ++parent_position) {
if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) {
++parent_position;
break;
}
}
}
if (parent_position == draft_token_num) {
printf(
"WARNING: invalid eagle tree!!! Detected a token with no parent token selected. "
"Please check if the logprob has nan. The token will be ignored to keep proceeding.\n");
continue;
}
if (retrive_next_token[bid * draft_token_num + parent_position] == -1) {
retrive_next_token[bid * draft_token_num + parent_position] = i;
} else {
int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position];
retrive_next_token[bid * draft_token_num + parent_position] = i;
retrive_next_sibling[bid * draft_token_num + i] = origin_next_token;
}
}
retrive_index[bid * draft_token_num] = bid * draft_token_num;
} else {
int cur_position = tid - 1;
while (true) {
position += 1;
int byte_idx = (cur_position + 1) / 8;
int bit_idx = (cur_position + 1) % 8;
tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx);
int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk;
if (parent_tb_idx == 0) {
break;
}
int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx];
for (cur_position = 0; cur_position < draft_token_num; ++cur_position) {
if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) {
break;
}
}
}
positions[bid * draft_token_num + tid] = position + seq_len;
}
}
void build_tree_kernel_efficient(
at::Tensor parent_list,
at::Tensor selected_index,
at::Tensor verified_seq_len,
at::Tensor tree_mask,
at::Tensor positions,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
int64_t topk,
int64_t depth,
int64_t draft_token_num,
int64_t tree_mask_mode) {
// TODO (ying) check shape
// TODO (ying) check type
int bs = parent_list.size(0);
dim3 grid(bs);
dim3 block(draft_token_num);
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (tree_mask_mode == QLEN_ONLY_BITPACKING) {
size_t num_bytes_per_item = 1;
if (draft_token_num > 16) {
num_bytes_per_item = 4;
} else if (draft_token_num > 8) {
num_bytes_per_item = 2;
}
hipLaunchKernelGGL(( build_tree_efficient_partial_packed), dim3(grid), dim3(block), 0, stream,
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<uint8_t*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
num_bytes_per_item);
} else {
hipLaunchKernelGGL(( build_tree_efficient), dim3(grid), dim3(block), 0, stream,
static_cast<int64_t*>(parent_list.data_ptr()),
static_cast<int64_t*>(selected_index.data_ptr()),
static_cast<int64_t*>(verified_seq_len.data_ptr()),
static_cast<bool*>(tree_mask.data_ptr()),
static_cast<int64_t*>(positions.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
int32_t(topk),
int32_t(depth),
int32_t(draft_token_num),
int32_t(tree_mask_mode));
}
}
template <typename IdType, typename IdType2>
__global__ void VerifyTreeGreedy(
IdType* predicts,
IdType* accept_index,
IdType* accept_token_num, // mutable
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
IdType2* target_predict,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens) {
uint32_t bx = blockIdx.x;
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
uint32_t num_accepted_tokens = 0;
IdType2 cur_index = 0;
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
while (cur_index != -1) {
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
IdType2 target_token_id = target_predict[last_accepted_retrive_idx];
if (draft_token_id == target_token_id) {
// accept token
predicts[last_accepted_retrive_idx] = target_token_id;
++num_accepted_tokens;
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
last_accepted_retrive_idx = draft_index;
break;
} else {
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
}
}
if (cur_index == -1) break;
}
accept_token_num[bx] = num_accepted_tokens;
predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx];
}
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// target_predict: [bs, num_draft_tokens]
void verify_tree_greedy(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor target_predict,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
CHECK_INPUT(retrive_next_sibling);
CHECK_INPUT(target_predict);
auto device = target_predict.device();
CHECK_EQ(candidates.device(), device);
CHECK_EQ(retrive_index.device(), device);
CHECK_EQ(retrive_next_token.device(), device);
CHECK_EQ(retrive_next_sibling.device(), device);
CHECK_EQ(target_predict.device(), device);
CHECK_DIM(1, predicts);
CHECK_DIM(2, accept_index);
CHECK_DIM(1, accept_token_num);
CHECK_DIM(2, candidates);
CHECK_DIM(2, retrive_index);
CHECK_DIM(2, retrive_next_token);
CHECK_DIM(2, retrive_next_sibling);
CHECK_DIM(2, target_predict);
unsigned int batch_size = candidates.size(0);
unsigned int num_spec_step = accept_index.size(1);
unsigned int num_draft_tokens = candidates.size(1);
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
CHECK_EQ(batch_size, retrive_index.size(0));
CHECK_EQ(batch_size, retrive_next_token.size(0));
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
CHECK_EQ(batch_size, target_predict.size(0));
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
CHECK_EQ(num_draft_tokens, target_predict.size(1));
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
if (predicts.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
}
if (accept_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
}
if (accept_token_num.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
}
if (candidates.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
}
if (retrive_index.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
}
if (retrive_next_token.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
}
if (retrive_next_sibling.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
}
if (target_predict.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
}
hipStream_t stream = reinterpret_cast<hipStream_t>(cuda_stream);
dim3 grid(batch_size);
dim3 block(1);
hipLaunchKernelGGL(( VerifyTreeGreedy<int32_t, int64_t>), dim3(grid), dim3(block), 0, stream,
static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()),
static_cast<int32_t*>(accept_token_num.data_ptr()),
static_cast<int64_t*>(candidates.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
static_cast<int64_t*>(target_predict.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens);
}

View File

@@ -0,0 +1,51 @@
// This is only a plugin used for flashinfer 0.1.6. The new version does not need it.
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/quantization.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
// bitorder = "little"
void segment_packbits(
at::Tensor x,
at::Tensor input_indptr,
at::Tensor output_indptr,
at::Tensor y,
int64_t batch_size,
int64_t cuda_stream) {
CHECK_INPUT(x);
CHECK_INPUT(input_indptr);
CHECK_INPUT(output_indptr);
auto device = x.device();
CHECK_EQ(input_indptr.device(), device);
CHECK_EQ(output_indptr.device(), device);
CHECK_EQ(y.device(), device);
CHECK_GE(output_indptr.size(0), batch_size + 1);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = quantization::SegmentPackBits(
static_cast<bool*>(x.data_ptr()),
static_cast<uint8_t*>(y.data_ptr()),
static_cast<int32_t*>(input_indptr.data_ptr()),
static_cast<int32_t*>(output_indptr.data_ptr()),
batch_size,
quantization::BitOrder::kLittle,
stream);
}

View File

@@ -0,0 +1,152 @@
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pytorch_extension_utils.h"
#include "speculative_sampling.cuh"
using namespace flashinfer;
// predicts: [tot_num_draft_tokens]
// accept_index: [bs, num_spec_step]
// accept_token_num: [bs]
// candidates: [bs, num_draft_tokens]
// retrive_index: [bs, num_draft_tokens]
// retrive_next_token: [bs, num_draft_tokens]
// retrive_next_sibling: [bs, num_draft_tokens]
// uniform_samples: [bs, num_draft_tokens]
// target_probs: [bs, num_draft_tokens, vocab_size]
void tree_speculative_sampling_target_only(
at::Tensor predicts,
at::Tensor accept_index,
at::Tensor accept_token_num, // mutable
at::Tensor candidates,
at::Tensor retrive_index,
at::Tensor retrive_next_token,
at::Tensor retrive_next_sibling,
at::Tensor uniform_samples,
at::Tensor uniform_samples_for_final_sampling,
at::Tensor target_probs,
at::Tensor draft_probs,
double threshold_single,
double threshold_acc,
bool deterministic = true,
int64_t cuda_stream = 0) {
CHECK_INPUT(candidates);
CHECK_INPUT(retrive_index);
CHECK_INPUT(retrive_next_token);
CHECK_INPUT(retrive_next_sibling);
CHECK_INPUT(uniform_samples);
CHECK_INPUT(uniform_samples_for_final_sampling);
CHECK_INPUT(target_probs);
auto device = target_probs.device();
CHECK_EQ(candidates.device(), device);
CHECK_EQ(retrive_index.device(), device);
CHECK_EQ(retrive_next_token.device(), device);
CHECK_EQ(retrive_next_sibling.device(), device);
CHECK_EQ(uniform_samples.device(), device);
CHECK_EQ(uniform_samples_for_final_sampling.device(), device);
CHECK_EQ(target_probs.device(), device);
CHECK_DIM(1, predicts);
CHECK_DIM(2, accept_index);
CHECK_DIM(1, accept_token_num);
CHECK_DIM(2, candidates);
CHECK_DIM(2, retrive_index);
CHECK_DIM(2, retrive_next_token);
CHECK_DIM(2, retrive_next_sibling);
CHECK_DIM(2, uniform_samples);
CHECK_DIM(3, target_probs);
CHECK_DIM(3, draft_probs);
unsigned int batch_size = uniform_samples.size(0);
unsigned int num_spec_step = accept_index.size(1);
unsigned int num_draft_tokens = candidates.size(1);
unsigned int vocab_size = target_probs.size(2);
CHECK_EQ(batch_size, candidates.size(0));
CHECK_EQ(batch_size, retrive_index.size(0));
CHECK_EQ(batch_size, retrive_next_token.size(0));
CHECK_EQ(batch_size, retrive_next_sibling.size(0));
CHECK_EQ(batch_size, target_probs.size(0));
CHECK_EQ(num_draft_tokens, retrive_index.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_token.size(1));
CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1));
CHECK_EQ(num_draft_tokens, uniform_samples.size(1));
CHECK_EQ(num_draft_tokens, target_probs.size(1));
CHECK_EQ(vocab_size, target_probs.size(2));
CHECK_EQ(batch_size, accept_index.size(0));
CHECK_EQ(batch_size, accept_token_num.size(0));
if (predicts.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32).");
}
if (accept_index.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32).");
}
if (accept_token_num.scalar_type() != at::kInt) {
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
}
if (candidates.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
}
if (retrive_index.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
}
if (retrive_next_token.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
}
if (retrive_next_sibling.scalar_type() != at::kLong) {
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
}
if (uniform_samples.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'uniform_samples' to be of type float (torch.float32).");
}
if (uniform_samples_for_final_sampling.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'uniform_samples_for_final_sampling' to be of type float (torch.float32).");
}
if (target_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
if (draft_probs.scalar_type() != at::kFloat) {
throw std::runtime_error("Expected 'target_probs' to be of type float (torch.float32).");
}
CHECK_GE(threshold_single, 0);
CHECK_GE(1, threshold_single);
CHECK_GE(threshold_acc, 0);
CHECK_GE(1, threshold_acc);
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
cudaError_t status = sampling::TreeSpeculativeSamplingTargetOnly<float, int32_t, int64_t>(
static_cast<int32_t*>(predicts.data_ptr()),
static_cast<int32_t*>(accept_index.data_ptr()),
static_cast<int32_t*>(accept_token_num.data_ptr()),
static_cast<int64_t*>(candidates.data_ptr()),
static_cast<int64_t*>(retrive_index.data_ptr()),
static_cast<int64_t*>(retrive_next_token.data_ptr()),
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
static_cast<float*>(uniform_samples.data_ptr()),
static_cast<float*>(uniform_samples_for_final_sampling.data_ptr()),
static_cast<float*>(target_probs.data_ptr()),
static_cast<float*>(draft_probs.data_ptr()),
batch_size,
num_spec_step,
num_draft_tokens,
vocab_size,
static_cast<float>(threshold_single),
static_cast<float>(threshold_acc),
deterministic,
stream);
TORCH_CHECK(
status == cudaSuccess,
"TreeSpeculativeSamplingTargetOnly failed with error code " + std::string(cudaGetErrorString(status)));
}

View File

@@ -0,0 +1,231 @@
/*
* Copyright (c) 2025 by SGLang team.
* Copyright (c) 2024-2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SPECULATIVE_SAMPLING_CUH_
#define SPECULATIVE_SAMPLING_CUH_
#include <assert.h>
#include <flashinfer/sampling.cuh>
namespace flashinfer {
namespace sampling {
using namespace cub;
template <
uint32_t BLOCK_THREADS,
BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM,
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType,
typename IdType2>
__global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts, // mutable
IdType* accept_index, // mutable
IdType* accept_token_num, // mutable
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
DType* uniform_samples,
DType* uniform_samples_for_final_sampling,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
DType threshold_single,
DType threshold_acc) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
DType coin = uniform_samples[bx * num_draft_tokens];
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
uint32_t num_accepted_tokens = 0;
IdType2 cur_index = 0;
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
while (cur_index != -1) {
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
prob_acc += target_prob_single;
if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) {
// accept token
prob_acc = 0.;
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
coin = uniform_samples[bx * num_draft_tokens + cur_index];
predicts[last_accepted_retrive_idx] = draft_token_id;
++num_accepted_tokens;
accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index;
last_accepted_retrive_idx = draft_index;
break;
} else {
// FIXME: leverage draft probs
draft_probs[cur_prob_offset + draft_token_id] = target_probs[cur_prob_offset + draft_token_id];
cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index];
}
}
if (cur_index == -1) break;
}
accept_token_num[bx] = num_accepted_tokens;
// we need a different coin for the final sampling
coin = uniform_samples_for_final_sampling[bx];
// sample from relu(target_probs - draft_probs)
DType sum_relu_q_minus_p(0);
vec_t<DType, VEC_SIZE> q_vec, p_vec;
DType relu_q_minus_p[VEC_SIZE];
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(DType(0));
p_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
if (num_accepted_tokens != num_speculative_tokens - 1) {
// there is no draft_probs for the bonus token
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
}
sum_relu_q_minus_p += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Sum<VEC_SIZE>(relu_q_minus_p);
__syncthreads();
}
if (tx == 0) {
temp_storage.block_aggregate.value = sum_relu_q_minus_p;
}
// init the first rejected token to (d - 1)
temp_storage.sampled_id = d - 1;
__syncthreads();
sum_relu_q_minus_p = temp_storage.block_aggregate.value;
DType u = coin * sum_relu_q_minus_p;
DType aggregate_relu_q_minus_p(0);
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
q_vec.fill(DType(0));
p_vec.fill(DType(0));
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
q_vec.load(target_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
if (num_accepted_tokens != num_speculative_tokens - 1) {
// there is no draft_probs for the bonus token
p_vec.load(draft_probs + cur_prob_offset + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
}
vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
}
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
if (aggregate_relu_q_minus_p > u) {
break;
}
}
__syncthreads();
// set the first rejected token
predicts[last_accepted_retrive_idx] = temp_storage.sampled_id;
// value at not used indices are undefined
}
template <typename DType, typename IdType, typename IdType2>
cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts, // mutable
IdType* output_token_ids, // mutable
IdType* output_accepted_token_num, // mutable
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
DType* uniform_samples,
DType* uniform_samples_for_final_sampling,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
DType threshold_single = 1,
DType threshold_acc = 1,
bool deterministic = true,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
void* args[] = {
&predicts,
&output_token_ids,
&output_accepted_token_num,
&candidates,
&retrive_index,
&retrive_next_token,
&retrive_next_sibling,
&uniform_samples,
&uniform_samples_for_final_sampling,
&target_probs,
&draft_probs,
&batch_size,
&num_speculative_tokens,
&num_draft_tokens,
&d,
&threshold_single,
&capped_threshold_acc};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TreeSpeculativeSamplingTargetOnly<
BLOCK_THREADS,
SCAN_ALGO,
REDUCE_ALGO,
VEC_SIZE,
DETERMINISTIC,
DType,
IdType,
IdType2>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});
return cudaSuccess;
}
} // namespace sampling
} // namespace flashinfer
#endif // SPECULATIVE_SAMPLING_CUH_