Fix sampling for speculative decoding & simplify kernels (#7207)

This commit is contained in:
Lianmin Zheng
2025-06-16 03:28:30 -07:00
committed by GitHub
parent b1286a116a
commit cfceb83d05
11 changed files with 124 additions and 79 deletions

View File

@@ -34,16 +34,18 @@ template <
uint32_t VEC_SIZE,
bool DETERMINISTIC,
typename DType,
typename IdType>
typename IdType,
typename IdType2>
__global__ void TreeSpeculativeSamplingTargetOnly(
IdType* predicts, // mutable
IdType* accept_index, // mutable
IdType* accept_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
DType* uniform_samples,
DType* uniform_samples_for_final_sampling,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
@@ -62,16 +64,16 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
DType prob_acc = 0.0;
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
DType coin = uniform_samples[bx * num_draft_tokens];
IdType last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens];
accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx;
uint32_t num_accepted_tokens = 0;
IdType cur_index = 0;
IdType2 cur_index = 0;
for (uint32_t j = 1; j < num_speculative_tokens; ++j) {
cur_index = retrive_next_token[bx * num_draft_tokens + cur_index];
while (cur_index != -1) {
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index];
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
prob_acc += target_prob_single;
@@ -95,6 +97,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
}
accept_token_num[bx] = num_accepted_tokens;
// we need a different coin for the final sampling
coin = uniform_samples_for_final_sampling[bx];
// sample from relu(target_probs - draft_probs)
DType sum_relu_q_minus_p(0);
vec_t<DType, VEC_SIZE> q_vec, p_vec;
@@ -156,16 +161,17 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
// value at not used indices are undefined
}
template <typename DType, typename IdType>
template <typename DType, typename IdType, typename IdType2>
cudaError_t TreeSpeculativeSamplingTargetOnly(
IdType* predicts, // mutable
IdType* output_token_ids, // mutable
IdType* output_accepted_token_num, // mutable
IdType* candidates,
IdType* retrive_index,
IdType* retrive_next_token,
IdType* retrive_next_sibling,
IdType2* candidates,
IdType2* retrive_index,
IdType2* retrive_next_token,
IdType2* retrive_next_sibling,
DType* uniform_samples,
DType* uniform_samples_for_final_sampling,
DType* target_probs,
DType* draft_probs,
uint32_t batch_size,
@@ -192,6 +198,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
&retrive_next_token,
&retrive_next_sibling,
&uniform_samples,
&uniform_samples_for_final_sampling,
&target_probs,
&draft_probs,
&batch_size,
@@ -209,7 +216,8 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
VEC_SIZE,
DETERMINISTIC,
DType,
IdType>;
IdType,
IdType2>;
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
})});