diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 9bd941fc..b7905373 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -215,7 +215,7 @@ def rejection_sample( target_probs, bonus_token_ids, recovered_token_ids, - uniform_probs, + uniform_probs.to(torch.float32), is_greedy, max_spec_len, vocab_size, @@ -331,6 +331,9 @@ def sample_recovered_tokens( vocab_size, triton.next_power_of_2(vocab_size), NO_DRAFT_PROBS=draft_probs is None, + SUB_BLOCK=4 * 1024, + # TODO: enable multibuffer when accuracy problem is solved. + multibuffer=False, ) else: sample_recovered_tokens_pytorch( @@ -698,6 +701,7 @@ def sample_recovered_tokens_kernel( vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, + SUB_BLOCK: tl.constexpr, ): req_idx = tl.program_id(0) start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + @@ -705,42 +709,74 @@ def sample_recovered_tokens_kernel( end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) num_draft_tokens = end_idx - start_idx - # Early exit for out-of-range positions + # Early exit for out-of-range positions. pos = tl.program_id(1) if pos >= num_draft_tokens: return - vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) + loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK + global_recovered_id = -1 + global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=((vocab_offset < vocab_size) & - (vocab_offset != draft_token_id)), - other=0, - ) + orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + draft_token_id) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + 0) + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = tl.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id else: - draft_prob = tl.load( - draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - target_prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - prob = tl.maximum(target_prob - draft_prob, 0) - # We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. + for loop_i in range(loop): + vocab_start = loop_i * SUB_BLOCK + vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + + vocab_offset, + mask=vocab_offset < vocab_size, + other=0) + prob = tl.maximum(target_prob - draft_prob, 0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. - q = tl.load( - q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf"), - ) - recovered_id = tl.argmax(prob / q, axis=-1) - tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_offset < vocab_size, + other=float("-inf")) + new_p = prob / q + recovered_id = tl.argmax(new_p, axis=-1) + max_p = tl.get_element(new_p, (recovered_id, )) + if max_p > global_max_p: + global_max_p = max_p + global_recovered_id = vocab_start + recovered_id + + tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) + + if NO_DRAFT_PROBS: + # Restore the original probability. + tl.store( + target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, + orig_prob) rs.expand_batch_to_tokens = expand_batch_to_tokens