# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, TypeVar from vllm.exceptions import VLLMValidationError from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.utils.import_utils import LazyLoader if TYPE_CHECKING: import torch from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption else: torch = LazyLoader("torch", globals(), "torch") ChatTemplateContentFormatOption = object logger = init_logger(__name__) _S = TypeVar("_S", list[int], "torch.Tensor") def merge_kwargs( defaults: dict[str, Any] | None, overrides: dict[str, Any] | None, /, *, unset_values: tuple[object, ...] = (None, "auto"), ) -> dict[str, Any]: if defaults is None: defaults = {} if overrides is None: overrides = {} return defaults | {k: v for k, v in overrides.items() if v not in unset_values} @dataclass(frozen=True) class ChatParams: """Configuration to control how to parse chat messages.""" chat_template: str | None = None """The chat template to apply.""" chat_template_content_format: "ChatTemplateContentFormatOption" = "auto" """The format of the chat template.""" chat_template_kwargs: dict[str, Any] = field(default_factory=dict) """The kwargs to pass to the chat template.""" def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None): if not default_chat_template_kwargs: return self return ChatParams( chat_template=self.chat_template, chat_template_content_format=self.chat_template_content_format, chat_template_kwargs=merge_kwargs( default_chat_template_kwargs, self.chat_template_kwargs, ), ) def get_apply_chat_template_kwargs(self) -> dict[str, Any]: """The arguments to pass to `tokenizer.apply_chat_template`.""" return merge_kwargs( self.chat_template_kwargs, dict(chat_template=self.chat_template, return_dict=False), ) @dataclass(frozen=True) class TokenizeParams: """Configuration to control how prompts are tokenized.""" max_total_tokens: int | None """ Maximum allowed number of input + output tokens. Usually, this refers to the model's context length. """ max_output_tokens: int = 0 """Maximum requested number of output tokens.""" pad_prompt_tokens: int | None = None """ Number of tokens to pad to: - `None` means no padding. - `-1` maps to `max_input_tokens`. """ truncate_prompt_tokens: int | None = None """ Number of tokens to keep: - `None` means no truncation. - `-1` maps to `max_input_tokens`. """ do_lower_case: bool = False """Whether to normalize text to lower case before tokenization.""" add_special_tokens: bool = True """Whether to add special tokens.""" needs_detokenization: bool = False """ Whether the tokenized prompt needs to contain the original text. Not to be confused with `SamplingParams.detokenize` which deals with the output generated by the model. """ max_total_tokens_param: str = "max_total_tokens" """Override this to edit the message for validation errors.""" max_output_tokens_param: str = "max_output_tokens" """Override this to edit the message for validation errors.""" truncate_prompt_tokens_param: str = "truncate_prompt_tokens" """Override this to edit the message for validation errors.""" @property def max_input_tokens(self) -> int | None: """Maximum allowed number of input tokens.""" if self.max_total_tokens is None: return None return self.max_total_tokens - self.max_output_tokens def __post_init__(self) -> None: max_total_tokens = self.max_total_tokens max_output_tokens = self.max_output_tokens max_input_tokens = self.max_input_tokens truncate_prompt_tokens = self.truncate_prompt_tokens if ( max_output_tokens is not None and max_total_tokens is not None and max_output_tokens > max_total_tokens ): raise VLLMValidationError( f"{self.max_output_tokens_param}={max_output_tokens}" f"cannot be greater than " f"{self.max_total_tokens_param}={max_total_tokens=}. " f"Please request fewer output tokens.", parameter=self.max_output_tokens_param, value=max_output_tokens, ) if ( max_input_tokens is not None and truncate_prompt_tokens is not None and truncate_prompt_tokens > max_input_tokens ): raise VLLMValidationError( f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} " f"cannot be greater than {self.max_total_tokens_param} - " f"{self.max_output_tokens_param} = {max_input_tokens}. " f"Please request a smaller truncation size.", parameter=self.truncate_prompt_tokens_param, value=truncate_prompt_tokens, ) def with_kwargs(self, **tokenization_kwargs: Any): max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens) pad_prompt_tokens = tokenization_kwargs.pop( "pad_prompt_tokens", self.pad_prompt_tokens ) truncate_prompt_tokens = tokenization_kwargs.pop( "truncate_prompt_tokens", self.truncate_prompt_tokens ) do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case) add_special_tokens = tokenization_kwargs.pop( "add_special_tokens", self.add_special_tokens ) needs_detokenization = tokenization_kwargs.pop( "needs_detokenization", self.needs_detokenization ) # https://huggingface.co/docs/transformers/en/pad_truncation if padding := tokenization_kwargs.pop("padding", None): if padding == "max_length": pad_prompt_tokens = max_length elif padding in (False, "do_not_pad"): pad_prompt_tokens = None else: # To emit the below warning tokenization_kwargs["padding"] = padding if truncation := tokenization_kwargs.pop("truncation", None): if truncation in (True, "longest_first"): truncate_prompt_tokens = max_length elif truncation in (False, "do_not_truncate"): truncate_prompt_tokens = None else: # To emit the below warning tokenization_kwargs["truncation"] = truncation if tokenization_kwargs: logger.warning( "The following tokenization arguments are not supported " "by vLLM Renderer and will be ignored: %s", tokenization_kwargs, ) max_total_tokens = self.max_total_tokens return TokenizeParams( max_total_tokens=max_total_tokens, max_output_tokens=( 0 if max_total_tokens is None or max_length is None else max_total_tokens - max_length ), pad_prompt_tokens=pad_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens, do_lower_case=do_lower_case, add_special_tokens=add_special_tokens, needs_detokenization=needs_detokenization, ) def get_encode_kwargs(self) -> dict[str, Any]: """The arguments to pass to `tokenizer.encode`.""" max_length = self.truncate_prompt_tokens if max_length is not None and max_length < 0: max_length = self.max_input_tokens elif max_length is None and self.max_input_tokens is not None: # This prevents tokenization from taking up more resources than necessary # while still failing `self._token_len_check` as expected by users max_length = self.max_input_tokens + 1 return dict( truncation=max_length is not None, max_length=max_length, add_special_tokens=self.add_special_tokens, ) def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str: """Apply length checks to prompt text if necessary.""" max_input_tokens = self.max_input_tokens if max_input_tokens is None: return text if self.truncate_prompt_tokens is None and tokenizer is not None: max_input_chars = max_input_tokens * tokenizer.max_chars_per_token if len(text) > max_input_chars: # To save resources, fail the request outright without even # attempting tokenization raise VLLMValidationError( f"You passed {len(text)} input characters " f"and requested {self.max_output_tokens} output tokens. " f"However, the model's context length is only " f"{self.max_total_tokens} tokens, resulting in a maximum " f"input length of {max_input_tokens} tokens " f"(at most {max_input_chars} characters). " f"Please reduce the length of the input prompt.", parameter="input_text", value=len(text), ) return text def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str: """Apply lowercase to prompt text if necessary.""" return text.lower() if self.do_lower_case else text def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str: """Apply all validators to prompt text.""" for validator in ( self._text_len_check, self._text_lowercase, ): text = validator(tokenizer, text) return text def apply_pre_tokenization( self, tokenizer: TokenizerLike | None, prompt: TextPrompt, ) -> TextPrompt: """ Ensure that the prompt meets the requirements set out by this config. If that is not possible, raise a `VLLMValidationError`. This method is run before tokenization occurs. """ prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"]) return prompt def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: """Apply padding to prompt tokens if necessary.""" pad_length = self.pad_prompt_tokens if pad_length is not None and pad_length < 0: pad_length = self.max_input_tokens if pad_length is None or pad_length <= len(tokens): return tokens if tokenizer is None: raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`") if not isinstance(tokens, list): raise ValueError("Cannot pad tokens for embedding inputs") return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens)) def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: """Apply truncation to prompt tokens if necessary.""" max_length = self.truncate_prompt_tokens if max_length is not None and max_length < 0: max_length = self.max_input_tokens if max_length is None or max_length >= len(tokens): return tokens if max_length == 0: return tokens[:0] if getattr(tokenizer, "truncation_side", "left") == "left": return tokens[-max_length:] return tokens[:max_length] def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: """Apply length checks to prompt tokens if necessary.""" max_input_tokens = self.max_input_tokens if max_input_tokens is None: return tokens if len(tokens) > max_input_tokens: raise VLLMValidationError( f"You passed {len(tokens)} input tokens " f"and requested {self.max_output_tokens} output tokens. " f"However, the model's context length is only " f"{self.max_total_tokens} tokens, resulting in a maximum " f"input length of {max_input_tokens} tokens. " f"Please reduce the length of the input prompt.", parameter="input_tokens", value=len(tokens), ) return tokens def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S: """Apply all validators to a token sequence.""" for validator in ( self._token_padding, self._token_truncation, self._token_len_check, ): tokens = validator(tokenizer, tokens) return tokens def apply_post_tokenization( self, tokenizer: TokenizerLike | None, prompt: TokensPrompt | EmbedsPrompt, ) -> TokensPrompt | EmbedsPrompt: """ Ensure that the prompt meets the requirements set out by this config. If that is not possible, raise a `VLLMValidationError`. This method is run after tokenization occurs. """ if "prompt_token_ids" in prompt: prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key] tokenizer, prompt["prompt_token_ids"], # type: ignore[typeddict-item] ) if "prompt_embeds" in prompt: prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key] tokenizer, prompt["prompt_embeds"], # type: ignore[typeddict-item] ) return prompt