Files
xc-llm-ascend/tests/e2e/nightly/ops/triton/test_rejection_sample.py

96 lines
4.2 KiB
Python
Raw Normal View History

feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
import pytest
import torch
from vllm.v1.sample.rejection_sampler import \
rejection_random_sample_kernel as original_rejection_random_sample_kernel
from vllm_ascend.ops.triton.reject_sample import (
cal_grid_and_block_size, rejection_random_sample_kernel)
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
@pytest.fixture(scope="function", autouse=True)
def setup_device_properties():
init_device_properties_triton()
yield
@pytest.mark.parametrize("max_spec_len", [1, 2, 3])
@pytest.mark.parametrize("vocab_size", [151_936])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 64, 128, 256, 512, 1024])
@torch.inference_mode()
def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
device = 'npu'
torch.manual_seed(0)
draft_probs = torch.rand(batch_size * max_spec_len,
vocab_size,
dtype=torch.float32,
device=device)
target_probs = torch.rand(batch_size * max_spec_len,
vocab_size,
dtype=torch.float32,
device=device)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device=device)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size * max_spec_len, ),
dtype=torch.int64,
device=device)
output_token_ids = torch.empty((batch_size, max_spec_len + 1),
dtype=torch.int64,
device=device)
original_output_token_ids = output_token_ids.clone()
num_tokens = draft_token_ids.shape[0]
uniform_probs = torch.rand((num_tokens, ),
dtype=torch.float32,
device=device)
num_draft_tokens = [max_spec_len] * batch_size
num_draft_tokens = torch.tensor(num_draft_tokens,
dtype=torch.int32,
device=device)
cu_num_draft_tokens = torch.cumsum(num_draft_tokens,
dim=0,
dtype=torch.int32)
is_greedy_ptr = torch.full((batch_size, ),
False,
dtype=torch.bool,
device=device)
recovered_ids = torch.zeros_like(draft_token_ids,
dtype=torch.int64,
device=device)
grid, block_size = cal_grid_and_block_size(batch_size)
original_rejection_random_sample_kernel[(batch_size, )](
original_output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_ids,
uniform_probs,
is_greedy_ptr,
max_spec_len,
vocab_size,
NO_DRAFT_PROBS=draft_probs is None,
)
rejection_random_sample_kernel[(grid, )](output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_ids,
uniform_probs,
is_greedy_ptr,
max_spec_len,
vocab_size,
batch_size,
NO_DRAFT_PROBS=draft_probs
is None,
BLOCK_SIZE=block_size)
torch.npu.synchronize()
assert torch.equal(original_output_token_ids, output_token_ids)