[Refactor][Triton] Move reject sample triton kernels into ops/triton (#5324)

### What this PR does / why we need it?
This PR moves reject sample related triton kernels into `ops/triton`.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-12-29 16:15:41 +08:00
committed by GitHub
parent e7e1a7dc05
commit 28b7614322
3 changed files with 403 additions and 356 deletions

View File

@@ -165,14 +165,13 @@ class TestAscendRejectionSampler(TestBase):
# Test Triton kernel path
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", True):
with patch("vllm_ascend.sample.rejection_sampler.expand_kernel"
with patch("vllm_ascend.sample.rejection_sampler.expand_triton"
) as mock_triton:
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
# grid = triton.cdiv(n, BLOCK_SIZE) = triton.cdiv(3, 2) = 2
mock_triton.__getitem__.assert_called_once_with((2, ))
call_args = mock_triton.__getitem__.return_value.call_args[0]
assert (call_args[1] == x).all()
assert (call_args[2] == cu_num_tokens).all()
mock_triton.assert_called_once()
call_args = mock_triton.call_args[0]
assert (call_args[2] == x).all()
assert (call_args[3] == cu_num_tokens).all()
# Run actual function
with patch("vllm_ascend.sample.rejection_sampler.HAS_TRITON", False):

View File

@@ -0,0 +1,377 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_spec_len_1_triton(
output_token_ids_ptr, # [batch_size, 2]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr,
vec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
vec_len,
max_spec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
if is_greedy_ptr is None:
is_greedy_mask = mask
else:
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)
start_idx = tl.where(
offset == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos, ))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected and is_greedy_mask1:
bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exost for greedy sampling requests
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept
token_id = draft_token_id
else:
# Reject. Use recovered token
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens,
bonus_token_id,
)
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
len_mask = offset < vec_len
start_idx = tl.where(offset == 0, 0,
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + offset, len_mask)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
else:
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id)
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)
def rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
):
vec_len = output_token_ids.shape[0]
n = cu_num_draft_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
vectorcore_num = get_vectorcore_num()
if n >= vectorcore_num:
grid = vectorcore_num # Empirically tuned value
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and is_greedy is None:
rejection_greedy_sample_spec_len_1_triton[(grid, )](
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
vec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
rejection_greedy_sample_triton[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
vec_len,
max_spec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
replace_to, max_num_tokens):
vec_len = batch_size
n = cu_num_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
vectorcore_num = get_vectorcore_num()
if n >= vectorcore_num:
grid = vectorcore_num
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
expand_kernel[(grid, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -2,11 +2,14 @@
from typing import Optional
import torch
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.triton_utils import HAS_TRITON, triton
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE,
generate_uniform_probs)
from vllm_ascend.ops.triton.reject_sample import (
expand_triton, rejection_greedy_sample_with_triton,
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
from vllm_ascend.sample.sampler import apply_top_k_top_p
PLACEHOLDER_TOKEN_ID = -1
@@ -14,16 +17,6 @@ PLACEHOLDER_TOKEN_ID = -1
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN = 32
vectorcore_num = None
device_properties = None
if HAS_TRITON:
from triton.runtime import driver # type: ignore
device_properties = driver.active.utils.get_device_properties(
torch.npu.current_device())
vectorcore_num = device_properties['num_vectorcore']
#get vector core number in order for later tiling
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
@@ -130,36 +123,16 @@ def rejection_sample(
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON:
vec_len = batch_size
n = cu_num_draft_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= vectorcore_num:
grid = vectorcore_num # Empirically tuned value
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
rejection_greedy_sample_spec_len_1_triton[(grid, )](
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
vec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
rejection_greedy_sample_triton[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
vec_len,
max_spec_len,
BLOCK_SIZE=BLOCK_SIZE,
)
rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -270,24 +243,13 @@ def expand_batch_to_tokens(
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
if HAS_TRITON:
vec_len = batch_size
n = cu_num_tokens.numel()
BLOCK_SIZE = 2
grid = triton.cdiv(n, BLOCK_SIZE)
if n >= vectorcore_num:
grid = vectorcore_num
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
expand_kernel[(grid, )](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
BLOCK_SIZE=BLOCK_SIZE,
)
expand_triton(batch_size,
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
max_num_tokens=MAX_SPEC_LEN)
else:
expand_pytorch(
expanded_x,
@@ -717,294 +679,3 @@ def sample_recovered_tokens_pytorch(
recovered_ids = torch.argmax(prob_over_q, dim=1)
output_token_ids[:] = recovered_ids
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_spec_len_1_triton(
output_token_ids_ptr, # [batch_size, 2]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr,
vec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos, ))
target_argmax1 = tl.get_element(target_argmax_id, (pos, ))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1,
bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
vec_len,
max_spec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
if is_greedy_ptr is None:
is_greedy_mask = mask
else:
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)
start_idx = tl.where(
offset == 0, 0,
tl.load(cu_num_draft_tokens_ptr + offset - 1, is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos, ))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos, ))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos, ))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected and is_greedy_mask1:
bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
if is_greedy:
# Early exost for greedy sampling requests
return
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
rejected = False
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept
token_id = draft_token_id
else:
# Reject. Use recovered token
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
if not rejected:
# If all tokens are accepted, append the bonus token
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens,
bonus_token_id,
)
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
len_mask = offset < vec_len
start_idx = tl.where(offset == 0, 0,
tl.load(cu_num_tokens_ptr + offset - 1, len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + offset, len_mask)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i, ))
start_idx1 = tl.get_element(start_idx, (i, ))
src_val1 = tl.get_element(src_val, (i, ))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1,
src_val1,
mask=offset1 < num_tokens1)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
0)
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
else:
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id, ))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id)
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
orig_prob)