Sync from v0.13
This commit is contained in:
121
vllm/logits_process.py
Normal file
121
vllm/logits_process.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
LogitsProcessor: TypeAlias = (
|
||||
Callable[[list[int], torch.Tensor], torch.Tensor]
|
||||
| Callable[[list[int], list[int], torch.Tensor], torch.Tensor]
|
||||
)
|
||||
"""LogitsProcessor is a function that takes a list
|
||||
of previously generated tokens, the logits tensor
|
||||
for the next token and, optionally, prompt tokens as a
|
||||
first argument, and returns a modified tensor of logits
|
||||
to sample from."""
|
||||
|
||||
|
||||
def get_bad_words_logits_processors(
|
||||
bad_words: list[str], tokenizer: TokenizerLike
|
||||
) -> list[LogitsProcessor]:
|
||||
bad_words_ids: list[list[int]] = list()
|
||||
|
||||
for bad_word in bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
|
||||
prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space
|
||||
and prompt_token_ids[0] != bad_words_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(bad_words_ids[-1])
|
||||
):
|
||||
bad_words_ids.append(prompt_token_ids)
|
||||
|
||||
return [NoBadWordsLogitsProcessor(bad_words_ids=bad_words_ids)]
|
||||
|
||||
|
||||
class NoBadWordsLogitsProcessor:
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
_NEUTRAL_LOGIT = 0.0
|
||||
|
||||
def __init__(self, bad_words_ids: list[list[int]]):
|
||||
self.bad_words_ids = bad_words_ids
|
||||
self.word_bias: torch.FloatTensor = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
past_tokens_ids: Sequence[int],
|
||||
logits: torch.FloatTensor,
|
||||
) -> torch.Tensor:
|
||||
if self.word_bias is None:
|
||||
self._init_word_bias(logits=logits)
|
||||
|
||||
last_token_bias = torch.zeros_like(logits)
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
if len(bad_word_ids) == 1: # 1-token words already processed
|
||||
continue
|
||||
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
is_match = tuple(actual_prefix) == tuple(expected_prefix)
|
||||
last_token_bias[last_token_id] += (
|
||||
self._SMALLEST_LOGIT if is_match else self._NEUTRAL_LOGIT
|
||||
)
|
||||
|
||||
logits = logits + self.word_bias + last_token_bias
|
||||
|
||||
return logits
|
||||
|
||||
def _init_word_bias(self, logits: torch.FloatTensor) -> None:
|
||||
# Code based on NoBadWordsLogitsProcessor and SequenceBiasLogitsProcessor # noqa: E501
|
||||
# from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
|
||||
self._check_token_ids_bounds(vocab_size=vocab_size)
|
||||
|
||||
self.word_bias = torch.zeros(
|
||||
(vocab_size,), dtype=torch.float, device=logits.device
|
||||
)
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
if len(bad_word_ids) == 1:
|
||||
bad_word_id = bad_word_ids[-1]
|
||||
self.word_bias[bad_word_id] = self._SMALLEST_LOGIT
|
||||
|
||||
def _check_token_ids_bounds(self, vocab_size: int) -> None:
|
||||
invalid_token_ids = []
|
||||
|
||||
for bad_word_ids in self.bad_words_ids:
|
||||
for token_id in bad_word_ids:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {vocab_size},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id < {vocab_size}."
|
||||
)
|
||||
Reference in New Issue
Block a user