[releases/v0.18.0][Triton][Sampler] Add penalty-related Triton kernel for better performance of penalties (#7794)

### What this PR does / why we need it?
Implement get_token_bin_counts_and_mask and apply_penalties with
Triton-Ascend kernels. This significantly reduces latency of the
sampling process when repetition/frequency/presence penalties are
enabled.

Cherry-pick from main PR #7569 
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
linfeng-yuan
2026-03-31 19:01:51 +08:00
committed by GitHub
parent 82e26b5a6e
commit ed4ef1f4e7
5 changed files with 477 additions and 0 deletions

View File

@@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# Compare vllm_ascend.sample.penalties.apply_all_penalties (Triton-Ascend) with
# vllm.v1.sample.ops.penalties.apply_all_penalties (PyTorch via model_executor).
# Requires NPU and Triton-Ascend.
import gc
import pytest
import torch
from vllm.v1.sample.ops.penalties import apply_all_penalties as v1_apply_all_penalties
from vllm_ascend.sample.penalties import apply_all_penalties as ascend_apply_all_penalties
# Same scenario grid as test_apply_penalties_model_executor (equivalence + boundaries).
APPLY_PENALTY_CASES = [
pytest.param(0, 0, "mixed", id="empty-both"),
pytest.param(0, 16, "mixed", id="empty-prompt"),
pytest.param(32, 0, "mixed", id="empty-output"),
pytest.param(1, 1, "mixed", id="single-token-each"),
pytest.param(32, 16, "mixed", id="typical-small"),
pytest.param(128, 64, "mixed", id="typical-large"),
pytest.param(128, 64, "all_padding", id="all-padding"),
]
def _make_tokens(
num_seqs: int,
seq_len: int,
vocab_size: int,
mode: str,
device: str,
) -> torch.Tensor:
if mode == "all_padding":
return torch.full(
(num_seqs, seq_len), vocab_size, device=device, dtype=torch.int64
)
if seq_len == 0:
return torch.empty((num_seqs, 0), device=device, dtype=torch.int64)
tokens = torch.randint(
0, vocab_size, (num_seqs, seq_len), device=device, dtype=torch.int64
)
pad_mask = torch.rand(num_seqs, seq_len, device=device) > 0.7
tokens[pad_mask] = vocab_size
return tokens
@pytest.mark.parametrize("num_seqs", [1, 8, 32, 128])
@pytest.mark.parametrize("vocab_size", [5120, 151936])
@pytest.mark.parametrize(
"max_prompt_len,max_output_len,token_mode",
APPLY_PENALTY_CASES,
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@torch.inference_mode()
def test_apply_all_penalties_v1_vs_ascend(
num_seqs,
vocab_size,
max_prompt_len,
max_output_len,
token_mode,
dtype,
device="npu",
seed=42,
):
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
init_device_properties_triton()
torch.manual_seed(seed)
logits_v1 = torch.randn(num_seqs, vocab_size, device=device, dtype=dtype)
logits_ascend = logits_v1.clone()
prompt_tokens = _make_tokens(
num_seqs, max_prompt_len, vocab_size, token_mode, device
)
output_tokens = _make_tokens(
num_seqs, max_output_len, vocab_size, token_mode, device
)
output_token_ids = [row.tolist() for row in output_tokens.cpu()]
presence_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2
frequency_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2
repetition_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.4 + 1.0
v1_apply_all_penalties(
logits_v1,
prompt_tokens,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
ascend_apply_all_penalties(
logits_ascend,
prompt_tokens,
presence_penalties,
frequency_penalties,
repetition_penalties,
output_token_ids,
)
atol = 1e-2 if dtype == torch.bfloat16 else 1e-3
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
assert torch.allclose(
logits_ascend.float(), logits_v1.float(), atol=atol, rtol=rtol
), (
f"Max diff: {(logits_ascend.float() - logits_v1.float()).abs().max().item()}"
)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()