From 048c8d1afe5efd98a158c0879b61a14235f69265 Mon Sep 17 00:00:00 2001 From: bowenli <125331496+Bowen-Leee@users.noreply.github.com> Date: Fri, 27 Mar 2026 14:13:12 +0800 Subject: [PATCH] [v0.18.0][Bugfix] Fix the bug of MTP1 crashing in multiple concurrent scenarios. (#7699) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? The triton operator does not perform boundary checks on the global position within the loop, leading to the memory overflow in scenarios with multiple concurrency + 1-step MTP launch. Solution: Add a check that global_pos < vec_len, and strictly limit the boundaries of all memory accesses to avoid out-of-bounds writes. backport:#7459 Signed-off-by: Bowen-Leee --- vllm_ascend/ops/triton/reject_sample.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm_ascend/ops/triton/reject_sample.py b/vllm_ascend/ops/triton/reject_sample.py index 4c85a430..5b8d3e2e 100644 --- a/vllm_ascend/ops/triton/reject_sample.py +++ b/vllm_ascend/ops/triton/reject_sample.py @@ -58,16 +58,19 @@ def rejection_greedy_sample_spec_len_1_triton( target_argmax_id = tl.load(target_argmax_ptr + offset, mask) tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask) + # Add validity check for pos within the loop for pos in tl.range(0, BLOCK_SIZE): - draft_token_id1 = get_element(draft_token_id, (pos,)) - target_argmax1 = get_element(target_argmax_id, (pos,)) - position = block_idx * BLOCK_SIZE + pos - if draft_token_id1 == target_argmax1: - bonus_renew_1( - bonus_token_ids_ptr, - position, - output_token_ids_ptr, - ) + # Calculate the global position of the current token + global_pos = block_idx * BLOCK_SIZE + pos + if global_pos < vec_len: + draft_token_id1 = get_element(draft_token_id, (pos,)) + target_argmax1 = get_element(target_argmax_id, (pos,)) + if draft_token_id1 == target_argmax1: + bonus_renew_1( + bonus_token_ids_ptr, + global_pos, + output_token_ids_ptr, + ) @triton.jit(do_not_specialize=["max_spec_len"])