[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
This commit is contained in:
@@ -107,7 +107,7 @@ def rejection_greedy_sample_triton(
|
||||
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
|
||||
is_greedy_mask = mask & (is_greedy != 0)
|
||||
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offset - 1, 0), is_greedy_mask))
|
||||
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
@@ -161,7 +161,9 @@ def rejection_random_sample_kernel(
|
||||
mask = offsets < vec_len
|
||||
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
|
||||
not_greedy_mask = is_greedy == 0
|
||||
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
start_idxs = tl.where(
|
||||
offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask)
|
||||
)
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
@@ -174,21 +176,26 @@ def rejection_random_sample_kernel(
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
if draft_token_id < 0:
|
||||
# Invalid draft (e.g., padded).
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
else:
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||
# Accept.
|
||||
token_id = draft_token_id
|
||||
else:
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
@@ -214,7 +221,7 @@ def expand_kernel(
|
||||
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
len_mask = offset < vec_len
|
||||
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
|
||||
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + tl.maximum(offset - 1, 0), len_mask))
|
||||
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
@@ -257,11 +264,6 @@ def sample_recovered_tokens_kernel(
|
||||
global_max_p = -1.0
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, 0)
|
||||
for loop_i in range(loop):
|
||||
vocab_start = loop_i * SUB_BLOCK
|
||||
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
|
||||
@@ -270,6 +272,10 @@ def sample_recovered_tokens_kernel(
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
# Temporarily zero out the probability of the draft token.
|
||||
# This is essentially the same as target_prob - draft_prob, except that
|
||||
# n-gram does not have draft_prob. We regard it as 1.
|
||||
prob = tl.where(vocab_offset == draft_token_id, 0, prob)
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
|
||||
)
|
||||
@@ -307,10 +313,6 @@ def sample_recovered_tokens_kernel(
|
||||
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id)
|
||||
|
||||
if NO_DRAFT_PROBS:
|
||||
# Restore the original probability.
|
||||
tl.store(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, orig_prob)
|
||||
|
||||
|
||||
def rejection_greedy_sample_with_triton(
|
||||
output_token_ids,
|
||||
@@ -387,7 +389,9 @@ def rejection_random_sample_block_verify_kernel(
|
||||
mask = offsets < vec_len
|
||||
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
|
||||
not_greedy_mask = is_greedy == 0
|
||||
start_idxs = tl.where(offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||
start_idxs = tl.where(
|
||||
offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask)
|
||||
)
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
|
||||
@@ -42,7 +42,7 @@ def prepare_inputs_padded_kernel(
|
||||
# cumulative sum (first entry is the first value, not zero).
|
||||
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask)
|
||||
|
||||
prev_indices = offsets - 1
|
||||
prev_indices = tl.maximum(offsets - 1, 0)
|
||||
has_prev = offsets > 0
|
||||
cu_draft_prev = tl.load(
|
||||
cu_num_draft_tokens_ptr + prev_indices,
|
||||
|
||||
Reference in New Issue
Block a user