Sync from v0.13
This commit is contained in:
@@ -1,320 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Sampling parameters for text generation."""
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from typing_extensions import Annotated
|
||||
import copy
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Any
|
||||
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import LogitsProcessor
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.serial_utils import PydanticMsgspecMixin
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_MAX_TEMP = 1e-2
|
||||
|
||||
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
RANDOM_SEED = 2
|
||||
BEAM = 3
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
"""LogitsProcessor is a function that takes a list of previously generated
|
||||
tokens and a tensor of the logits for the next token, and returns a modified
|
||||
tensor of logits to sample from."""
|
||||
# maybe make msgspec?
|
||||
@dataclass
|
||||
class StructuredOutputsParams:
|
||||
# One of these fields will be used to build a logit processor.
|
||||
json: str | dict | None = None
|
||||
regex: str | None = None
|
||||
choice: list[str] | None = None
|
||||
grammar: str | None = None
|
||||
json_object: bool | None = None
|
||||
# These are other options that can be set.
|
||||
disable_fallback: bool = False
|
||||
disable_any_whitespace: bool = False
|
||||
disable_additional_properties: bool = False
|
||||
whitespace_pattern: str | None = None
|
||||
structural_tag: str | None = None
|
||||
|
||||
_backend: str | None = field(default=None, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
_backend_was_auto: bool = field(default=False, init=False)
|
||||
"""CAUTION: Should only be set by Processor._validate_structured_output"""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that some fields are mutually exclusive."""
|
||||
count = sum(
|
||||
[
|
||||
self.json is not None,
|
||||
self.regex is not None,
|
||||
self.choice is not None,
|
||||
self.grammar is not None,
|
||||
self.json_object is not None,
|
||||
self.structural_tag is not None,
|
||||
]
|
||||
)
|
||||
if count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of structured outputs constraint "
|
||||
f"but multiple are specified: {self.__dict__}"
|
||||
)
|
||||
|
||||
def all_constraints_none(self) -> bool:
|
||||
"""
|
||||
Returns True if all structured-output constraint fields are None.
|
||||
"""
|
||||
return all(
|
||||
getattr(self, field) is None
|
||||
for field in (
|
||||
"json",
|
||||
"regex",
|
||||
"choice",
|
||||
"grammar",
|
||||
"json_object",
|
||||
"structural_tag",
|
||||
)
|
||||
)
|
||||
|
||||
def all_non_structural_tag_constraints_none(self) -> bool:
|
||||
"""
|
||||
Returns True if all structured-output constraint fields are None.
|
||||
"""
|
||||
return all(
|
||||
getattr(self, field) is None
|
||||
for field in (
|
||||
"json",
|
||||
"regex",
|
||||
"choice",
|
||||
"grammar",
|
||||
"json_object",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
# Return only deltas in each RequestOutput
|
||||
DELTA = 1
|
||||
# Do not return intermediate RequestOutput
|
||||
FINAL_ONLY = 2
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
PydanticMsgspecMixin,
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
): # type: ignore[call-arg]
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
Overall, we follow the sampling parameters from the OpenAI text completion
|
||||
API (https://platform.openai.com/docs/api-reference/completions/create).
|
||||
In addition, we support beam search, which is not supported by OpenAI.
|
||||
|
||||
Args:
|
||||
n: Number of output sequences to return for the given prompt.
|
||||
best_of: Number of output sequences that are generated from the prompt.
|
||||
From these `best_of` sequences, the top `n` sequences are returned.
|
||||
`best_of` must be greater than or equal to `n`. This is treated as
|
||||
the beam width when `use_beam_search` is True. By default, `best_of`
|
||||
is set to `n`.
|
||||
presence_penalty: Float that penalizes new tokens based on whether they
|
||||
appear in the generated text so far. Values > 0 encourage the model
|
||||
to use new tokens, while values < 0 encourage the model to repeat
|
||||
tokens.
|
||||
frequency_penalty: Float that penalizes new tokens based on their
|
||||
frequency in the generated text so far. Values > 0 encourage the
|
||||
model to use new tokens, while values < 0 encourage the model to
|
||||
repeat tokens.
|
||||
repetition_penalty: Float that penalizes new tokens based on whether
|
||||
they appear in the prompt and the generated text so far. Values > 1
|
||||
encourage the model to use new tokens, while values < 1 encourage
|
||||
the model to repeat tokens.
|
||||
temperature: Float that controls the randomness of the sampling. Lower
|
||||
values make the model more deterministic, while higher values make
|
||||
the model more random. Zero means greedy sampling.
|
||||
top_p: Float that controls the cumulative probability of the top tokens
|
||||
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
|
||||
top_k: Integer that controls the number of top tokens to consider. Set
|
||||
to -1 to consider all tokens.
|
||||
min_p: Float that represents the minimum probability for a token to be
|
||||
considered, relative to the probability of the most likely token.
|
||||
Must be in [0, 1]. Set to 0 to disable this.
|
||||
seed: Random seed to use for the generation.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
early_stopping: Controls the stopping condition for beam search. It
|
||||
accepts the following values: `True`, where the generation stops as
|
||||
soon as there are `best_of` complete candidates; `False`, where an
|
||||
heuristic is applied and the generation stops when is it very
|
||||
unlikely to find better candidates; `"never"`, where the beam search
|
||||
procedure only stops when there cannot be better candidates
|
||||
(canonical beam search algorithm).
|
||||
stop: List of strings that stop the generation when they are generated.
|
||||
The returned output will not contain the stop strings.
|
||||
stop_token_ids: List of tokens that stop the generation when they are
|
||||
generated. The returned output will contain the stop tokens unless
|
||||
the stop tokens are special tokens.
|
||||
include_stop_str_in_output: Whether to include the stop strings in
|
||||
output text. Defaults to False.
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
tokens, as well the chosen tokens. The API will always return the
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n: int = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
n: int = 1
|
||||
"""Number of outputs to return for the given prompt request.
|
||||
|
||||
NOTE:
|
||||
`AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs
|
||||
are generated and streamed cumulatively per request. To see all `n`
|
||||
outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY`
|
||||
in `SamplingParams`."""
|
||||
presence_penalty: float = 0.0
|
||||
"""Penalizes new tokens based on whether they appear in the generated text
|
||||
so far. Values > 0 encourage the model to use new tokens, while values < 0
|
||||
encourage the model to repeat tokens."""
|
||||
frequency_penalty: float = 0.0
|
||||
"""Penalizes new tokens based on their frequency in the generated text so
|
||||
far. Values > 0 encourage the model to use new tokens, while values < 0
|
||||
encourage the model to repeat tokens."""
|
||||
repetition_penalty: float = 1.0
|
||||
"""Penalizes new tokens based on whether they appear in the prompt and the
|
||||
generated text so far. Values > 1 encourage the model to use new tokens,
|
||||
while values < 1 encourage the model to repeat tokens."""
|
||||
temperature: float = 1.0
|
||||
"""Controls the randomness of the sampling. Lower values make the model
|
||||
more deterministic, while higher values make the model more random. Zero
|
||||
means greedy sampling."""
|
||||
top_p: float = 1.0
|
||||
"""Controls the cumulative probability of the top tokens to consider. Must
|
||||
be in (0, 1]. Set to 1 to consider all tokens."""
|
||||
top_k: int = 0
|
||||
"""Controls the number of top tokens to consider. Set to 0 (or -1) to
|
||||
consider all tokens."""
|
||||
min_p: float = 0.0
|
||||
"""Represents the minimum probability for a token to be considered,
|
||||
relative to the probability of the most likely token. Must be in [0, 1].
|
||||
Set to 0 to disable this."""
|
||||
seed: int | None = None
|
||||
"""Random seed to use for the generation."""
|
||||
stop: str | list[str] | None = None
|
||||
"""String(s) that stop the generation when they are generated. The returned
|
||||
output will not contain the stop strings."""
|
||||
stop_token_ids: list[int] | None = None
|
||||
"""Token IDs that stop the generation when they are generated. The returned
|
||||
output will contain the stop tokens unless the stop tokens are special
|
||||
tokens."""
|
||||
ignore_eos: bool = False
|
||||
"""Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated."""
|
||||
max_tokens: int | None = 16
|
||||
"""Maximum number of tokens to generate per output sequence."""
|
||||
min_tokens: int = 0
|
||||
"""Minimum number of tokens to generate per output sequence before EOS or
|
||||
`stop_token_ids` can be generated"""
|
||||
logprobs: int | None = None
|
||||
"""Number of log probabilities to return per output token. When set to
|
||||
`None`, no probability is returned. If set to a non-`None` value, the
|
||||
result includes the log probabilities of the specified number of most
|
||||
likely tokens, as well as the chosen tokens. Note that the implementation
|
||||
follows the OpenAI API: The API will always return the log probability of
|
||||
the sampled token, so there may be up to `logprobs+1` elements in the
|
||||
response. When set to -1, return all `vocab_size` log probabilities."""
|
||||
prompt_logprobs: int | None = None
|
||||
"""Number of log probabilities to return per prompt token.
|
||||
When set to -1, return all `vocab_size` log probabilities."""
|
||||
flat_logprobs: bool = False
|
||||
"""Whether to return logprobs in flatten format (i.e. FlatLogprob)
|
||||
for better performance.
|
||||
NOTE: GC costs of FlatLogprobs is significantly smaller than
|
||||
list[dict[int, Logprob]]. After enabled, PromptLogprobs and
|
||||
SampleLogprobs would populated as FlatLogprobs."""
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
detokenize: bool = True
|
||||
"""Whether to detokenize the output."""
|
||||
skip_special_tokens: bool = True
|
||||
"""Whether to skip special tokens in the output."""
|
||||
spaces_between_special_tokens: bool = True
|
||||
"""Whether to add spaces between special tokens in the output."""
|
||||
# `list[LogitsProcessor] | None` type. We use Any here because
|
||||
# `list[LogitsProcessor] | None` type is not supported by msgspec.
|
||||
logits_processors: Any | None = None
|
||||
"""Functions that modify logits based on previously generated tokens, and
|
||||
optionally prompt tokens as a first argument."""
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
||||
|
||||
# The below fields are not supposed to be used as an input.
|
||||
# They are set in post_init.
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
# Fields used to construct logits processors
|
||||
structured_outputs: StructuredOutputsParams | None = None
|
||||
"""Parameters for configuring structured outputs."""
|
||||
logit_bias: dict[int, float] | None = None
|
||||
"""If provided, the engine will construct a logits processor that applies
|
||||
these logit biases."""
|
||||
allowed_token_ids: list[int] | None = None
|
||||
"""If provided, the engine will construct a logits processor which only
|
||||
retains scores for the given token ids."""
|
||||
extra_args: dict[str, Any] | None = None
|
||||
"""Arbitrary additional args, that can be used by custom sampling
|
||||
implementations, plugins, etc. Not used by any in-tree sampling
|
||||
implementations."""
|
||||
|
||||
# Fields used for bad words
|
||||
bad_words: list[str] | None = None
|
||||
"""Words that are not allowed to be generated. More precisely, only the
|
||||
last token of a corresponding token sequence is not allowed when the next
|
||||
generated token can complete the sequence."""
|
||||
_bad_words_token_ids: list[list[int]] | None = None
|
||||
|
||||
skip_reading_prefix_cache: bool | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
presence_penalty: float | None = 0.0,
|
||||
frequency_penalty: float | None = 0.0,
|
||||
repetition_penalty: float | None = 1.0,
|
||||
temperature: float | None = 1.0,
|
||||
top_p: float | None = 1.0,
|
||||
top_k: int = 0,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
bad_words: list[str] | None = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
max_tokens: int | None = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
logprobs: int | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
if seed == -1:
|
||||
logits_processors: list[LogitsProcessor] | None = None,
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
structured_outputs: StructuredOutputsParams | None = None,
|
||||
logit_bias: dict[int, float] | dict[str, float] | None = None,
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
logit_bias = {
|
||||
int(token): min(100.0, max(-100.0, bias))
|
||||
for token, bias in logit_bias.items()
|
||||
}
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
presence_penalty=0.0 if presence_penalty is None else presence_penalty,
|
||||
frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty,
|
||||
repetition_penalty=1.0
|
||||
if repetition_penalty is None
|
||||
else repetition_penalty,
|
||||
temperature=1.0 if temperature is None else temperature,
|
||||
top_p=1.0 if top_p is None else top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
bad_words=bad_words,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
detokenize=detokenize,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
extra_args=extra_args,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
"temperature %s is less than %s, which may cause numerical "
|
||||
"errors nan or inf in tensors. We have maxed it out to %s.",
|
||||
self.temperature,
|
||||
_MAX_TEMP,
|
||||
_MAX_TEMP,
|
||||
)
|
||||
self.temperature = max(self.temperature, _MAX_TEMP)
|
||||
|
||||
if self.seed == -1:
|
||||
self.seed = None
|
||||
else:
|
||||
self.seed = seed
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
if stop is None:
|
||||
|
||||
if self.stop is None:
|
||||
self.stop = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = list(stop)
|
||||
if stop_token_ids is None:
|
||||
elif isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
if self.stop_token_ids is None:
|
||||
self.stop_token_ids = []
|
||||
else:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
# NOTE: This parameter is only exposed at the engine level for now.
|
||||
# It is not exposed in the OpenAI API server, as the OpenAI API does
|
||||
# not support returning only a list of token IDs.
|
||||
self.detokenize = detokenize
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
if self.bad_words is None:
|
||||
self.bad_words = []
|
||||
|
||||
if self.logprobs is True:
|
||||
self.logprobs = 1
|
||||
|
||||
if self.prompt_logprobs is True:
|
||||
self.prompt_logprobs = 1
|
||||
|
||||
# Number of characters to hold back for stop string evaluation
|
||||
# until sequence is finished.
|
||||
if self.stop and not include_stop_str_in_output:
|
||||
if self.stop and not self.include_stop_str_in_output:
|
||||
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
|
||||
else:
|
||||
self.output_text_buffer_length = 0
|
||||
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
else:
|
||||
self._verify_non_beam_search()
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = 0
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
# eos_token_id is added to this by the engine
|
||||
self.all_stop_token_ids = set(self.stop_token_ids)
|
||||
self._all_stop_token_ids.update(self.stop_token_ids)
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of prompt logprobs may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
self.skip_reading_prefix_cache = self.prompt_logprobs is not None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if not isinstance(self.n, int):
|
||||
raise ValueError(f"n must be an int, but is of type {type(self.n)}")
|
||||
if self.n < 1:
|
||||
raise ValueError(f"n must be at least 1, got {self.n}.")
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
if not -2.0 <= self.presence_penalty <= 2.0:
|
||||
raise ValueError("presence_penalty must be in [-2, 2], got "
|
||||
f"{self.presence_penalty}.")
|
||||
raise ValueError(
|
||||
f"presence_penalty must be in [-2, 2], got {self.presence_penalty}."
|
||||
)
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if not 0.0 < self.repetition_penalty <= 2.0:
|
||||
raise ValueError("repetition_penalty must be in (0, 2], got "
|
||||
f"{self.repetition_penalty}.")
|
||||
raise ValueError(
|
||||
f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}."
|
||||
)
|
||||
if self.repetition_penalty <= 0.0:
|
||||
raise ValueError(
|
||||
"repetition_penalty must be greater than zero, got "
|
||||
f"{self.repetition_penalty}."
|
||||
)
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
f"temperature must be non-negative, got {self.temperature}."
|
||||
)
|
||||
if not 0.0 < self.top_p <= 1.0:
|
||||
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
||||
f"got {self.top_k}.")
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError("min_p must be in [0, 1], got "
|
||||
f"{self.min_p}.")
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
# quietly accept -1 as disabled, but prefer 0
|
||||
if self.top_k < -1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
f"top_k must be 0 (disable), or at least 1, got {self.top_k}."
|
||||
)
|
||||
if not isinstance(self.top_k, int):
|
||||
raise TypeError(
|
||||
f"top_k must be an integer, got {type(self.top_k).__name__}"
|
||||
)
|
||||
if not 0.0 <= self.min_p <= 1.0:
|
||||
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
raise ValueError(
|
||||
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
|
||||
)
|
||||
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if (self.truncate_prompt_tokens is not None
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}."
|
||||
)
|
||||
if (
|
||||
self.prompt_logprobs is not None
|
||||
and self.prompt_logprobs != -1
|
||||
and self.prompt_logprobs < 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"prompt_logprobs must be non-negative or -1, got "
|
||||
f"{self.prompt_logprobs}."
|
||||
)
|
||||
if self.truncate_prompt_tokens is not None and (
|
||||
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
|
||||
):
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||
f"got {self.truncate_prompt_tokens}"
|
||||
)
|
||||
assert isinstance(self.stop_token_ids, list)
|
||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||
raise ValueError(
|
||||
f"stop_token_ids must contain only integers, got {self.stop_token_ids}."
|
||||
)
|
||||
assert isinstance(self.stop, list)
|
||||
if any(not stop_str for stop_str in self.stop):
|
||||
raise ValueError("stop cannot contain an empty string.")
|
||||
if self.stop and not self.detokenize:
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
raise ValueError("best_of must be greater than 1 when using beam "
|
||||
f"search. Got {self.best_of}.")
|
||||
if self.temperature > _SAMPLING_EPS:
|
||||
raise ValueError("temperature must be 0 when using beam search.")
|
||||
if self.top_p < 1.0 - _SAMPLING_EPS:
|
||||
raise ValueError("top_p must be 1 when using beam search.")
|
||||
if self.top_k != -1:
|
||||
raise ValueError("top_k must be -1 when using beam search.")
|
||||
if self.early_stopping not in [True, False, "never"]:
|
||||
raise ValueError(
|
||||
f"early_stopping must be True, False, or 'never', "
|
||||
f"got {self.early_stopping}.")
|
||||
|
||||
def _verify_non_beam_search(self) -> None:
|
||||
if self.early_stopping is not False:
|
||||
raise ValueError("early_stopping is not effective and must be "
|
||||
"False when not using beam search.")
|
||||
if (self.length_penalty < 1.0 - _SAMPLING_EPS
|
||||
or self.length_penalty > 1.0 + _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
"length_penalty is not effective and must be the "
|
||||
"default value of 1.0 when not using beam search.")
|
||||
"Set detokenize=True to use stop."
|
||||
)
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.best_of > 1:
|
||||
raise ValueError("best_of must be 1 when using greedy sampling."
|
||||
f"Got {self.best_of}.")
|
||||
if self.n > 1:
|
||||
raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.")
|
||||
|
||||
def update_from_generation_config(
|
||||
self, generation_config: Dict[str, Any]) -> None:
|
||||
self,
|
||||
generation_config: dict[str, Any],
|
||||
model_eos_token_id: int | None = None,
|
||||
) -> None:
|
||||
"""Update if there are non-default values from generation_config"""
|
||||
|
||||
if model_eos_token_id is not None:
|
||||
# Add the eos token id into the sampling_params to support
|
||||
# min_tokens processing.
|
||||
self._all_stop_token_ids.add(model_eos_token_id)
|
||||
|
||||
# Update eos_token_id for generation
|
||||
if (not self.ignore_eos) and (eos_ids :=
|
||||
generation_config.get("eos_token_id")):
|
||||
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
||||
# it can be either int or list of int
|
||||
if isinstance(eos_ids, int):
|
||||
eos_ids = [eos_ids]
|
||||
original_stop_token_ids = set(self.stop_token_ids)
|
||||
original_stop_token_ids.update(eos_ids)
|
||||
self.stop_token_ids = list(original_stop_token_ids)
|
||||
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||
if model_eos_token_id is not None:
|
||||
# We don't need to include the primary eos_token_id in
|
||||
# stop_token_ids since it's handled separately for stopping
|
||||
# purposes.
|
||||
eos_ids.discard(model_eos_token_id)
|
||||
if eos_ids:
|
||||
self._all_stop_token_ids.update(eos_ids)
|
||||
if not self.ignore_eos:
|
||||
eos_ids.update(self.stop_token_ids)
|
||||
self.stop_token_ids = list(eos_ids)
|
||||
|
||||
def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
|
||||
if not self.bad_words:
|
||||
return
|
||||
self._bad_words_token_ids = []
|
||||
for bad_word in self.bad_words:
|
||||
# To prohibit words both at the beginning
|
||||
# and in the middle of text
|
||||
# (related to add_prefix_space tokenizer parameter)
|
||||
for add_prefix_space in [False, True]:
|
||||
prefix = " " if add_prefix_space else ""
|
||||
prompt = prefix + bad_word.lstrip()
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
text=prompt, add_special_tokens=False
|
||||
)
|
||||
|
||||
# If no space at the beginning
|
||||
# or if prefix space produces a new word token
|
||||
if (not add_prefix_space) or (
|
||||
add_prefix_space
|
||||
and prompt_token_ids[0] != self._bad_words_token_ids[-1][0]
|
||||
and len(prompt_token_ids) == len(self._bad_words_token_ids[-1])
|
||||
):
|
||||
self._bad_words_token_ids.append(prompt_token_ids)
|
||||
|
||||
invalid_token_ids = [
|
||||
token_id
|
||||
for bad_words_token_ids in self._bad_words_token_ids
|
||||
for token_id in bad_words_token_ids
|
||||
if token_id < 0 or token_id > tokenizer.max_token_id
|
||||
]
|
||||
if len(invalid_token_ids) > 0:
|
||||
raise ValueError(
|
||||
f"The model vocabulary size is {tokenizer.max_token_id + 1},"
|
||||
f" but the following tokens"
|
||||
f" were specified as bad: {invalid_token_ids}."
|
||||
f" All token id values should be integers satisfying:"
|
||||
f" 0 <= token_id <= {tokenizer.max_token_id}."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def sampling_type(self) -> SamplingType:
|
||||
if self.use_beam_search:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
if self.seed is not None:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
@property
|
||||
def all_stop_token_ids(self) -> set[int]:
|
||||
return self._all_stop_token_ids
|
||||
|
||||
LogitsProcessor objects are excluded because they may contain an
|
||||
arbitrary, nontrivial amount of data.
|
||||
@property
|
||||
def bad_words_token_ids(self) -> list[list[int]] | None:
|
||||
# For internal use only. Backward compatibility not guaranteed
|
||||
return self._bad_words_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||
|
||||
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
||||
data that is expensive to copy. However, if not copied, the processor
|
||||
needs to support parallel decoding for multiple sequences
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
"""
|
||||
|
||||
logit_processor_refs = None if self.logits_processors is None else {
|
||||
id(lp): lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
logit_processor_refs = (
|
||||
None
|
||||
if self.logits_processors is None
|
||||
else {
|
||||
id(lp): lp.clone() if hasattr(lp, "clone") else lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
)
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"SamplingParams(n={self.n}, "
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
@@ -323,11 +563,9 @@ class SamplingParams:
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"seed={self.seed}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
f"stop={self.stop}, "
|
||||
f"stop_token_ids={self.stop_token_ids}, "
|
||||
f"bad_words={self.bad_words}, "
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
@@ -337,4 +575,23 @@ class SamplingParams:
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
||||
f"structured_outputs={self.structured_outputs}, "
|
||||
f"extra_args={self.extra_args})"
|
||||
)
|
||||
|
||||
|
||||
class BeamSearchParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True,
|
||||
): # type: ignore[call-arg]
|
||||
"""Beam search parameters for text generation."""
|
||||
|
||||
beam_width: int
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
temperature: float = 0.0
|
||||
length_penalty: float = 1.0
|
||||
include_stop_str_in_output: bool = False
|
||||
|
||||
Reference in New Issue
Block a user