diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 5b8d3e2e..728ef706 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -107,7 +107,7 @@ def rejection_greedy_sample_triton( is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) is_greedy_mask = mask & (is_greedy != 0) - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offset - 1, 0), is_greedy_mask)) end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) num_draft_tokens = end_idx - start_idx @@ -161,7 +161,9 @@ def rejection_random_sample_kernel( mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where( + offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask) + ) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): @@ -174,21 +176,26 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if NO_DRAFT_PROBS: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - # NOTE(woosuk): While the draft probability should never be 0, - # we check it to avoid NaNs. If it happens to be 0, we reject. - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - # Accept. - token_id = draft_token_id - else: - # Reject. Use recovered token. + if draft_token_id < 0: + # Invalid draft (e.g., padded). rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + else: + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. @@ -214,7 +221,7 @@ def expand_kernel( offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) len_mask = offset < vec_len - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + tl.maximum(offset - 1, 0), len_mask)) end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) num_tokens = end_idx - start_idx @@ -257,11 +264,6 @@ def sample_recovered_tokens_kernel( global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) @@ -270,6 +272,10 @@ def sample_recovered_tokens_kernel( mask=vocab_offset < vocab_size, other=0, ) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + prob = tl.where(vocab_offset == draft_token_id, 0, prob) q = tl.load( q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf") ) @@ -307,10 +313,6 @@ def sample_recovered_tokens_kernel( tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) - def rejection_greedy_sample_with_triton( output_token_ids, @@ -387,7 +389,9 @@ def rejection_random_sample_block_verify_kernel( mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where( + offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask) + ) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index 3c7aa450..7ca40429 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -42,7 +42,7 @@ def prepare_inputs_padded_kernel( # cumulative sum (first entry is the first value, not zero). cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask) - prev_indices = offsets - 1 + prev_indices = tl.maximum(offsets - 1, 0) has_prev = offsets > 0 cu_draft_prev = tl.load( cu_num_draft_tokens_ptr + prev_indices,