92 lines
3.8 KiB
Python
92 lines
3.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
|
|
|
|
|
def get_token_bin_counts_and_mask(
|
|
tokens: torch.Tensor,
|
|
vocab_size: int,
|
|
num_seqs: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Compute the bin counts for the tokens.
|
|
# vocab_size + 1 for padding.
|
|
bin_counts = torch.zeros((num_seqs, vocab_size + 1),
|
|
dtype=torch.long,
|
|
device=tokens.device)
|
|
bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
|
|
bin_counts = bin_counts[:, :vocab_size]
|
|
mask = bin_counts > 0
|
|
|
|
return bin_counts, mask
|
|
|
|
def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
|
|
output_tokens_tensor: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
repetition_penalties: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Applies penalties in place to the logits tensor
|
|
logits : The input logits tensor of shape [num_seqs, vocab_size]
|
|
prompt_tokens_tensor: A tensor containing the prompt tokens. The prompts
|
|
are padded to the maximum prompt length within the batch using
|
|
`vocab_size` as the padding value. The value `vocab_size` is used
|
|
for padding because it does not correspond to any valid token ID
|
|
in the vocabulary.
|
|
output_tokens_tensor: The output tokens tensor.
|
|
presence_penalties: The presence penalties of shape (num_seqs, )
|
|
frequency_penalties: The frequency penalties of shape (num_seqs, )
|
|
repetition_penalties: The repetition penalties of shape (num_seqs, )
|
|
"""
|
|
num_seqs, vocab_size = logits.shape
|
|
_, prompt_mask = get_token_bin_counts_and_mask(prompt_tokens_tensor,
|
|
vocab_size, num_seqs)
|
|
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
|
|
output_tokens_tensor, vocab_size, num_seqs)
|
|
|
|
# Apply repetition penalties as a custom op
|
|
from vllm._custom_ops import apply_repetition_penalties_torch
|
|
apply_repetition_penalties_torch(logits, prompt_mask, output_mask,
|
|
repetition_penalties)
|
|
|
|
# We follow the definition in OpenAI API.
|
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
|
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
|
|
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
|
|
return logits
|
|
|
|
def apply_all_penalties(
|
|
logits: torch.Tensor,
|
|
prompt_token_ids: torch.Tensor,
|
|
presence_penalties: torch.Tensor,
|
|
frequency_penalties: torch.Tensor,
|
|
repetition_penalties: torch.Tensor,
|
|
output_token_ids: list[list[int]],
|
|
) -> torch.Tensor:
|
|
"""
|
|
Applies presence, frequency and repetition penalties to the logits.
|
|
"""
|
|
_, vocab_size = logits.shape
|
|
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
|
logits.device)
|
|
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
|
presence_penalties, frequency_penalties,
|
|
repetition_penalties)
|
|
|
|
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
|
device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Convert the different list data structures to tensors.
|
|
"""
|
|
output_tokens_tensor = make_tensor_with_pad(
|
|
output_token_ids,
|
|
# Use the value of vocab_size as a pad since we don't have a
|
|
# token_id of this value.
|
|
pad=vocab_size,
|
|
device="cpu",
|
|
dtype=torch.int64,
|
|
pin_memory=is_pin_memory_available(),
|
|
)
|
|
return output_tokens_tensor.to(device, non_blocking=True)
|