[v0.18.0][Bugfix] Fix the bug of MTP1 crashing in multiple concurrent scenarios. (#7699)
### What this PR does / why we need it? The triton operator does not perform boundary checks on the global position within the loop, leading to the memory overflow in scenarios with multiple concurrency + 1-step MTP launch. Solution: Add a check that global_pos < vec_len, and strictly limit the boundaries of all memory accesses to avoid out-of-bounds writes. backport:#7459 Signed-off-by: Bowen-Leee <caoshankuangren@gmail.com>
This commit is contained in:
@@ -58,16 +58,19 @@ def rejection_greedy_sample_spec_len_1_triton(
|
||||
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
|
||||
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
|
||||
|
||||
# Add validity check for pos within the loop
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
draft_token_id1 = get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = get_element(target_argmax_id, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
bonus_token_ids_ptr,
|
||||
position,
|
||||
output_token_ids_ptr,
|
||||
)
|
||||
# Calculate the global position of the current token
|
||||
global_pos = block_idx * BLOCK_SIZE + pos
|
||||
if global_pos < vec_len:
|
||||
draft_token_id1 = get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = get_element(target_argmax_id, (pos,))
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
bonus_token_ids_ptr,
|
||||
global_pos,
|
||||
output_token_ids_ptr,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||
|
||||
Reference in New Issue
Block a user