# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" import copy import json as json_mod from dataclasses import field from enum import Enum, IntEnum from functools import cached_property from typing import Annotated, Any import msgspec from pydantic.dataclasses import dataclass from vllm.config import ModelConfig, SpeculativeConfig, StructuredOutputsConfig from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike from vllm.utils.mistral import is_mistral_tokenizer from vllm.v1.serial_utils import PydanticMsgspecMixin logger = init_logger(__name__) _SAMPLING_EPS = 1e-5 _MAX_TEMP = 1e-2 class SamplingType(IntEnum): GREEDY = 0 RANDOM = 1 RANDOM_SEED = 2 # maybe make msgspec? @dataclass class StructuredOutputsParams: # One of these fields will be used to build a logit processor. json: str | dict | None = None regex: str | None = None choice: list[str] | None = None grammar: str | None = None json_object: bool | None = None # These are other options that can be set. disable_fallback: bool = False disable_any_whitespace: bool = False disable_additional_properties: bool = False whitespace_pattern: str | None = None structural_tag: str | None = None _backend: str | None = field(default=None, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" _backend_was_auto: bool = field(default=False, init=False) """CAUTION: Should only be set by Processor._validate_structured_output""" def __post_init__(self): """Validate that some fields are mutually exclusive.""" count = sum( [ self.json is not None, self.regex is not None, self.choice is not None, self.grammar is not None, self.json_object is not None, self.structural_tag is not None, ] ) if count > 1: raise ValueError( "You can only use one kind of structured outputs constraint " f"but multiple are specified: {self.__dict__}" ) if count < 1: raise ValueError( "You must use one kind of structured outputs constraint " f"but none are specified: {self.__dict__}" ) def all_constraints_none(self) -> bool: """ Returns True if all structured-output constraint fields are None. """ return all( getattr(self, field) is None for field in ( "json", "regex", "choice", "grammar", "json_object", "structural_tag", ) ) def all_non_structural_tag_constraints_none(self) -> bool: """ Returns True if all structured-output constraint fields are None. """ return all( getattr(self, field) is None for field in ( "json", "regex", "choice", "grammar", "json_object", ) ) class RequestOutputKind(Enum): # Return entire output so far in every RequestOutput CUMULATIVE = 0 # Return only deltas in each RequestOutput DELTA = 1 # Do not return intermediate RequestOutput FINAL_ONLY = 2 class SamplingParams( PydanticMsgspecMixin, msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. dict=True, ): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion API (https://platform.openai.com/docs/api-reference/completions/create). In addition, we support beam search, which is not supported by OpenAI. """ n: int = 1 """Number of outputs to return for the given prompt request. NOTE: `AsyncLLM` streams outputs by default. When `n > 1`, all `n` outputs are generated and streamed cumulatively per request. To see all `n` outputs upon completion, use `output_kind=RequestOutputKind.FINAL_ONLY` in `SamplingParams`.""" presence_penalty: float = 0.0 """Penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.""" frequency_penalty: float = 0.0 """Penalizes new tokens based on their frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens.""" repetition_penalty: float = 1.0 """Penalizes new tokens based on whether they appear in the prompt and the generated text so far. Values > 1 encourage the model to use new tokens, while values < 1 encourage the model to repeat tokens.""" temperature: float = 1.0 """Controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling.""" top_p: float = 1.0 """Controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens.""" top_k: int = 0 """Controls the number of top tokens to consider. Set to 0 (or -1) to consider all tokens.""" min_p: float = 0.0 """Represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this.""" seed: int | None = None """Random seed to use for the generation.""" stop: str | list[str] | None = None """String(s) that stop the generation when they are generated. The returned output will not contain the stop strings.""" stop_token_ids: list[int] | None = None """Token IDs that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens.""" ignore_eos: bool = False """Whether to ignore the EOS token and continue generating tokens after the EOS token is generated.""" max_tokens: int | None = 16 """Maximum number of tokens to generate per output sequence.""" min_tokens: int = 0 """Minimum number of tokens to generate per output sequence before EOS or `stop_token_ids` can be generated""" logprobs: int | None = None """Number of log probabilities to return per output token. When set to `None`, no probability is returned. If set to a non-`None` value, the result includes the log probabilities of the specified number of most likely tokens, as well as the chosen tokens. Note that the implementation follows the OpenAI API: The API will always return the log probability of the sampled token, so there may be up to `logprobs+1` elements in the response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: int | None = None """Number of log probabilities to return per prompt token. When set to -1, return all `vocab_size` log probabilities.""" flat_logprobs: bool = False """Whether to return logprobs in flatten format (i.e. FlatLogprob) for better performance. NOTE: GC costs of FlatLogprobs is significantly smaller than list[dict[int, Logprob]]. After enabled, PromptLogprobs and SampleLogprobs would populated as FlatLogprobs.""" # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. detokenize: bool = True """Whether to detokenize the output.""" skip_special_tokens: bool = True """Whether to skip special tokens in the output.""" spaces_between_special_tokens: bool = True """Whether to add spaces between special tokens in the output.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE skip_clone: bool = False """Internal flag indicating that this SamplingParams instance is safe to reuse without cloning. When True, clone() will return self without performing a deep copy. This should only be set when the params object is guaranteed to be dedicated to a single request and won't be modified in ways that would affect other uses.""" # The below fields are not supposed to be used as an input. # They are set in post_init. output_text_buffer_length: int = 0 _eos_token_id: int | None = None _all_stop_token_ids: set[int] = msgspec.field(default_factory=set) # Fields used to construct logits processors structured_outputs: StructuredOutputsParams | None = None """Parameters for configuring structured outputs.""" logit_bias: dict[int, float] | None = None """If provided, the engine will construct a logits processor that applies these logit biases.""" allowed_token_ids: list[int] | None = None """If provided, the engine will construct a logits processor which only retains scores for the given token ids.""" extra_args: dict[str, Any] | None = None """Arbitrary additional args, that can be used by custom sampling implementations, plugins, etc. Not used by any in-tree sampling implementations.""" # Fields used for bad words bad_words: list[str] | None = None """Words that are not allowed to be generated. More precisely, only the last token of a corresponding token sequence is not allowed when the next generated token can complete the sequence.""" _bad_words_token_ids: list[list[int]] | None = None skip_reading_prefix_cache: bool | None = None @staticmethod def from_optional( n: int | None = 1, presence_penalty: float | None = 0.0, frequency_penalty: float | None = 0.0, repetition_penalty: float | None = 1.0, temperature: float | None = 1.0, top_p: float | None = 1.0, top_k: int = 0, min_p: float = 0.0, seed: int | None = None, stop: str | list[str] | None = None, stop_token_ids: list[int] | None = None, bad_words: list[str] | None = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: int | None = 16, min_tokens: int = 0, logprobs: int | None = None, prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, structured_outputs: StructuredOutputsParams | None = None, logit_bias: dict[int, float] | dict[str, float] | None = None, allowed_token_ids: list[int] | None = None, extra_args: dict[str, Any] | None = None, skip_clone: bool = False, ) -> "SamplingParams": if logit_bias is not None: # Convert token_id to integer # Clamp the bias between -100 and 100 per OpenAI API spec logit_bias = { int(token): min(100.0, max(-100.0, bias)) for token, bias in logit_bias.items() } return SamplingParams( n=1 if n is None else n, presence_penalty=0.0 if presence_penalty is None else presence_penalty, frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty, repetition_penalty=1.0 if repetition_penalty is None else repetition_penalty, temperature=1.0 if temperature is None else temperature, top_p=1.0 if top_p is None else top_p, top_k=top_k, min_p=min_p, seed=seed, stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, min_tokens=min_tokens, logprobs=logprobs, prompt_logprobs=prompt_logprobs, detokenize=detokenize, skip_special_tokens=skip_special_tokens, spaces_between_special_tokens=spaces_between_special_tokens, truncate_prompt_tokens=truncate_prompt_tokens, output_kind=output_kind, structured_outputs=structured_outputs, logit_bias=logit_bias, allowed_token_ids=allowed_token_ids, extra_args=extra_args, skip_clone=skip_clone, ) def __post_init__(self) -> None: if 0 < self.temperature < _MAX_TEMP: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", self.temperature, _MAX_TEMP, _MAX_TEMP, ) self.temperature = max(self.temperature, _MAX_TEMP) if self.seed == -1: self.seed = None if self.stop is None: self.stop = [] elif isinstance(self.stop, str): self.stop = [self.stop] if self.stop_token_ids is None: self.stop_token_ids = [] if self.bad_words is None: self.bad_words = [] if self.logprobs is True: self.logprobs = 1 if self.prompt_logprobs is True: self.prompt_logprobs = 1 # Number of characters to hold back for stop string evaluation # until sequence is finished. if self.stop and not self.include_stop_str_in_output: self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self._verify_args() if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self.top_p = 1.0 self.top_k = 0 self.min_p = 0.0 self._verify_greedy_sampling() # eos_token_id is added to this by the engine self._all_stop_token_ids.update(self.stop_token_ids) if self.skip_reading_prefix_cache is None: # If prefix caching is enabled, # the output of prompt logprobs may less than n_prompt_tokens, # we need to skip reading cache at this request. self.skip_reading_prefix_cache = self.prompt_logprobs is not None def _verify_args(self) -> None: if not isinstance(self.n, int): raise ValueError(f"n must be an int, but is of type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError( f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." ) if not -2.0 <= self.frequency_penalty <= 2.0: raise ValueError( f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." ) if self.repetition_penalty <= 0.0: raise ValueError( "repetition_penalty must be greater than zero, got " f"{self.repetition_penalty}." ) if self.temperature < 0.0: raise VLLMValidationError( f"temperature must be non-negative, got {self.temperature}.", parameter="temperature", value=self.temperature, ) if not 0.0 < self.top_p <= 1.0: raise VLLMValidationError( f"top_p must be in (0, 1], got {self.top_p}.", parameter="top_p", value=self.top_p, ) # quietly accept -1 as disabled, but prefer 0 if self.top_k < -1: raise ValueError( f"top_k must be 0 (disable), or at least 1, got {self.top_k}." ) if not isinstance(self.top_k, int): raise TypeError( f"top_k must be an integer, got {type(self.top_k).__name__}" ) if not 0.0 <= self.min_p <= 1.0: raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.max_tokens is not None and self.max_tokens < 1: raise VLLMValidationError( f"max_tokens must be at least 1, got {self.max_tokens}.", parameter="max_tokens", value=self.max_tokens, ) if self.min_tokens < 0: raise ValueError( f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." ) if self.max_tokens is not None and self.min_tokens > self.max_tokens: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}." ) if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0: raise VLLMValidationError( f"logprobs must be non-negative or -1, got {self.logprobs}.", parameter="logprobs", value=self.logprobs, ) if ( self.prompt_logprobs is not None and self.prompt_logprobs != -1 and self.prompt_logprobs < 0 ): raise VLLMValidationError( f"prompt_logprobs must be non-negative or -1, got " f"{self.prompt_logprobs}.", parameter="prompt_logprobs", value=self.prompt_logprobs, ) if self.truncate_prompt_tokens is not None and ( self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 ): raise VLLMValidationError( f"truncate_prompt_tokens must be an integer >= 1 or -1, " f"got {self.truncate_prompt_tokens}", parameter="truncate_prompt_tokens", value=self.truncate_prompt_tokens, ) assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): raise ValueError( f"stop_token_ids must contain only integers, got {self.stop_token_ids}." ) assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop." ) def _verify_greedy_sampling(self) -> None: if self.n > 1: raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.") def update_from_generation_config( self, generation_config: dict[str, Any], eos_token_id: int | None = None, ) -> None: """Update if there are non-default values from generation_config""" if not self.ignore_eos: self._eos_token_id = eos_token_id if eos_token_id is not None: # Add the eos token id into the sampling_params to support # min_tokens processing. self._all_stop_token_ids.add(eos_token_id) # Update eos_token_id for generation if (eos_ids := generation_config.get("eos_token_id")) is not None: # it can be either int or list of int eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) if eos_token_id is not None: # We don't need to include the primary eos_token_id in # stop_token_ids since it's handled separately for stopping # purposes. eos_ids.discard(eos_token_id) if eos_ids: self._all_stop_token_ids.update(eos_ids) if not self.ignore_eos: eos_ids.update(self.stop_token_ids) self.stop_token_ids = list(eos_ids) def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None: if not self.bad_words: return self._bad_words_token_ids = [] for bad_word in self.bad_words: # To prohibit words both at the beginning # and in the middle of text # (related to add_prefix_space tokenizer parameter) for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() prompt_token_ids = tokenizer.encode( text=prompt, add_special_tokens=False ) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( add_prefix_space and prompt_token_ids[0] != self._bad_words_token_ids[-1][0] and len(prompt_token_ids) == len(self._bad_words_token_ids[-1]) ): self._bad_words_token_ids.append(prompt_token_ids) invalid_token_ids = [ token_id for bad_words_token_ids in self._bad_words_token_ids for token_id in bad_words_token_ids if token_id < 0 or token_id > tokenizer.max_token_id ] if len(invalid_token_ids) > 0: raise VLLMValidationError( f"The model vocabulary size is {tokenizer.max_token_id + 1}," f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" f" 0 <= token_id <= {tokenizer.max_token_id}.", parameter="bad_words", value=self.bad_words, ) @cached_property def sampling_type(self) -> SamplingType: if self.temperature < _SAMPLING_EPS: return SamplingType.GREEDY if self.seed is not None: return SamplingType.RANDOM_SEED return SamplingType.RANDOM @property def eos_token_id(self) -> int | None: return self._eos_token_id @property def all_stop_token_ids(self) -> set[int]: return self._all_stop_token_ids @property def bad_words_token_ids(self) -> list[list[int]] | None: # For internal use only. Backward compatibility not guaranteed return self._bad_words_token_ids def clone(self) -> "SamplingParams": """If skip_clone is True, uses shallow copy instead of deep copy.""" if self.skip_clone: return copy.copy(self) return copy.deepcopy(self) def verify( self, model_config: ModelConfig, speculative_config: SpeculativeConfig | None, structured_outputs_config: StructuredOutputsConfig | None, tokenizer: TokenizerLike | None, ) -> None: self._validate_logprobs(model_config) self._validate_logit_bias(model_config) self._validate_logits_processors(model_config) self._validate_allowed_token_ids(tokenizer) self._validate_spec_decode(speculative_config) self._validate_structured_outputs(structured_outputs_config, tokenizer) def _validate_logprobs(self, model_config: ModelConfig) -> None: max_logprobs = model_config.max_logprobs if max_logprobs == -1: max_logprobs = model_config.get_vocab_size() # Validate sample logprobs. if num_logprobs := self.logprobs: if num_logprobs == -1: num_logprobs = model_config.get_vocab_size() if num_logprobs > max_logprobs: raise VLLMValidationError( f"Requested sample logprobs of {num_logprobs}, " f"which is greater than max allowed: {max_logprobs}", parameter="logprobs", value=num_logprobs, ) # Validate prompt logprobs. if num_prompt_logprobs := self.prompt_logprobs: if num_prompt_logprobs == -1: num_prompt_logprobs = model_config.get_vocab_size() if num_prompt_logprobs > max_logprobs: raise VLLMValidationError( f"Requested prompt logprobs of {num_prompt_logprobs}, " f"which is greater than max allowed: {max_logprobs}", parameter="prompt_logprobs", value=num_prompt_logprobs, ) def _validate_logit_bias(self, model_config: ModelConfig) -> None: """Validate logit_bias token IDs are within vocabulary range.""" if not self.logit_bias: return vocab_size = model_config.get_vocab_size() invalid_token_ids = [ token_id for token_id in self.logit_bias if token_id < 0 or token_id >= vocab_size ] if invalid_token_ids: raise VLLMValidationError( f"token_id(s) {invalid_token_ids} in logit_bias contain " f"out-of-vocab token ids. Vocabulary size: {vocab_size}", parameter="logit_bias", value=invalid_token_ids, ) def _validate_logits_processors(self, model_config: ModelConfig) -> None: from vllm.v1.sample.logits_processor import ( validate_logits_processors_parameters, ) validate_logits_processors_parameters(model_config.logits_processors, self) def _validate_allowed_token_ids(self, tokenizer: TokenizerLike | None) -> None: allowed_token_ids = self.allowed_token_ids if allowed_token_ids is None: return if len(allowed_token_ids) == 0: raise VLLMValidationError( "allowed_token_ids is not None and empty!", parameter="allowed_token_ids", value=allowed_token_ids, ) if tokenizer is not None: vocab_size = len(tokenizer) invalid_token_ids = [ token_id for token_id in allowed_token_ids if token_id < 0 or token_id >= vocab_size ] if invalid_token_ids: raise VLLMValidationError( "allowed_token_ids contains out-of-vocab token id!", parameter="allowed_token_ids", value=invalid_token_ids, ) def _validate_spec_decode( self, speculative_config: SpeculativeConfig | None, ) -> None: if speculative_config is None: return # Some sampling parameters are not yet compatible with spec decoding. if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias: raise ValueError( "The min_tokens, min_p, and logit_bias sampling parameters " "are not yet supported with speculative decoding." ) def _validate_structured_outputs( self, structured_outputs_config: StructuredOutputsConfig | None, tokenizer: TokenizerLike | None, ) -> None: if structured_outputs_config is None or self.structured_outputs is None: return if tokenizer is None: raise ValueError( "Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501 ) backend = structured_outputs_config.backend if _backend := self.structured_outputs._backend: # Request-level backend selection is not supported. # The values may differ if `params` is reused and was set # to a specific backend based on `auto` behavior in a previous # request. We remember that it was set as a result of `auto` # using the `_backend_was_auto` field set in the params. if backend != _backend and not ( backend == "auto" and self.structured_outputs._backend_was_auto ): raise ValueError( "Request-level structured output backend selection is not " f"supported. The request specified '{_backend}', but vLLM " f"was initialised with '{backend}'. This error can be " "resolved by removing '_backend' from the request." ) else: self.structured_outputs._backend = backend # Request content validation if ( isinstance(self.structured_outputs.choice, list) and not self.structured_outputs.choice ): # It is invalid for choice to be an empty list raise ValueError( f"Choice '{self.structured_outputs.choice}' cannot be an empty list" # noqa: E501 ) # Reject empty string grammar early to avoid engine-side crashes if ( isinstance(self.structured_outputs.grammar, str) and self.structured_outputs.grammar.strip() == "" ): raise ValueError("structured_outputs.grammar cannot be an empty string") from vllm.v1.structured_output.backend_guidance import ( has_guidance_unsupported_json_features, validate_guidance_grammar, ) from vllm.v1.structured_output.backend_lm_format_enforcer import ( validate_structured_output_request_lm_format_enforcer, ) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines, ) from vllm.v1.structured_output.backend_xgrammar import validate_xgrammar_grammar if backend.startswith("xgrammar"): # xgrammar with no fallback validate_xgrammar_grammar(self) elif backend.startswith("guidance"): # TODO: ideally we would have the LLTokenizer here as Lark syntax # allows <|special_token|> and similar, see # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. if is_mistral_tokenizer(tokenizer): raise ValueError( "Mistral tokenizer is not supported for the 'guidance' " "structured output backend. Please use ['xgrammar', 'outlines'] " "backends or tokenizer_mode='hf' instead." ) validate_guidance_grammar(self, tokenizer=None) elif backend == "outlines": # outlines backend validate_structured_output_request_outlines(self) elif backend == "lm-format-enforcer": # lm format enforcer backend if is_mistral_tokenizer(tokenizer): raise ValueError( "Mistral tokenizer is not supported for the 'lm-format-enforcer' " "structured output backend. Please use ['xgrammar', 'outlines'] " "backends or tokenizer_mode='hf' instead." ) validate_structured_output_request_lm_format_enforcer(self) else: # NOTE: backend must be "auto" here, because we have # checked supported_backends above. # In this mode, we set opinionated defaults based on what we think # will satisfy the most use cases without having to worry about # this setting. We include fallback behavior here, but not with any # other setting where a specific backend was specified. try: validate_xgrammar_grammar(self) self.structured_outputs._backend = "xgrammar" except ValueError: # The request either failed validation # or includes some jsonschema feature(s) that # are not supported in xgrammar. # Check if schema has features unsupported by guidance so_params = self.structured_outputs skip_guidance = False if so_params.json: if isinstance(so_params.json, str): schema = json_mod.loads(so_params.json) else: schema = so_params.json skip_guidance = has_guidance_unsupported_json_features(schema) if is_mistral_tokenizer(tokenizer) or skip_guidance: # Fall back to outlines if the tokenizer is Mistral # or if schema contains features unsupported by guidance validate_structured_output_request_outlines(self) self.structured_outputs._backend = "outlines" else: # Fall back to guidance by default. validate_guidance_grammar(self, tokenizer=None) self.structured_outputs._backend = "guidance" # Remember that this backend was set automatically self.structured_outputs._backend_was_auto = True # Run post-init validation. This is also important to ensure subsequent # roundtrip serialization/deserialization won't fail. self.structured_outputs.__post_init__() def __repr__(self) -> str: return ( f"SamplingParams(n={self.n}, " f"presence_penalty={self.presence_penalty}, " f"frequency_penalty={self.frequency_penalty}, " f"repetition_penalty={self.repetition_penalty}, " f"temperature={self.temperature}, " f"top_p={self.top_p}, " f"top_k={self.top_k}, " f"min_p={self.min_p}, " f"seed={self.seed}, " f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " f"min_tokens={self.min_tokens}, " f"logprobs={self.logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " f"structured_outputs={self.structured_outputs}, " f"extra_args={self.extra_args})" ) class BeamSearchParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] # required for @cached_property. dict=True, ): # type: ignore[call-arg] """Beam search parameters for text generation.""" beam_width: int max_tokens: int ignore_eos: bool = False temperature: float = 0.0 length_penalty: float = 1.0 include_stop_str_in_output: bool = False