From ed4ef1f4e7ca0cce827025ad566aafed07f7bcf8 Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Tue, 31 Mar 2026 19:01:51 +0800 Subject: [PATCH] [releases/v0.18.0][Triton][Sampler] Add penalty-related Triton kernel for better performance of penalties (#7794) ### What this PR does / why we need it? Implement get_token_bin_counts_and_mask and apply_penalties with Triton-Ascend kernels. This significantly reduces latency of the sampling process when repetition/frequency/presence penalties are enabled. Cherry-pick from main PR #7569 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: realliujiaxu --- .../triton/test_apply_penalties_triton.py | 110 ++++++++++++ vllm_ascend/ops/triton/bincount.py | 139 +++++++++++++++ vllm_ascend/ops/triton/penalty.py | 158 ++++++++++++++++++ vllm_ascend/sample/penalties.py | 45 +++++ vllm_ascend/sample/sampler.py | 25 +++ 5 files changed, 477 insertions(+) create mode 100644 tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_apply_penalties_triton.py create mode 100644 vllm_ascend/ops/triton/bincount.py create mode 100644 vllm_ascend/ops/triton/penalty.py create mode 100644 vllm_ascend/sample/penalties.py diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_apply_penalties_triton.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_apply_penalties_triton.py new file mode 100644 index 00000000..625f2bce --- /dev/null +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_apply_penalties_triton.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# Compare vllm_ascend.sample.penalties.apply_all_penalties (Triton-Ascend) with +# vllm.v1.sample.ops.penalties.apply_all_penalties (PyTorch via model_executor). +# Requires NPU and Triton-Ascend. + +import gc +import pytest +import torch + +from vllm.v1.sample.ops.penalties import apply_all_penalties as v1_apply_all_penalties +from vllm_ascend.sample.penalties import apply_all_penalties as ascend_apply_all_penalties + +# Same scenario grid as test_apply_penalties_model_executor (equivalence + boundaries). +APPLY_PENALTY_CASES = [ + pytest.param(0, 0, "mixed", id="empty-both"), + pytest.param(0, 16, "mixed", id="empty-prompt"), + pytest.param(32, 0, "mixed", id="empty-output"), + pytest.param(1, 1, "mixed", id="single-token-each"), + pytest.param(32, 16, "mixed", id="typical-small"), + pytest.param(128, 64, "mixed", id="typical-large"), + pytest.param(128, 64, "all_padding", id="all-padding"), +] + + +def _make_tokens( + num_seqs: int, + seq_len: int, + vocab_size: int, + mode: str, + device: str, +) -> torch.Tensor: + if mode == "all_padding": + return torch.full( + (num_seqs, seq_len), vocab_size, device=device, dtype=torch.int64 + ) + if seq_len == 0: + return torch.empty((num_seqs, 0), device=device, dtype=torch.int64) + tokens = torch.randint( + 0, vocab_size, (num_seqs, seq_len), device=device, dtype=torch.int64 + ) + pad_mask = torch.rand(num_seqs, seq_len, device=device) > 0.7 + tokens[pad_mask] = vocab_size + return tokens + + +@pytest.mark.parametrize("num_seqs", [1, 8, 32, 128]) +@pytest.mark.parametrize("vocab_size", [5120, 151936]) +@pytest.mark.parametrize( + "max_prompt_len,max_output_len,token_mode", + APPLY_PENALTY_CASES, +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_apply_all_penalties_v1_vs_ascend( + num_seqs, + vocab_size, + max_prompt_len, + max_output_len, + token_mode, + dtype, + device="npu", + seed=42, +): + from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton + + init_device_properties_triton() + torch.manual_seed(seed) + + logits_v1 = torch.randn(num_seqs, vocab_size, device=device, dtype=dtype) + logits_ascend = logits_v1.clone() + + prompt_tokens = _make_tokens( + num_seqs, max_prompt_len, vocab_size, token_mode, device + ) + output_tokens = _make_tokens( + num_seqs, max_output_len, vocab_size, token_mode, device + ) + output_token_ids = [row.tolist() for row in output_tokens.cpu()] + + presence_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2 + frequency_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.2 + repetition_penalties = torch.rand(num_seqs, device=device, dtype=torch.float32) * 0.4 + 1.0 + + v1_apply_all_penalties( + logits_v1, + prompt_tokens, + presence_penalties, + frequency_penalties, + repetition_penalties, + output_token_ids, + ) + ascend_apply_all_penalties( + logits_ascend, + prompt_tokens, + presence_penalties, + frequency_penalties, + repetition_penalties, + output_token_ids, + ) + + atol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3 + assert torch.allclose( + logits_ascend.float(), logits_v1.float(), atol=atol, rtol=rtol + ), ( + f"Max diff: {(logits_ascend.float() - logits_v1.float()).abs().max().item()}" + ) + gc.collect() + torch.npu.empty_cache() + torch.npu.reset_peak_memory_stats() diff --git a/vllm_ascend/ops/triton/bincount.py b/vllm_ascend/ops/triton/bincount.py new file mode 100644 index 00000000..20c586ab --- /dev/null +++ b/vllm_ascend/ops/triton/bincount.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Triton-Ascend implementation of get_token_bin_counts_and_mask. +# Migrated from model_executor/layers/utils.get_token_bin_counts_and_mask. +# Reference: https://github.com/vllm-project/vllm-ascend/pull/6979 + +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit +def token_bin_counts_and_mask_kernel( + tokens_ptr, + tokens_batch_stride, + tokens_seq_stride, + batch_size, + seq_len, + vocab_size, + bin_counts_ptr, + counts_batch_stride, + counts_vocab_stride, + SEQ_BLOCK: tl.constexpr, +): + """Count token occurrences per batch row. + + 2D tiling: + - axis=0: core/program group dimension + - axis=1: block id dimension + + We linearize (batch_idx, seq_block_id) into a single global block id and + distribute blocks across all programs to improve utilization when + batch_size is small but seq_len is large (typical prefill). + + Tokens with value >= vocab_size (e.g. padding) are skipped. + """ + pid0 = tl.program_id(axis=0) + pid1 = tl.program_id(axis=1) + progs = tl.num_programs(axis=0) + + n_seq_blocks = tl.cdiv(seq_len, SEQ_BLOCK) + linear_block = pid1 * progs + pid0 + total_blocks = batch_size * n_seq_blocks + if linear_block >= total_blocks: + return + + batch_idx = linear_block // n_seq_blocks + seq_block_id = linear_block - batch_idx * n_seq_blocks + seq_start = seq_block_id * SEQ_BLOCK + + batch_tokens_start = tokens_ptr + batch_idx * tokens_batch_stride + batch_counts_start = bin_counts_ptr + batch_idx * counts_batch_stride + + pos_offsets = seq_start + tl.arange(0, SEQ_BLOCK) + pos_mask = pos_offsets < seq_len + token = tl.load( + batch_tokens_start + pos_offsets * tokens_seq_stride, + mask=pos_mask, + other=vocab_size, # force invalid + ) + # Only count valid token ids in [0, vocab_size). Padding must use id >= vocab_size + # (see vLLM apply_penalties contract); those positions are masked out here. + token_in_range = (token >= 0) & (token < vocab_size) & pos_mask + count_ptr = batch_counts_start + token * counts_vocab_stride + tl.atomic_add(count_ptr, 1, mask=token_in_range) + + +def get_token_bin_counts_and_mask_triton( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Triton-Ascend implementation of token bin counting. + + Args: + tokens: [num_seqs, seq_len] tensor of token IDs. Padding value + should be vocab_size and will be ignored. + vocab_size: Vocabulary size. + num_seqs: If provided, asserts tokens.shape[0] == num_seqs. + + Returns: + bin_counts: [num_seqs, vocab_size] int32 counts. + mask: [num_seqs, vocab_size] bool, True where count > 0. + """ + n_rows, n_cols = tokens.shape + if num_seqs is not None and num_seqs > 0: + assert n_rows == num_seqs, f"tokens rows must match num_seqs: tokens.shape[0]={n_rows}, num_seqs={num_seqs}" + n_rows = num_seqs if num_seqs is not None else n_rows + + # seq_len == 0 is valid for empty decode history; return directly. + if n_cols == 0: + bin_counts = torch.zeros((n_rows, vocab_size), dtype=torch.int32, device=tokens.device) + return bin_counts, bin_counts > 0 + + core_num = get_vectorcore_num() + + bin_counts = torch.zeros((n_rows, vocab_size), dtype=torch.int32, device=tokens.device) + if not tokens.is_contiguous(): + tokens = tokens.contiguous() + + # 2D grid: (progs, blocks_per_prog_group) + # Keep axis-0 bounded by vector core count, and distribute (batch, seq_block) + # blocks across all programs to increase utilization when n_rows is small. + SEQ_BLOCK = 256 + n_seq_blocks = triton.cdiv(n_cols, SEQ_BLOCK) + total_blocks = n_rows * n_seq_blocks + progs = min(core_num, total_blocks) + grid = (progs, triton.cdiv(total_blocks, progs)) + + token_bin_counts_and_mask_kernel[grid]( + tokens, + tokens.stride(0), + tokens.stride(1), + n_rows, + n_cols, + vocab_size, + bin_counts, + bin_counts.stride(0), + bin_counts.stride(1), + SEQ_BLOCK=SEQ_BLOCK, + ) + return bin_counts, bin_counts > 0 diff --git a/vllm_ascend/ops/triton/penalty.py b/vllm_ascend/ops/triton/penalty.py new file mode 100644 index 00000000..2db719fa --- /dev/null +++ b/vllm_ascend/ops/triton/penalty.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +# Triton-Ascend implementation of apply_penalties. +# Migrated from model_executor/layers/utils.apply_penalties. +# Reference: https://github.com/vllm-project/vllm-ascend/pull/6979 + +import torch +from vllm.triton_utils import tl, triton + +from vllm_ascend.ops.triton.bincount import get_token_bin_counts_and_mask_triton +from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num + + +@triton.jit +def apply_all_penalties_kernel( + logits_ptr, + prompt_mask_ptr, + output_mask_ptr, + output_bin_counts_ptr, + repetition_penalties_ptr, + frequency_penalties_ptr, + presence_penalties_ptr, + num_seqs, + vocab_size, + stride_logits_seq, + stride_logits_vocab, + stride_prompt_mask_seq, + stride_prompt_mask_vocab, + stride_output_mask_seq, + stride_output_mask_vocab, + stride_bin_counts_seq, + stride_bin_counts_vocab, + BLOCK_SIZE: tl.constexpr, +): + """Apply repetition, frequency, and presence penalties to logits in place.""" + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + seqs_per_program = (num_seqs + num_programs - 1) // num_programs + + start_seq = pid * seqs_per_program + end_seq = tl.minimum(start_seq + seqs_per_program, num_seqs) + + for seq_idx in range(start_seq, end_seq): + repetition_penalty = tl.load(repetition_penalties_ptr + seq_idx) + frequency_penalty = tl.load(frequency_penalties_ptr + seq_idx) + presence_penalty = tl.load(presence_penalties_ptr + seq_idx) + + for vocab_start in range(0, vocab_size, BLOCK_SIZE): + vocab_offsets = vocab_start + tl.arange(0, BLOCK_SIZE) + mask = vocab_offsets < vocab_size + + logits_offset = seq_idx * stride_logits_seq + vocab_offsets * stride_logits_vocab + prompt_mask_offset = seq_idx * stride_prompt_mask_seq + vocab_offsets * stride_prompt_mask_vocab + output_mask_offset = seq_idx * stride_output_mask_seq + vocab_offsets * stride_output_mask_vocab + counts_offset = seq_idx * stride_bin_counts_seq + vocab_offsets * stride_bin_counts_vocab + + logits = tl.load(logits_ptr + logits_offset, mask=mask, other=0.0) + prompt_mask_val = tl.load(prompt_mask_ptr + prompt_mask_offset, mask=mask, other=False) + output_mask_val = tl.load(output_mask_ptr + output_mask_offset, mask=mask, other=False) + output_bin_counts = tl.load( + output_bin_counts_ptr + counts_offset, + mask=mask, + other=0, + ).to(tl.float32) + + need_repetition_penalty = (prompt_mask_val | output_mask_val).to(tl.int1) + penalty_factor = tl.where(need_repetition_penalty, repetition_penalty, 1.0) + scaling = tl.where( + (logits > 0.0).to(tl.int1), + 1.0 / penalty_factor, + penalty_factor, + ) + updated = logits * scaling + + updated -= frequency_penalty * output_bin_counts + updated -= presence_penalty * output_mask_val.to(tl.float32) + tl.store(logits_ptr + logits_offset, updated, mask=mask) + + +def apply_penalties_triton( + logits: torch.Tensor, + prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, +) -> torch.Tensor: + """Apply penalties to logits in place. Same interface as + model_executor.layers.utils.apply_penalties. + """ + num_seqs, vocab_size = logits.shape + _, prompt_mask = get_token_bin_counts_and_mask_triton(prompt_tokens_tensor, vocab_size, num_seqs) + output_bin_counts, output_mask = get_token_bin_counts_and_mask_triton( + output_tokens_tensor, + vocab_size, + num_seqs, + ) + _apply_all_penalties_triton( + logits, + prompt_mask, + output_mask, + output_bin_counts, + repetition_penalties, + frequency_penalties, + presence_penalties, + ) + return logits + + +def _apply_all_penalties_triton( + logits: torch.Tensor, + prompt_mask: torch.Tensor, + output_mask: torch.Tensor, + output_bin_counts: torch.Tensor, + repetition_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + presence_penalties: torch.Tensor, +) -> None: + """Apply all penalties given precomputed bin counts and masks.""" + num_seqs, vocab_size = logits.shape + grid = (min(num_seqs, get_vectorcore_num()), 1, 1) + + apply_all_penalties_kernel[grid]( + logits, + prompt_mask, + output_mask, + output_bin_counts, + repetition_penalties, + frequency_penalties, + presence_penalties, + num_seqs=num_seqs, + vocab_size=vocab_size, + stride_logits_seq=logits.stride(0), + stride_logits_vocab=logits.stride(1), + stride_prompt_mask_seq=prompt_mask.stride(0), + stride_prompt_mask_vocab=prompt_mask.stride(1), + stride_output_mask_seq=output_mask.stride(0), + stride_output_mask_vocab=output_mask.stride(1), + stride_bin_counts_seq=output_bin_counts.stride(0), + stride_bin_counts_vocab=output_bin_counts.stride(1), + BLOCK_SIZE=2048, + ) diff --git a/vllm_ascend/sample/penalties.py b/vllm_ascend/sample/penalties.py new file mode 100644 index 00000000..18d176b0 --- /dev/null +++ b/vllm_ascend/sample/penalties.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# apply_all_penalties for AscendSampler - uses Triton-Ascend kernels. + +import torch +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.utils.torch_utils import make_tensor_with_pad + +from vllm_ascend.ops.triton.penalty import apply_penalties_triton + + +def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor: + """Convert output_token_ids (list of lists) to padded tensor.""" + output_tokens_tensor = make_tensor_with_pad( + output_token_ids, + pad=vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=is_pin_memory_available(), + ) + return output_tokens_tensor.to(device, non_blocking=True) + + +def apply_all_penalties( + logits: torch.Tensor, + prompt_token_ids: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor, + output_token_ids: list[list[int]], +) -> torch.Tensor: + """Apply penalties to logits via Triton-Ascend.""" + _, vocab_size = logits.shape + output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device) + output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size) + + return apply_penalties_triton( + logits, + prompt_token_ids, + output_tokens_t, + presence_penalties, + frequency_penalties, + repetition_penalties, + ) diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 5a73f951..68082152 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -1,9 +1,12 @@ import torch from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant +from vllm.triton_utils import HAS_TRITON +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.sampler import Sampler from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.sample.penalties import apply_all_penalties from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, global_stream, npu_stream_switch DEFAULT_LOGPROBS_MODE = "raw_logprobs" @@ -36,6 +39,28 @@ def random_sample( class AscendSampler(Sampler): + @staticmethod + def apply_penalties( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + output_token_ids: list[list[int]], + ) -> torch.Tensor: + """Use Triton-Ascend penalties on NPU when Triton is available; else vLLM default.""" + if not HAS_TRITON: + return Sampler.apply_penalties(logits, sampling_metadata, output_token_ids) + + if sampling_metadata.no_penalties: + return logits + assert sampling_metadata.prompt_token_ids is not None + return apply_all_penalties( + logits, + sampling_metadata.prompt_token_ids, + sampling_metadata.presence_penalties, + sampling_metadata.frequency_penalties, + sampling_metadata.repetition_penalties, + output_token_ids, + ) + def __init__(self, logprobs_mode=DEFAULT_LOGPROBS_MODE): # TODO: support logprobs_mode in vllm-ascend super().__init__(logprobs_mode=logprobs_mode)