[v0.18.0][Bugfix] Fix the bug of MTP1 crashing in multiple concurrent scenarios. (#7699)

### What this PR does / why we need it?
The triton operator does not perform boundary checks on the global
position within the loop, leading to the memory overflow in scenarios
with multiple concurrency + 1-step MTP launch.

Solution: Add a check that global_pos < vec_len, and strictly limit the
boundaries of all memory accesses to avoid out-of-bounds writes.
backport:#7459

Signed-off-by: Bowen-Leee <caoshankuangren@gmail.com>
This commit is contained in:
bowenli
2026-03-27 14:13:12 +08:00
committed by GitHub
parent 6ce1dc162a
commit 048c8d1afe

View File

@@ -58,16 +58,19 @@ def rejection_greedy_sample_spec_len_1_triton(
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
# Add validity check for pos within the loop
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = get_element(draft_token_id, (pos,))
target_argmax1 = get_element(target_argmax_id, (pos,))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)
# Calculate the global position of the current token
global_pos = block_idx * BLOCK_SIZE + pos
if global_pos < vec_len:
draft_token_id1 = get_element(draft_token_id, (pos,))
target_argmax1 = get_element(target_argmax_id, (pos,))
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
global_pos,
output_token_ids_ptr,
)
@triton.jit(do_not_specialize=["max_spec_len"])