Add greedy verification kernel (#4383)

This commit is contained in:
Ying Sheng
2025-03-16 00:58:26 -07:00
committed by GitHub
parent 06d12b39d3
commit 52a34d7448
11 changed files with 394 additions and 153 deletions

View File

@@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
uint32_t batch_size,
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d) {
uint32_t d,
DType threshold_single,
DType threshold_acc) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
@@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
while (cur_index != -1) {
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
prob_acc += target_probs[cur_prob_offset + draft_token_id];
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
prob_acc += target_prob_single;
if (coin < prob_acc) {
if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) {
// accept token
prob_acc = 0.;
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
@@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
uint32_t num_speculative_tokens,
uint32_t num_draft_tokens,
uint32_t d,
bool deterministic,
DType threshold_single = 1,
DType threshold_acc = 1,
bool deterministic = true,
cudaStream_t stream = 0) {
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
@@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
void* args[] = {
&predicts,
&output_token_ids,
@@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
&batch_size,
&num_speculative_tokens,
&num_draft_tokens,
&d};
&d,
&threshold_single,
&capped_threshold_acc};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel = TreeSpeculativeSamplingTargetOnly<