Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -32,7 +32,7 @@
|
||||
__global__ void build_tree_efficient(
|
||||
int64_t* parent_list,
|
||||
int64_t* selected_index,
|
||||
int32_t* verified_seq_len,
|
||||
int64_t* verified_seq_len,
|
||||
bool* tree_mask,
|
||||
int64_t* positions,
|
||||
int64_t* retrive_index,
|
||||
@@ -135,7 +135,7 @@ void build_tree_kernel_efficient(
|
||||
build_tree_efficient<<<grid, block, 0, stream>>>(
|
||||
static_cast<int64_t*>(parent_list.data_ptr()),
|
||||
static_cast<int64_t*>(selected_index.data_ptr()),
|
||||
static_cast<int32_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<int64_t*>(verified_seq_len.data_ptr()),
|
||||
static_cast<bool*>(tree_mask.data_ptr()),
|
||||
static_cast<int64_t*>(positions.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
@@ -146,32 +146,32 @@ void build_tree_kernel_efficient(
|
||||
int32_t(draft_token_num));
|
||||
}
|
||||
|
||||
template <typename IdType>
|
||||
template <typename IdType, typename IdType2>
|
||||
__global__ void VerifyTreeGreedy(
|
||||
IdType* predicts,
|
||||
IdType* accept_index,
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
IdType* target_predict,
|
||||
IdType2* candidates,
|
||||
IdType2* retrive_index,
|
||||
IdType2* retrive_next_token,
|
||||
IdType2* retrive_next_sibling,
|
||||
IdType2* target_predict,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens) {
|
||||
uint32_t bx = blockIdx.x;
|
||||
|
||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||
uint32_t num_accepted_tokens = 0;
|
||||
IdType cur_index = 0;
|
||||
IdType2 cur_index = 0;
|
||||
|
||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||
while (cur_index != -1) {
|
||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
IdType target_token_id = target_predict[last_accepted_retrive_idx];
|
||||
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
IdType2 target_token_id = target_predict[last_accepted_retrive_idx];
|
||||
|
||||
if (draft_token_id == target_token_id) {
|
||||
// accept token
|
||||
@@ -251,35 +251,35 @@ void verify_tree_greedy(
|
||||
if (accept_token_num.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32).");
|
||||
}
|
||||
if (candidates.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type int (torch.int32).");
|
||||
if (candidates.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_index.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type int (torch.int32).");
|
||||
if (retrive_index.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_next_token.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type int (torch.int32).");
|
||||
if (retrive_next_token.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64).");
|
||||
}
|
||||
if (retrive_next_sibling.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type int (torch.int32).");
|
||||
if (retrive_next_sibling.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64).");
|
||||
}
|
||||
if (target_predict.scalar_type() != at::kInt) {
|
||||
throw std::runtime_error("Expected 'target_predict' to be of type int (torch.int32).");
|
||||
if (target_predict.scalar_type() != at::kLong) {
|
||||
throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64).");
|
||||
}
|
||||
|
||||
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
dim3 grid(batch_size);
|
||||
dim3 block(1);
|
||||
|
||||
VerifyTreeGreedy<int><<<grid, block, 0, stream>>>(
|
||||
static_cast<int*>(predicts.data_ptr()),
|
||||
static_cast<int*>(accept_index.data_ptr()),
|
||||
static_cast<int*>(accept_token_num.data_ptr()),
|
||||
static_cast<int*>(candidates.data_ptr()),
|
||||
static_cast<int*>(retrive_index.data_ptr()),
|
||||
static_cast<int*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int*>(retrive_next_sibling.data_ptr()),
|
||||
static_cast<int*>(target_predict.data_ptr()),
|
||||
VerifyTreeGreedy<int32_t, int64_t><<<grid, block, 0, stream>>>(
|
||||
static_cast<int32_t*>(predicts.data_ptr()),
|
||||
static_cast<int32_t*>(accept_index.data_ptr()),
|
||||
static_cast<int32_t*>(accept_token_num.data_ptr()),
|
||||
static_cast<int64_t*>(candidates.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_index.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_token.data_ptr()),
|
||||
static_cast<int64_t*>(retrive_next_sibling.data_ptr()),
|
||||
static_cast<int64_t*>(target_predict.data_ptr()),
|
||||
batch_size,
|
||||
num_spec_step,
|
||||
num_draft_tokens);
|
||||
|
||||
Reference in New Issue
Block a user