Files
bi_150-vllm/vllm/renderers/params.py

384 lines
14 KiB
Python
Raw Permalink Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, TypeVar
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
else:
torch = LazyLoader("torch", globals(), "torch")
ChatTemplateContentFormatOption = object
logger = init_logger(__name__)
_S = TypeVar("_S", list[int], "torch.Tensor")
def merge_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
/,
*,
unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
if defaults is None:
defaults = {}
if overrides is None:
overrides = {}
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
@dataclass(frozen=True)
class ChatParams:
"""Configuration to control how to parse chat messages."""
chat_template: str | None = None
"""The chat template to apply."""
chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
"""The format of the chat template."""
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
if not default_chat_template_kwargs:
return self
return ChatParams(
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
chat_template_kwargs=merge_kwargs(
default_chat_template_kwargs,
self.chat_template_kwargs,
),
)
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.apply_chat_template`."""
return merge_kwargs(
self.chat_template_kwargs,
dict(chat_template=self.chat_template, return_dict=False),
)
@dataclass(frozen=True)
class TokenizeParams:
"""Configuration to control how prompts are tokenized."""
max_total_tokens: int | None
"""
Maximum allowed number of input + output tokens.
Usually, this refers to the model's context length.
"""
max_output_tokens: int = 0
"""Maximum requested number of output tokens."""
pad_prompt_tokens: int | None = None
"""
Number of tokens to pad to:
- `None` means no padding.
- `-1` maps to `max_input_tokens`.
"""
truncate_prompt_tokens: int | None = None
"""
Number of tokens to keep:
- `None` means no truncation.
- `-1` maps to `max_input_tokens`.
"""
do_lower_case: bool = False
"""Whether to normalize text to lower case before tokenization."""
add_special_tokens: bool = True
"""Whether to add special tokens."""
needs_detokenization: bool = False
"""
Whether the tokenized prompt needs to contain the original text.
Not to be confused with `SamplingParams.detokenize` which deals
with the output generated by the model.
"""
max_total_tokens_param: str = "max_total_tokens"
"""Override this to edit the message for validation errors."""
max_output_tokens_param: str = "max_output_tokens"
"""Override this to edit the message for validation errors."""
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
"""Override this to edit the message for validation errors."""
@property
def max_input_tokens(self) -> int | None:
"""Maximum allowed number of input tokens."""
if self.max_total_tokens is None:
return None
return self.max_total_tokens - self.max_output_tokens
def __post_init__(self) -> None:
max_total_tokens = self.max_total_tokens
max_output_tokens = self.max_output_tokens
max_input_tokens = self.max_input_tokens
truncate_prompt_tokens = self.truncate_prompt_tokens
if (
max_output_tokens is not None
and max_total_tokens is not None
and max_output_tokens > max_total_tokens
):
raise VLLMValidationError(
f"{self.max_output_tokens_param}={max_output_tokens}"
f"cannot be greater than "
f"{self.max_total_tokens_param}={max_total_tokens=}. "
f"Please request fewer output tokens.",
parameter=self.max_output_tokens_param,
value=max_output_tokens,
)
if (
max_input_tokens is not None
and truncate_prompt_tokens is not None
and truncate_prompt_tokens > max_input_tokens
):
raise VLLMValidationError(
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
f"cannot be greater than {self.max_total_tokens_param} - "
f"{self.max_output_tokens_param} = {max_input_tokens}. "
f"Please request a smaller truncation size.",
parameter=self.truncate_prompt_tokens_param,
value=truncate_prompt_tokens,
)
def with_kwargs(self, **tokenization_kwargs: Any):
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
pad_prompt_tokens = tokenization_kwargs.pop(
"pad_prompt_tokens", self.pad_prompt_tokens
)
truncate_prompt_tokens = tokenization_kwargs.pop(
"truncate_prompt_tokens", self.truncate_prompt_tokens
)
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
add_special_tokens = tokenization_kwargs.pop(
"add_special_tokens", self.add_special_tokens
)
needs_detokenization = tokenization_kwargs.pop(
"needs_detokenization", self.needs_detokenization
)
# https://huggingface.co/docs/transformers/en/pad_truncation
if padding := tokenization_kwargs.pop("padding", None):
if padding == "max_length":
pad_prompt_tokens = max_length
elif padding in (False, "do_not_pad"):
pad_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["padding"] = padding
if truncation := tokenization_kwargs.pop("truncation", None):
if truncation in (True, "longest_first"):
truncate_prompt_tokens = max_length
elif truncation in (False, "do_not_truncate"):
truncate_prompt_tokens = None
else:
# To emit the below warning
tokenization_kwargs["truncation"] = truncation
if tokenization_kwargs:
logger.warning(
"The following tokenization arguments are not supported "
"by vLLM Renderer and will be ignored: %s",
tokenization_kwargs,
)
max_total_tokens = self.max_total_tokens
return TokenizeParams(
max_total_tokens=max_total_tokens,
max_output_tokens=(
0
if max_total_tokens is None or max_length is None
else max_total_tokens - max_length
),
pad_prompt_tokens=pad_prompt_tokens,
truncate_prompt_tokens=truncate_prompt_tokens,
do_lower_case=do_lower_case,
add_special_tokens=add_special_tokens,
needs_detokenization=needs_detokenization,
)
def get_encode_kwargs(self) -> dict[str, Any]:
"""The arguments to pass to `tokenizer.encode`."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
elif max_length is None and self.max_input_tokens is not None:
# This prevents tokenization from taking up more resources than necessary
# while still failing `self._token_len_check` as expected by users
max_length = self.max_input_tokens + 1
return dict(
truncation=max_length is not None,
max_length=max_length,
add_special_tokens=self.add_special_tokens,
)
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply length checks to prompt text if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return text
if self.truncate_prompt_tokens is None and tokenizer is not None:
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
if len(text) > max_input_chars:
# To save resources, fail the request outright without even
# attempting tokenization
raise VLLMValidationError(
f"You passed {len(text)} input characters "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens "
f"(at most {max_input_chars} characters). "
f"Please reduce the length of the input prompt.",
parameter="input_text",
value=len(text),
)
return text
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply lowercase to prompt text if necessary."""
return text.lower() if self.do_lower_case else text
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
"""Apply all validators to prompt text."""
for validator in (
self._text_len_check,
self._text_lowercase,
):
text = validator(tokenizer, text)
return text
def apply_pre_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TextPrompt,
) -> TextPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run before tokenization occurs.
"""
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
return prompt
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply padding to prompt tokens if necessary."""
pad_length = self.pad_prompt_tokens
if pad_length is not None and pad_length < 0:
pad_length = self.max_input_tokens
if pad_length is None or pad_length <= len(tokens):
return tokens
if tokenizer is None:
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
if not isinstance(tokens, list):
raise ValueError("Cannot pad tokens for embedding inputs")
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply truncation to prompt tokens if necessary."""
max_length = self.truncate_prompt_tokens
if max_length is not None and max_length < 0:
max_length = self.max_input_tokens
if max_length is None or max_length >= len(tokens):
return tokens
if max_length == 0:
return tokens[:0]
if getattr(tokenizer, "truncation_side", "left") == "left":
return tokens[-max_length:]
return tokens[:max_length]
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply length checks to prompt tokens if necessary."""
max_input_tokens = self.max_input_tokens
if max_input_tokens is None:
return tokens
if len(tokens) > max_input_tokens:
raise VLLMValidationError(
f"You passed {len(tokens)} input tokens "
f"and requested {self.max_output_tokens} output tokens. "
f"However, the model's context length is only "
f"{self.max_total_tokens} tokens, resulting in a maximum "
f"input length of {max_input_tokens} tokens. "
f"Please reduce the length of the input prompt.",
parameter="input_tokens",
value=len(tokens),
)
return tokens
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
"""Apply all validators to a token sequence."""
for validator in (
self._token_padding,
self._token_truncation,
self._token_len_check,
):
tokens = validator(tokenizer, tokens)
return tokens
def apply_post_tokenization(
self,
tokenizer: TokenizerLike | None,
prompt: TokensPrompt | EmbedsPrompt,
) -> TokensPrompt | EmbedsPrompt:
"""
Ensure that the prompt meets the requirements set out by this config.
If that is not possible, raise a `VLLMValidationError`.
This method is run after tokenization occurs.
"""
if "prompt_token_ids" in prompt:
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
)
if "prompt_embeds" in prompt:
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
tokenizer,
prompt["prompt_embeds"], # type: ignore[typeddict-item]
)
return prompt