[Kernel] add triton kernels for sampling (#4550)
### What this PR does / why we need it? Replace pyorch implement of sampling with triton kernels ### Does this PR introduce _any_ user-facing change? No - vLLM version: v0.11.2 --------- Signed-off-by: Lord_of_Ironhill <suiweiyi@huawei.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Co-authored-by: Lord_of_Ironhill <suiweiyi@huawei.com> Co-authored-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import vllm.v1.sample.rejection_sampler as rs
|
import vllm.v1.sample.rejection_sampler as rs
|
||||||
|
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
|
from vllm.v1.sample.rejection_sampler import (RejectionSampler,
|
||||||
apply_sampling_constraints,
|
apply_sampling_constraints,
|
||||||
@@ -149,25 +150,36 @@ def rejection_sample(
|
|||||||
if not sampling_metadata.all_random:
|
if not sampling_metadata.all_random:
|
||||||
# Rejection sampling for greedy sampling requests.
|
# Rejection sampling for greedy sampling requests.
|
||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_probs.argmax(dim=-1)
|
||||||
if min(num_draft_tokens) == 1 and max(
|
if HAS_TRITON:
|
||||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
rejection_greedy_sample_kernel[(batch_size, )](
|
||||||
rejection_greedy_sample_spec_len_1_pytorch(
|
|
||||||
output_token_ids,
|
|
||||||
draft_token_ids,
|
|
||||||
target_argmax,
|
|
||||||
bonus_token_ids,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rejection_greedy_sample_pytorch(
|
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
target_argmax,
|
target_argmax,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
num_draft_tokens,
|
|
||||||
max_spec_len,
|
|
||||||
is_greedy,
|
is_greedy,
|
||||||
|
max_spec_len,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if min(num_draft_tokens) == 1 and max(
|
||||||
|
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||||
|
rejection_greedy_sample_spec_len_1_pytorch(
|
||||||
|
output_token_ids,
|
||||||
|
draft_token_ids,
|
||||||
|
target_argmax,
|
||||||
|
bonus_token_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rejection_greedy_sample_pytorch(
|
||||||
|
output_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
target_argmax,
|
||||||
|
bonus_token_ids,
|
||||||
|
num_draft_tokens,
|
||||||
|
max_spec_len,
|
||||||
|
is_greedy,
|
||||||
|
)
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
@@ -194,21 +206,37 @@ def rejection_sample(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Rejection sampling for random sampling requests.
|
# Rejection sampling for random sampling requests.
|
||||||
rejection_random_sample_pytorch(
|
if HAS_TRITON:
|
||||||
output_token_ids,
|
rejection_random_sample_kernel[(batch_size, )](
|
||||||
cu_num_draft_tokens,
|
output_token_ids,
|
||||||
draft_token_ids,
|
cu_num_draft_tokens,
|
||||||
draft_probs,
|
draft_token_ids,
|
||||||
target_probs,
|
draft_probs,
|
||||||
bonus_token_ids,
|
target_probs,
|
||||||
recovered_token_ids,
|
bonus_token_ids,
|
||||||
uniform_probs,
|
recovered_token_ids,
|
||||||
is_greedy,
|
uniform_probs,
|
||||||
max_spec_len,
|
is_greedy,
|
||||||
vocab_size,
|
max_spec_len,
|
||||||
IS_NGRAM=draft_probs is None,
|
vocab_size,
|
||||||
# num_warps=1,
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
rejection_random_sample_pytorch(
|
||||||
|
output_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
draft_probs,
|
||||||
|
target_probs,
|
||||||
|
bonus_token_ids,
|
||||||
|
recovered_token_ids,
|
||||||
|
uniform_probs,
|
||||||
|
is_greedy,
|
||||||
|
max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
IS_NGRAM=draft_probs is None,
|
||||||
|
# num_warps=1,
|
||||||
|
)
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -241,14 +269,24 @@ def expand_batch_to_tokens(
|
|||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
assert cu_num_tokens.shape[0] == batch_size
|
assert cu_num_tokens.shape[0] == batch_size
|
||||||
expanded_x = x.new_empty(num_tokens)
|
expanded_x = x.new_empty(num_tokens)
|
||||||
expand_pytorch(
|
if HAS_TRITON:
|
||||||
expanded_x,
|
expand_kernel[(batch_size, )](
|
||||||
x,
|
expanded_x,
|
||||||
cu_num_tokens,
|
x,
|
||||||
replace_from,
|
cu_num_tokens,
|
||||||
replace_to,
|
replace_from,
|
||||||
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
replace_to,
|
||||||
)
|
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expand_pytorch(
|
||||||
|
expanded_x,
|
||||||
|
x,
|
||||||
|
cu_num_tokens,
|
||||||
|
replace_from,
|
||||||
|
replace_to,
|
||||||
|
MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation.
|
||||||
|
)
|
||||||
return expanded_x
|
return expanded_x
|
||||||
|
|
||||||
|
|
||||||
@@ -282,16 +320,29 @@ def sample_recovered_tokens(
|
|||||||
q[i].exponential_(generator=generator)
|
q[i].exponential_(generator=generator)
|
||||||
|
|
||||||
recovered_token_ids = torch.empty_like(draft_token_ids)
|
recovered_token_ids = torch.empty_like(draft_token_ids)
|
||||||
sample_recovered_tokens_pytorch(
|
if HAS_TRITON:
|
||||||
recovered_token_ids,
|
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
|
||||||
cu_num_draft_tokens,
|
recovered_token_ids,
|
||||||
draft_token_ids,
|
cu_num_draft_tokens,
|
||||||
draft_probs,
|
draft_token_ids,
|
||||||
target_probs,
|
draft_probs,
|
||||||
q,
|
target_probs,
|
||||||
vocab_size,
|
q,
|
||||||
IS_NGRAM=draft_probs is None,
|
vocab_size,
|
||||||
)
|
triton.next_power_of_2(vocab_size),
|
||||||
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample_recovered_tokens_pytorch(
|
||||||
|
recovered_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
draft_probs,
|
||||||
|
target_probs,
|
||||||
|
q,
|
||||||
|
vocab_size,
|
||||||
|
IS_NGRAM=draft_probs is None,
|
||||||
|
)
|
||||||
return recovered_token_ids
|
return recovered_token_ids
|
||||||
|
|
||||||
|
|
||||||
@@ -504,4 +555,192 @@ def sample_recovered_tokens_pytorch(
|
|||||||
target_probs[token_idx, draft_token_id] = orig_prob
|
target_probs[token_idx, draft_token_id] = orig_prob
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
|
def rejection_greedy_sample_kernel(
|
||||||
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||||
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
|
target_argmax_ptr, # [num_tokens]
|
||||||
|
bonus_token_ids_ptr, # [batch_size]
|
||||||
|
is_greedy_ptr, # [batch_size] or None
|
||||||
|
max_spec_len,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
# Because is_greedy_ptr is not Nonr at profiling run,
|
||||||
|
# re-comilation may happen during runtime when is_greedy_ptr is None.
|
||||||
|
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr +
|
||||||
|
req_idx)
|
||||||
|
if not is_greedy:
|
||||||
|
# Early exit for non-greedy sampling requests
|
||||||
|
return
|
||||||
|
|
||||||
|
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
||||||
|
req_idx - 1)
|
||||||
|
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||||
|
num_draft_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
rejected = False
|
||||||
|
for pos in range(num_draft_tokens):
|
||||||
|
if not rejected:
|
||||||
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
|
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||||
|
tl.store(
|
||||||
|
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
|
target_argmax_id,
|
||||||
|
)
|
||||||
|
if draft_token_id != target_argmax_id:
|
||||||
|
# Reject
|
||||||
|
rejected = True
|
||||||
|
|
||||||
|
if not rejected:
|
||||||
|
# If all tokens are accepted, append the bonus token
|
||||||
|
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||||
|
tl.store(
|
||||||
|
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||||
|
num_draft_tokens,
|
||||||
|
bonus_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
|
def rejection_random_sample_kernel(
|
||||||
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||||
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
|
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||||
|
target_probs_ptr, # [num_tokens, vocab_size]
|
||||||
|
bonus_token_ids_ptr, # [batch_size]
|
||||||
|
recovered_token_ids_ptr, # [num_tokens]
|
||||||
|
uniform_probs_ptr, # [num_tokens]
|
||||||
|
is_greedy_ptr, # [batch_size]
|
||||||
|
max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
NO_DRAFT_PROBS: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
||||||
|
if is_greedy:
|
||||||
|
# Early exost for greedy sampling requests
|
||||||
|
return
|
||||||
|
|
||||||
|
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
||||||
|
req_idx - 1)
|
||||||
|
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||||
|
num_draft_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
rejected = False
|
||||||
|
for pos in range(num_draft_tokens):
|
||||||
|
if not rejected:
|
||||||
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
|
if NO_DRAFT_PROBS:
|
||||||
|
draft_prob = 1
|
||||||
|
else:
|
||||||
|
draft_prob = tl.load(draft_probs_ptr +
|
||||||
|
(start_idx + pos) * vocab_size +
|
||||||
|
draft_token_id)
|
||||||
|
target_prob = tl.load(target_probs_ptr +
|
||||||
|
(start_idx + pos) * vocab_size +
|
||||||
|
draft_token_id)
|
||||||
|
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||||
|
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||||
|
# Accept
|
||||||
|
token_id = draft_token_id
|
||||||
|
else:
|
||||||
|
# Reject. Use recovered token
|
||||||
|
rejected = True
|
||||||
|
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||||
|
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
|
token_id)
|
||||||
|
|
||||||
|
if not rejected:
|
||||||
|
# If all tokens are accepted, append the bonus token
|
||||||
|
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||||
|
tl.store(
|
||||||
|
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||||
|
num_draft_tokens,
|
||||||
|
bonus_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||||
|
def expand_kernel(
|
||||||
|
output_ptr, # [num_tokens]
|
||||||
|
input_ptr, # [batch_size]
|
||||||
|
cu_num_tokens_ptr, # [batch_size]
|
||||||
|
replace_from,
|
||||||
|
replace_to,
|
||||||
|
MAX_NUM_TOKENS: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
if req_idx == 0:
|
||||||
|
start_idx = 0
|
||||||
|
else:
|
||||||
|
start_idx = tl.load(cu_num_tokens_ptr + req_idx - 1)
|
||||||
|
end_idx = tl.load(cu_num_tokens_ptr + req_idx)
|
||||||
|
num_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
src_val = tl.load(input_ptr + req_idx)
|
||||||
|
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||||
|
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||||
|
tl.store(output_ptr + start_idx + offset,
|
||||||
|
src_val,
|
||||||
|
mask=offset < num_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def sample_recovered_tokens_kernel(
|
||||||
|
output_token_ids_ptr, # [num_tokens]
|
||||||
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
|
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||||
|
target_probs_ptr, # [num_tokens, vocab_size]
|
||||||
|
q_ptr, # [batch_size, vocab_size]
|
||||||
|
vocab_size,
|
||||||
|
PADDED_VOCAB_SIZE: tl.constexpr,
|
||||||
|
NO_DRAFT_PROBS: tl.constexpr,
|
||||||
|
):
|
||||||
|
req_idx = tl.program_id(0)
|
||||||
|
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
||||||
|
req_idx - 1)
|
||||||
|
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||||
|
num_draft_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
# Early exit for out-of-range positions
|
||||||
|
pos = tl.program_id(1)
|
||||||
|
if pos >= num_draft_tokens:
|
||||||
|
return
|
||||||
|
|
||||||
|
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||||
|
if NO_DRAFT_PROBS:
|
||||||
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
|
prob = tl.load(
|
||||||
|
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||||
|
mask=((vocab_offset < vocab_size) &
|
||||||
|
(vocab_offset != draft_token_id)),
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
draft_prob = tl.load(
|
||||||
|
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||||
|
mask=vocab_offset < vocab_size,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
target_prob = tl.load(
|
||||||
|
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||||
|
mask=vocab_offset < vocab_size,
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||||
|
# We don't need `prob = prob / tl.sum(prob)` here because
|
||||||
|
# `tl.argmax` will select the maximum value.
|
||||||
|
|
||||||
|
q = tl.load(
|
||||||
|
q_ptr + req_idx * vocab_size + vocab_offset,
|
||||||
|
mask=vocab_offset < vocab_size,
|
||||||
|
other=float("-inf"),
|
||||||
|
)
|
||||||
|
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||||
|
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||||
|
|
||||||
|
|
||||||
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
rs.expand_batch_to_tokens = expand_batch_to_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user