feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259)

### What this PR does / why we need it?

This PR introduces optimized Triton implementations for the
rejection_random_sample_kernel delivering superior performance compared
to the existing Triton implementations. The new Triton kernels maintain
full functional accuracy while delivering significant performance
improvements across various batch sizes and MTP configurations.

### Does this PR introduce _any_ user-facing change?

Yes, this PR modifies rejection_sampler.py to use optimized Triton
kernels:
rejection_random_sample_kernel is modified and optimized

### How was this patch tested?
performance benchmark results:
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">
<head>

<meta name=Generator content="Microsoft Excel">
<!--[if !mso]>
</head>
<body>
<!--StartFragment-->

Batch Size | MTP | origin implementation(us) | optimized version(us)
-- | -- | -- | --
1 | 1 | 2.934 | 3.64
8 | 1 | 4.467 | 4
32 | 1 | 6.98 | 4.54
64 | 1 | 11.087 | 6.42
128 | 1 | 13.414 | 7.84
256 | 1 | 19.66 | 8.487
512 | 1 | 39.908 | 11.62
1024 | 1 | 81.781 | 18.16
2048 | 1 | 137.923 | 32.934
1 | 2 | 3.4 | 4.02
8 | 2 | 3.74 | 4.24
32 | 2 | 6.373 | 7.394
64 | 2 | 9.747 | 6.46
128 | 2 | 12.98 | 7.76
256 | 2 | 20.834 | 9.787
512 | 2 | 39.314 | 13.56
1024 | 2 | 83.135 | 22.387
2048 | 2 | 157.563 | 40.607


<!--EndFragment-->
</body>

</html>


- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c

Signed-off-by: 1024daniel <xxltju324@gmail.com>
This commit is contained in:
daniel
2026-01-05 16:03:02 +08:00
committed by GitHub
parent 91bf524364
commit 8ffe3f5d78
3 changed files with 195 additions and 96 deletions

View File

@@ -8,8 +8,9 @@ from vllm.v1.sample.rejection_sampler import (GREEDY_TEMPERATURE,
generate_uniform_probs)
from vllm_ascend.ops.triton.reject_sample import (
expand_triton, rejection_greedy_sample_with_triton,
rejection_random_sample_kernel, sample_recovered_tokens_kernel)
cal_grid_and_block_size, expand_triton,
rejection_greedy_sample_with_triton, rejection_random_sample_kernel,
sample_recovered_tokens_kernel)
from vllm_ascend.sample.sampler import apply_top_k_top_p
PLACEHOLDER_TOKEN_ID = -1
@@ -119,20 +120,18 @@ def rejection_sample(
is_greedy = None
else:
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if HAS_TRITON:
grid, block_size = cal_grid_and_block_size(batch_size)
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
if HAS_TRITON:
rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
)
rejection_greedy_sample_with_triton(output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids, target_argmax,
bonus_token_ids, is_greedy,
max_spec_len, grid, block_size)
else:
if min(num_draft_tokens) == 1 and max(
num_draft_tokens) == 1 and sampling_metadata.all_greedy:
@@ -180,7 +179,7 @@ def rejection_sample(
# Rejection sampling for random sampling requests.
if HAS_TRITON:
rejection_random_sample_kernel[(batch_size, )](
rejection_random_sample_kernel[(grid, )](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -192,7 +191,9 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs is None,
BLOCK_SIZE=block_size,
)
else:
rejection_random_sample_pytorch(