Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
@@ -34,16 +34,18 @@ template <
|
||||
uint32_t VEC_SIZE,
|
||||
bool DETERMINISTIC,
|
||||
typename DType,
|
||||
typename IdType>
|
||||
typename IdType,
|
||||
typename IdType2>
|
||||
__global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts, // mutable
|
||||
IdType* accept_index, // mutable
|
||||
IdType* accept_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
IdType2* candidates,
|
||||
IdType2* retrive_index,
|
||||
IdType2* retrive_next_token,
|
||||
IdType2* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* uniform_samples_for_final_sampling,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
@@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
DType prob_acc = 0.0;
|
||||
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
|
||||
DType coin = uniform_samples[bx * num_draft_tokens];
|
||||
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
|
||||
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
|
||||
uint32_t num_accepted_tokens = 0;
|
||||
IdType cur_index = 0;
|
||||
IdType2 cur_index = 0;
|
||||
|
||||
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
|
||||
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
|
||||
while (cur_index != -1) {
|
||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
||||
prob_acc += target_prob_single;
|
||||
|
||||
@@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
}
|
||||
accept_token_num[bx] = num_accepted_tokens;
|
||||
|
||||
// we need a different coin for the final sampling
|
||||
coin = uniform_samples_for_final_sampling[bx];
|
||||
|
||||
// sample from relu(target_probs - draft_probs)
|
||||
DType sum_relu_q_minus_p(0);
|
||||
vec_t<DType, VEC_SIZE> q_vec, p_vec;
|
||||
@@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
// value at not used indices are undefined
|
||||
}
|
||||
|
||||
template <typename DType, typename IdType>
|
||||
template <typename DType, typename IdType, typename IdType2>
|
||||
cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
IdType* predicts, // mutable
|
||||
IdType* output_token_ids, // mutable
|
||||
IdType* output_accepted_token_num, // mutable
|
||||
IdType* candidates,
|
||||
IdType* retrive_index,
|
||||
IdType* retrive_next_token,
|
||||
IdType* retrive_next_sibling,
|
||||
IdType2* candidates,
|
||||
IdType2* retrive_index,
|
||||
IdType2* retrive_next_token,
|
||||
IdType2* retrive_next_sibling,
|
||||
DType* uniform_samples,
|
||||
DType* uniform_samples_for_final_sampling,
|
||||
DType* target_probs,
|
||||
DType* draft_probs,
|
||||
uint32_t batch_size,
|
||||
@@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
&retrive_next_token,
|
||||
&retrive_next_sibling,
|
||||
&uniform_samples,
|
||||
&uniform_samples_for_final_sampling,
|
||||
&target_probs,
|
||||
&draft_probs,
|
||||
&batch_size,
|
||||
@@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
VEC_SIZE,
|
||||
DETERMINISTIC,
|
||||
DType,
|
||||
IdType>;
|
||||
IdType,
|
||||
IdType2>;
|
||||
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
||||
})});
|
||||
|
||||
Reference in New Issue
Block a user