feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259)
### What this PR does / why we need it?
This PR introduces optimized Triton implementations for the
rejection_random_sample_kernel delivering superior performance compared
to the existing Triton implementations. The new Triton kernels maintain
full functional accuracy while delivering significant performance
improvements across various batch sizes and MTP configurations.
### Does this PR introduce _any_ user-facing change?
Yes, this PR modifies rejection_sampler.py to use optimized Triton
kernels:
rejection_random_sample_kernel is modified and optimized
### How was this patch tested?
performance benchmark results:
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>
<meta name=Generator content="Microsoft Excel">
<!--[if !mso]>
</head>
<body>
<!--StartFragment-->
Batch Size | MTP | origin implementation(us) | optimized version(us)
-- | -- | -- | --
1 | 1 | 2.934 | 3.64
8 | 1 | 4.467 | 4
32 | 1 | 6.98 | 4.54
64 | 1 | 11.087 | 6.42
128 | 1 | 13.414 | 7.84
256 | 1 | 19.66 | 8.487
512 | 1 | 39.908 | 11.62
1024 | 1 | 81.781 | 18.16
2048 | 1 | 137.923 | 32.934
1 | 2 | 3.4 | 4.02
8 | 2 | 3.74 | 4.24
32 | 2 | 6.373 | 7.394
64 | 2 | 9.747 | 6.46
128 | 2 | 12.98 | 7.76
256 | 2 | 20.834 | 9.787
512 | 2 | 39.314 | 13.56
1024 | 2 | 83.135 | 22.387
2048 | 2 | 157.563 | 40.607
<!--EndFragment-->
</body>
</html>
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
95
tests/e2e/nightly/ops/triton/test_rejection_sample.py
Normal file
95
tests/e2e/nightly/ops/triton/test_rejection_sample.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from vllm.v1.sample.rejection_sampler import \
|
||||||
|
rejection_random_sample_kernel as original_rejection_random_sample_kernel
|
||||||
|
|
||||||
|
from vllm_ascend.ops.triton.reject_sample import (
|
||||||
|
cal_grid_and_block_size, rejection_random_sample_kernel)
|
||||||
|
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_device_properties():
|
||||||
|
init_device_properties_triton()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("max_spec_len", [1, 2, 3])
|
||||||
|
@pytest.mark.parametrize("vocab_size", [151_936])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 8, 32, 64, 128, 256, 512, 1024])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
|
||||||
|
device = 'npu'
|
||||||
|
torch.manual_seed(0)
|
||||||
|
draft_probs = torch.rand(batch_size * max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
target_probs = torch.rand(batch_size * max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
bonus_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
draft_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size * max_spec_len, ),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
output_token_ids = torch.empty((batch_size, max_spec_len + 1),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
original_output_token_ids = output_token_ids.clone()
|
||||||
|
num_tokens = draft_token_ids.shape[0]
|
||||||
|
uniform_probs = torch.rand((num_tokens, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device)
|
||||||
|
num_draft_tokens = [max_spec_len] * batch_size
|
||||||
|
num_draft_tokens = torch.tensor(num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
cu_num_draft_tokens = torch.cumsum(num_draft_tokens,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32)
|
||||||
|
is_greedy_ptr = torch.full((batch_size, ),
|
||||||
|
False,
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
recovered_ids = torch.zeros_like(draft_token_ids,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||||
|
original_rejection_random_sample_kernel[(batch_size, )](
|
||||||
|
original_output_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
draft_probs,
|
||||||
|
target_probs,
|
||||||
|
bonus_token_ids,
|
||||||
|
recovered_ids,
|
||||||
|
uniform_probs,
|
||||||
|
is_greedy_ptr,
|
||||||
|
max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
|
)
|
||||||
|
rejection_random_sample_kernel[(grid, )](output_token_ids,
|
||||||
|
cu_num_draft_tokens,
|
||||||
|
draft_token_ids,
|
||||||
|
draft_probs,
|
||||||
|
target_probs,
|
||||||
|
bonus_token_ids,
|
||||||
|
recovered_ids,
|
||||||
|
uniform_probs,
|
||||||
|
is_greedy_ptr,
|
||||||
|
max_spec_len,
|
||||||
|
vocab_size,
|
||||||
|
batch_size,
|
||||||
|
NO_DRAFT_PROBS=draft_probs
|
||||||
|
is None,
|
||||||
|
BLOCK_SIZE=block_size)
|
||||||
|
torch.npu.synchronize()
|
||||||
|
assert torch.equal(original_output_token_ids, output_token_ids)
|
||||||
@@ -20,6 +20,17 @@ from vllm.triton_utils import tl, triton
|
|||||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||||
|
|
||||||
|
|
||||||
|
def cal_grid_and_block_size(batch_size: int):
|
||||||
|
vectorcore_num = get_vectorcore_num()
|
||||||
|
if batch_size <= vectorcore_num:
|
||||||
|
grid = batch_size
|
||||||
|
block_size = 1
|
||||||
|
else:
|
||||||
|
grid = vectorcore_num
|
||||||
|
block_size = triton.next_power_of_2(triton.cdiv(batch_size, grid))
|
||||||
|
return grid, block_size
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
def bonus_renew_1(
|
def bonus_renew_1(
|
||||||
bonus_token_ids_ptr,
|
bonus_token_ids_ptr,
|
||||||
@@ -131,62 +142,72 @@ def rejection_greedy_sample_triton(
|
|||||||
|
|
||||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
def rejection_random_sample_kernel(
|
def rejection_random_sample_kernel(
|
||||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||||
cu_num_draft_tokens_ptr, # [batch_size]
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
draft_token_ids_ptr, # [num_tokens]
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||||
target_probs_ptr, # [num_tokens, vocab_size]
|
target_probs_ptr, # [num_tokens, vocab_size]
|
||||||
bonus_token_ids_ptr, # [batch_size]
|
bonus_token_ids_ptr, # [batch_size]
|
||||||
recovered_token_ids_ptr, # [num_tokens]
|
recovered_token_ids_ptr, # [num_tokens]
|
||||||
uniform_probs_ptr, # [num_tokens]
|
uniform_probs_ptr, # [num_tokens]
|
||||||
is_greedy_ptr, # [batch_size]
|
is_greedy_ptr, # [batch_size]
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
NO_DRAFT_PROBS: tl.constexpr,
|
vec_len,
|
||||||
):
|
NO_DRAFT_PROBS: tl.constexpr,
|
||||||
req_idx = tl.program_id(0)
|
BLOCK_SIZE: tl.constexpr):
|
||||||
is_greedy = tl.load(is_greedy_ptr + req_idx)
|
block_idx = tl.program_id(0)
|
||||||
if is_greedy:
|
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
# Early exost for greedy sampling requests
|
mask = offsets < vec_len
|
||||||
return
|
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
|
||||||
|
not_greedy_mask = is_greedy == 0
|
||||||
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
|
start_idxs = tl.where(
|
||||||
req_idx - 1)
|
offsets == 0, 0,
|
||||||
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
|
||||||
num_draft_tokens = end_idx - start_idx
|
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||||
|
n_num_draft_tokens = end_idxs - start_idxs
|
||||||
rejected = False
|
for req_i in range(BLOCK_SIZE):
|
||||||
for pos in range(num_draft_tokens):
|
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
|
||||||
if not rejected:
|
if not_greedy:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
rejected = False
|
||||||
if NO_DRAFT_PROBS:
|
start_idx = tl.get_element(start_idxs, (req_i, ))
|
||||||
draft_prob = 1
|
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||||
else:
|
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
|
||||||
draft_prob = tl.load(draft_probs_ptr +
|
for pos in range(num_draft_tokens):
|
||||||
(start_idx + pos) * vocab_size +
|
if not rejected:
|
||||||
draft_token_id)
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx +
|
||||||
target_prob = tl.load(target_probs_ptr +
|
pos)
|
||||||
(start_idx + pos) * vocab_size +
|
if NO_DRAFT_PROBS:
|
||||||
draft_token_id)
|
draft_prob = 1
|
||||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
else:
|
||||||
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
draft_prob = tl.load(draft_probs_ptr +
|
||||||
# Accept
|
(start_idx + pos) * vocab_size +
|
||||||
token_id = draft_token_id
|
draft_token_id)
|
||||||
else:
|
target_prob = tl.load(target_probs_ptr +
|
||||||
# Reject. Use recovered token
|
(start_idx + pos) * vocab_size +
|
||||||
rejected = True
|
draft_token_id)
|
||||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
# NOTE(woosuk): While the draft probability should never be 0,
|
||||||
token_id)
|
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||||
|
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
|
||||||
if not rejected:
|
# Accept.
|
||||||
# If all tokens are accepted, append the bonus token
|
token_id = draft_token_id
|
||||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
else:
|
||||||
tl.store(
|
# Reject. Use recovered token.
|
||||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
rejected = True
|
||||||
num_draft_tokens,
|
token_id = tl.load(recovered_token_ids_ptr +
|
||||||
bonus_token_id,
|
start_idx + pos)
|
||||||
)
|
tl.store(
|
||||||
|
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||||
|
pos, token_id)
|
||||||
|
if not rejected:
|
||||||
|
# If all tokens are accepted, append the bonus token.
|
||||||
|
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||||
|
tl.store(
|
||||||
|
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||||
|
num_draft_tokens,
|
||||||
|
bonus_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
@triton.jit(do_not_specialize=["replace_from", "replace_to"])
|
||||||
@@ -311,24 +332,12 @@ def sample_recovered_tokens_kernel(
|
|||||||
orig_prob)
|
orig_prob)
|
||||||
|
|
||||||
|
|
||||||
def rejection_greedy_sample_with_triton(
|
def rejection_greedy_sample_with_triton(output_token_ids, num_draft_tokens,
|
||||||
output_token_ids,
|
cu_num_draft_tokens, draft_token_ids,
|
||||||
num_draft_tokens,
|
target_argmax, bonus_token_ids,
|
||||||
cu_num_draft_tokens,
|
is_greedy, max_spec_len, grid,
|
||||||
draft_token_ids,
|
block_size):
|
||||||
target_argmax,
|
|
||||||
bonus_token_ids,
|
|
||||||
is_greedy,
|
|
||||||
max_spec_len,
|
|
||||||
):
|
|
||||||
vec_len = output_token_ids.shape[0]
|
vec_len = output_token_ids.shape[0]
|
||||||
n = cu_num_draft_tokens.numel()
|
|
||||||
BLOCK_SIZE = 2
|
|
||||||
grid = triton.cdiv(n, BLOCK_SIZE)
|
|
||||||
vectorcore_num = get_vectorcore_num()
|
|
||||||
if n >= vectorcore_num:
|
|
||||||
grid = vectorcore_num # Empirically tuned value
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
|
|
||||||
|
|
||||||
if min(num_draft_tokens) == 1 and max(
|
if min(num_draft_tokens) == 1 and max(
|
||||||
num_draft_tokens) == 1 and is_greedy is None:
|
num_draft_tokens) == 1 and is_greedy is None:
|
||||||
@@ -338,7 +347,7 @@ def rejection_greedy_sample_with_triton(
|
|||||||
target_argmax,
|
target_argmax,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
vec_len,
|
vec_len,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=block_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rejection_greedy_sample_triton[(grid, )](
|
rejection_greedy_sample_triton[(grid, )](
|
||||||
@@ -350,20 +359,14 @@ def rejection_greedy_sample_with_triton(
|
|||||||
is_greedy,
|
is_greedy,
|
||||||
vec_len,
|
vec_len,
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
|
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
|
||||||
replace_to, max_num_tokens):
|
replace_to, max_num_tokens):
|
||||||
vec_len = batch_size
|
vec_len = batch_size
|
||||||
n = cu_num_tokens.numel()
|
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||||
BLOCK_SIZE = 2
|
|
||||||
grid = triton.cdiv(n, BLOCK_SIZE)
|
|
||||||
vectorcore_num = get_vectorcore_num()
|
|
||||||
if n >= vectorcore_num:
|
|
||||||
grid = vectorcore_num
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(triton.cdiv(n, grid))
|
|
||||||
|
|
||||||
expand_kernel[(grid, )](
|
expand_kernel[(grid, )](
|
||||||
expanded_x,
|
expanded_x,
|
||||||
@@ -373,5 +376,5 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
|
|||||||
replace_to,
|
replace_to,
|
||||||
vec_len,
|
vec_len,
|
||||||
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
|
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
BLOCK_SIZE=block_size,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE,
|
|||||||
generate_uniform_probs)
|
generate_uniform_probs)
|
||||||
|
|
||||||
from vllm_ascend.ops.triton.reject_sample import (
|
from vllm_ascend.ops.triton.reject_sample import (
|
||||||
expand_triton, rejection_greedy_sample_with_triton,
|
cal_grid_and_block_size, expand_triton,
|
||||||
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
|
rejection_greedy_sample_with_triton, rejection_random_sample_kernel,
|
||||||
|
sample_recovered_tokens_kernel)
|
||||||
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
from vllm_ascend.sample.sampler import apply_top_k_top_p
|
||||||
|
|
||||||
PLACEHOLDER_TOKEN_ID = -1
|
PLACEHOLDER_TOKEN_ID = -1
|
||||||
@@ -119,20 +120,18 @@ def rejection_sample(
|
|||||||
is_greedy = None
|
is_greedy = None
|
||||||
else:
|
else:
|
||||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||||
|
if HAS_TRITON:
|
||||||
|
grid, block_size = cal_grid_and_block_size(batch_size)
|
||||||
if not sampling_metadata.all_random:
|
if not sampling_metadata.all_random:
|
||||||
# Rejection sampling for greedy sampling requests.
|
# Rejection sampling for greedy sampling requests.
|
||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_probs.argmax(dim=-1)
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
rejection_greedy_sample_with_triton(
|
rejection_greedy_sample_with_triton(output_token_ids,
|
||||||
output_token_ids,
|
num_draft_tokens,
|
||||||
num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
cu_num_draft_tokens,
|
draft_token_ids, target_argmax,
|
||||||
draft_token_ids,
|
bonus_token_ids, is_greedy,
|
||||||
target_argmax,
|
max_spec_len, grid, block_size)
|
||||||
bonus_token_ids,
|
|
||||||
is_greedy,
|
|
||||||
max_spec_len,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if min(num_draft_tokens) == 1 and max(
|
if min(num_draft_tokens) == 1 and max(
|
||||||
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
|
||||||
@@ -180,7 +179,7 @@ def rejection_sample(
|
|||||||
|
|
||||||
# Rejection sampling for random sampling requests.
|
# Rejection sampling for random sampling requests.
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
rejection_random_sample_kernel[(batch_size, )](
|
rejection_random_sample_kernel[(grid, )](
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
@@ -192,7 +191,9 @@ def rejection_sample(
|
|||||||
is_greedy,
|
is_greedy,
|
||||||
max_spec_len,
|
max_spec_len,
|
||||||
vocab_size,
|
vocab_size,
|
||||||
|
batch_size,
|
||||||
NO_DRAFT_PROBS=draft_probs is None,
|
NO_DRAFT_PROBS=draft_probs is None,
|
||||||
|
BLOCK_SIZE=block_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
rejection_random_sample_pytorch(
|
rejection_random_sample_pytorch(
|
||||||
|
|||||||
Reference in New Issue
Block a user