66 lines
3.0 KiB
Python
66 lines
3.0 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
import torch
|
|
|
|
import vllm
|
|
from vllm.model_executor.layers.utils import get_token_bin_counts_and_mask
|
|
|
|
|
|
def apply_penalties_fit(logits: torch.Tensor,
|
|
prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Applies penalties in place to the logits tensor
|
|
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
|
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
|
are padded to the maximum prompt length within the batch using
|
|
`vocab_size` as the padding value. The value `vocab_size` is used
|
|
for padding because it does not correspond to any valid token ID
|
|
in the vocabulary.
|
|
output_tokens_tensor: The output tokens tensor.
|
|
presence_penalties: The presence penalties of shape (num_seqs, )
|
|
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
|
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
|
"""
|
|
num_seqs, vocab_size = logits.shape
|
|
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
|
vocab_size, num_seqs)
|
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
|
output_tokens_tensor, vocab_size, num_seqs)
|
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
|
1, vocab_size)
|
|
|
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
|
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
|
|
1.0)
|
|
|
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
|
logits *= scaling
|
|
|
|
# We follow the definition in OpenAI API.
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
|
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
|
return logits
|
|
|
|
|
|
vllm.model_executor.layers.utils.apply_penalties = apply_penalties_fit
|