Add greedy verification kernel (#4383)
This commit is contained in:
@@ -49,7 +49,9 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
uint32_t batch_size,
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d) {
|
||||
uint32_t d,
|
||||
DType threshold_single,
|
||||
DType threshold_acc) {
|
||||
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
|
||||
|
||||
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
|
||||
@@ -70,9 +72,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
|
||||
while (cur_index != -1) {
|
||||
IdType draft_index = retrive_index[bx * num_draft_tokens + cur_index];
|
||||
IdType draft_token_id = candidates[bx * num_draft_tokens + cur_index];
|
||||
prob_acc += target_probs[cur_prob_offset + draft_token_id];
|
||||
DType target_prob_single = target_probs[cur_prob_offset + draft_token_id];
|
||||
prob_acc += target_prob_single;
|
||||
|
||||
if (coin < prob_acc) {
|
||||
if (coin <= prob_acc / threshold_acc || target_prob_single >= threshold_single) {
|
||||
// accept token
|
||||
prob_acc = 0.;
|
||||
cur_prob_offset = (bx * num_draft_tokens + cur_index) * d;
|
||||
@@ -169,7 +172,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
uint32_t num_speculative_tokens,
|
||||
uint32_t num_draft_tokens,
|
||||
uint32_t d,
|
||||
bool deterministic,
|
||||
DType threshold_single = 1,
|
||||
DType threshold_acc = 1,
|
||||
bool deterministic = true,
|
||||
cudaStream_t stream = 0) {
|
||||
constexpr uint32_t BLOCK_THREADS = 1024;
|
||||
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
|
||||
@@ -177,6 +182,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
|
||||
dim3 nblks(batch_size);
|
||||
dim3 nthrs(BLOCK_THREADS);
|
||||
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);
|
||||
void* args[] = {
|
||||
&predicts,
|
||||
&output_token_ids,
|
||||
@@ -191,7 +197,9 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
|
||||
&batch_size,
|
||||
&num_speculative_tokens,
|
||||
&num_draft_tokens,
|
||||
&d};
|
||||
&d,
|
||||
&threshold_single,
|
||||
&capped_threshold_acc};
|
||||
DISPATCH_ALIGNED_VEC_SIZE(
|
||||
vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
|
||||
auto kernel = TreeSpeculativeSamplingTargetOnly<
|
||||
|
||||
Reference in New Issue
Block a user