diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 4c85a430..5b8d3e2e 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -58,16 +58,19 @@ def rejection_greedy_sample_spec_len_1_triton( target_argmax_id = tl.load(target_argmax_ptr + offset, mask) tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) + # Add validity check for pos within the loop for pos in tl.range(0, BLOCK_SIZE): - draft_token_id1 = get_element(draft_token_id, (pos,)) - target_argmax1 = get_element(target_argmax_id, (pos,)) - position = block_idx * BLOCK_SIZE + pos - if draft_token_id1 == target_argmax1: - bonus_renew_1( - bonus_token_ids_ptr, - position, - output_token_ids_ptr, - ) + # Calculate the global position of the current token + global_pos = block_idx * BLOCK_SIZE + pos + if global_pos < vec_len: + draft_token_id1 = get_element(draft_token_id, (pos,)) + target_argmax1 = get_element(target_argmax_id, (pos,)) + if draft_token_id1 == target_argmax1: + bonus_renew_1( + bonus_token_ids_ptr, + global_pos, + output_token_ids_ptr, + ) @triton.jit(do_not_specialize=["max_spec_len"])