### What this PR does / why we need it? Implement get_token_bin_counts_and_mask and apply_penalties with Triton-Ascend kernels. This significantly reduces latency of the sampling process when repetition/frequency/presence penalties are enabled. Cherry-pick from main PR #7569 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: realliujiaxu <realliujiaxu@163.com>
159 lines
5.8 KiB
Python
159 lines
5.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Triton-Ascend implementation of apply_penalties.
|
|
# Migrated from model_executor/layers/utils.apply_penalties.
|
|
# Reference: https://github.com/vllm-project/vllm-ascend/pull/6979
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from vllm_ascend.ops.triton.bincount import get_token_bin_counts_and_mask_triton
|
|
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
|
|
|
|
|
@triton.jit
|
|
def apply_all_penalties_kernel(
|
|
logits_ptr,
|
|
prompt_mask_ptr,
|
|
output_mask_ptr,
|
|
output_bin_counts_ptr,
|
|
repetition_penalties_ptr,
|
|
frequency_penalties_ptr,
|
|
presence_penalties_ptr,
|
|
num_seqs,
|
|
vocab_size,
|
|
stride_logits_seq,
|
|
stride_logits_vocab,
|
|
stride_prompt_mask_seq,
|
|
stride_prompt_mask_vocab,
|
|
stride_output_mask_seq,
|
|
stride_output_mask_vocab,
|
|
stride_bin_counts_seq,
|
|
stride_bin_counts_vocab,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
"""Apply repetition, frequency, and presence penalties to logits in place."""
|
|
pid = tl.program_id(axis=0)
|
|
num_programs = tl.num_programs(axis=0)
|
|
seqs_per_program = (num_seqs + num_programs - 1) // num_programs
|
|
|
|
start_seq = pid * seqs_per_program
|
|
end_seq = tl.minimum(start_seq + seqs_per_program, num_seqs)
|
|
|
|
for seq_idx in range(start_seq, end_seq):
|
|
repetition_penalty = tl.load(repetition_penalties_ptr + seq_idx)
|
|
frequency_penalty = tl.load(frequency_penalties_ptr + seq_idx)
|
|
presence_penalty = tl.load(presence_penalties_ptr + seq_idx)
|
|
|
|
for vocab_start in range(0, vocab_size, BLOCK_SIZE):
|
|
vocab_offsets = vocab_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = vocab_offsets < vocab_size
|
|
|
|
logits_offset = seq_idx * stride_logits_seq + vocab_offsets * stride_logits_vocab
|
|
prompt_mask_offset = seq_idx * stride_prompt_mask_seq + vocab_offsets * stride_prompt_mask_vocab
|
|
output_mask_offset = seq_idx * stride_output_mask_seq + vocab_offsets * stride_output_mask_vocab
|
|
counts_offset = seq_idx * stride_bin_counts_seq + vocab_offsets * stride_bin_counts_vocab
|
|
|
|
logits = tl.load(logits_ptr + logits_offset, mask=mask, other=0.0)
|
|
prompt_mask_val = tl.load(prompt_mask_ptr + prompt_mask_offset, mask=mask, other=False)
|
|
output_mask_val = tl.load(output_mask_ptr + output_mask_offset, mask=mask, other=False)
|
|
output_bin_counts = tl.load(
|
|
output_bin_counts_ptr + counts_offset,
|
|
mask=mask,
|
|
other=0,
|
|
).to(tl.float32)
|
|
|
|
need_repetition_penalty = (prompt_mask_val | output_mask_val).to(tl.int1)
|
|
penalty_factor = tl.where(need_repetition_penalty, repetition_penalty, 1.0)
|
|
scaling = tl.where(
|
|
(logits > 0.0).to(tl.int1),
|
|
1.0 / penalty_factor,
|
|
penalty_factor,
|
|
)
|
|
updated = logits * scaling
|
|
|
|
updated -= frequency_penalty * output_bin_counts
|
|
updated -= presence_penalty * output_mask_val.to(tl.float32)
|
|
tl.store(logits_ptr + logits_offset, updated, mask=mask)
|
|
|
|
|
|
def apply_penalties_triton(
|
|
logits: torch.Tensor,
|
|
prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
repetition_penalties: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Apply penalties to logits in place. Same interface as
|
|
model_executor.layers.utils.apply_penalties.
|
|
"""
|
|
num_seqs, vocab_size = logits.shape
|
|
_, prompt_mask = get_token_bin_counts_and_mask_triton(prompt_tokens_tensor, vocab_size, num_seqs)
|
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask_triton(
|
|
output_tokens_tensor,
|
|
vocab_size,
|
|
num_seqs,
|
|
)
|
|
_apply_all_penalties_triton(
|
|
logits,
|
|
prompt_mask,
|
|
output_mask,
|
|
output_bin_counts,
|
|
repetition_penalties,
|
|
frequency_penalties,
|
|
presence_penalties,
|
|
)
|
|
return logits
|
|
|
|
|
|
def _apply_all_penalties_triton(
|
|
logits: torch.Tensor,
|
|
prompt_mask: torch.Tensor,
|
|
output_mask: torch.Tensor,
|
|
output_bin_counts: torch.Tensor,
|
|
repetition_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
) -> None:
|
|
"""Apply all penalties given precomputed bin counts and masks."""
|
|
num_seqs, vocab_size = logits.shape
|
|
grid = (min(num_seqs, get_vectorcore_num()), 1, 1)
|
|
|
|
apply_all_penalties_kernel[grid](
|
|
logits,
|
|
prompt_mask,
|
|
output_mask,
|
|
output_bin_counts,
|
|
repetition_penalties,
|
|
frequency_penalties,
|
|
presence_penalties,
|
|
num_seqs=num_seqs,
|
|
vocab_size=vocab_size,
|
|
stride_logits_seq=logits.stride(0),
|
|
stride_logits_vocab=logits.stride(1),
|
|
stride_prompt_mask_seq=prompt_mask.stride(0),
|
|
stride_prompt_mask_vocab=prompt_mask.stride(1),
|
|
stride_output_mask_seq=output_mask.stride(0),
|
|
stride_output_mask_vocab=output_mask.stride(1),
|
|
stride_bin_counts_seq=output_bin_counts.stride(0),
|
|
stride_bin_counts_vocab=output_bin_counts.stride(1),
|
|
BLOCK_SIZE=2048,
|
|
)
|