[releases/v0.18.0][Triton][Sampler] Add penalty-related Triton kernel for better performance of penalties (#7794)

### What this PR does / why we need it?
Implement get_token_bin_counts_and_mask and apply_penalties with
Triton-Ascend kernels. This significantly reduces latency of the
sampling process when repetition/frequency/presence penalties are
enabled.

Cherry-pick from main PR #7569 
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
linfeng-yuan
2026-03-31 19:01:51 +08:00
committed by GitHub
parent 82e26b5a6e
commit ed4ef1f4e7
5 changed files with 477 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# apply_all_penalties for AscendSampler - uses Triton-Ascend kernels.
import torch
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import make_tensor_with_pad
from vllm_ascend.ops.triton.penalty import apply_penalties_triton
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor:
"""Convert output_token_ids (list of lists) to padded tensor."""
output_tokens_tensor = make_tensor_with_pad(
output_token_ids,
pad=vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=is_pin_memory_available(),
)
return output_tokens_tensor.to(device, non_blocking=True)
def apply_all_penalties(
logits: torch.Tensor,
prompt_token_ids: torch.Tensor,
presence_penalties: torch.Tensor,
frequency_penalties: torch.Tensor,
repetition_penalties: torch.Tensor,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""Apply penalties to logits via Triton-Ascend."""
_, vocab_size = logits.shape
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
return apply_penalties_triton(
logits,
prompt_token_ids,
output_tokens_t,
presence_penalties,
frequency_penalties,
repetition_penalties,
)

View File

@@ -1,9 +1,12 @@
import torch
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.sampler import Sampler
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.sample.penalties import apply_all_penalties
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch
DEFAULT_LOGPROBS_MODE = "raw_logprobs"
@@ -36,6 +39,28 @@ def random_sample(
class AscendSampler(Sampler):
@staticmethod
def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
output_token_ids: list[list[int]],
) -> torch.Tensor:
"""Use Triton-Ascend penalties on NPU when Triton is available; else vLLM default."""
if not HAS_TRITON:
return Sampler.apply_penalties(logits, sampling_metadata, output_token_ids)
if sampling_metadata.no_penalties:
return logits
assert sampling_metadata.prompt_token_ids is not None
return apply_all_penalties(
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
output_token_ids,
)
def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE):
# TODO: support logprobs_mode in vllm-ascend
super().__init__(logprobs_mode=logprobs_mode)