From 696dcc92651cbdbd4b7bb8ec480a89b2195f3307 Mon Sep 17 00:00:00 2001 From: pppeng <60355449+ppppeng@users.noreply.github.com> Date: Thu, 23 Apr 2026 23:04:19 +0800 Subject: [PATCH] [Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: ppppeng Co-authored-by: zepengliu912@qq.com --- vllm_ascend/ops/triton/reject_sample.py | 56 +++++++++++---------- vllm_ascend/ops/triton/spec_decode/utils.py | 2 +- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 5b8d3e2e..728ef706 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -107,7 +107,7 @@ def rejection_greedy_sample_triton( is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0) is_greedy_mask = mask & (is_greedy != 0) - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offset - 1, 0), is_greedy_mask)) end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask) num_draft_tokens = end_idx - start_idx @@ -161,7 +161,9 @@ def rejection_random_sample_kernel( mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where( + offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask) + ) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): @@ -174,21 +176,26 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if NO_DRAFT_PROBS: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) - # NOTE(woosuk): While the draft probability should never be 0, - # we check it to avoid NaNs. If it happens to be 0, we reject. - if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: - # Accept. - token_id = draft_token_id - else: - # Reject. Use recovered token. + if draft_token_id < 0: + # Invalid draft (e.g., padded). rejected = True token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) + else: + if NO_DRAFT_PROBS: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) + uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) + # NOTE(woosuk): While the draft probability should never be 0, + # we check it to avoid NaNs. If it happens to be 0, we reject. + if draft_prob > 0 and target_prob / draft_prob >= uniform_prob: + # Accept. + token_id = draft_token_id + else: + # Reject. Use recovered token. + rejected = True + token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) if not rejected: # If all tokens are accepted, append the bonus token. @@ -214,7 +221,7 @@ def expand_kernel( offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) len_mask = offset < vec_len - start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask)) + start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + tl.maximum(offset - 1, 0), len_mask)) end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask) num_tokens = end_idx - start_idx @@ -257,11 +264,6 @@ def sample_recovered_tokens_kernel( global_max_p = -1.0 if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0) for loop_i in range(loop): vocab_start = loop_i * SUB_BLOCK vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK) @@ -270,6 +272,10 @@ def sample_recovered_tokens_kernel( mask=vocab_offset < vocab_size, other=0, ) + # Temporarily zero out the probability of the draft token. + # This is essentially the same as target_prob - draft_prob, except that + # n-gram does not have draft_prob. We regard it as 1. + prob = tl.where(vocab_offset == draft_token_id, 0, prob) q = tl.load( q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf") ) @@ -307,10 +313,6 @@ def sample_recovered_tokens_kernel( tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id) - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob) - def rejection_greedy_sample_with_triton( output_token_ids, @@ -387,7 +389,9 @@ def rejection_random_sample_block_verify_kernel( mask = offsets < vec_len is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1) not_greedy_mask = is_greedy == 0 - start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask)) + start_idxs = tl.where( + offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask) + ) end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask) n_num_draft_tokens = end_idxs - start_idxs for req_i in range(BLOCK_SIZE): diff --git a/vllm_ascend/ops/triton/spec_decode/utils.py b/vllm_ascend/ops/triton/spec_decode/utils.py index 3c7aa450..7ca40429 100644 --- a/vllm_ascend/ops/triton/spec_decode/utils.py +++ b/vllm_ascend/ops/triton/spec_decode/utils.py @@ -42,7 +42,7 @@ def prepare_inputs_padded_kernel( # cumulative sum (first entry is the first value, not zero). cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask) - prev_indices = offsets - 1 + prev_indices = tl.maximum(offsets - 1, 0) has_prev = offsets > 0 cu_draft_prev = tl.load( cu_num_draft_tokens_ptr + prev_indices,