Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -7,7 +7,7 @@ import json as json_mod
|
||||
from dataclasses import field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Any
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -107,6 +107,43 @@ class StructuredOutputsParams:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepetitionDetectionParams:
|
||||
"""Parameters for detecting repetitive N-gram patterns in output tokens."""
|
||||
|
||||
max_pattern_size: int = 0
|
||||
"""Maximum size of N-gram pattern to detect for sequence repetition.
|
||||
Set to 0 to disable. Must be used together with min_count."""
|
||||
|
||||
min_pattern_size: int = 0
|
||||
"""Minimum N-gram pattern size to check for sequence repetition.
|
||||
If set to 0, it defaults to 1.
|
||||
Must be <= max_pattern_size."""
|
||||
|
||||
min_count: int = 0
|
||||
"""Minimum number of times an N-gram pattern must repeat to trigger
|
||||
detection. Must be >= 2. Example: 3 for detecting a phrase repeated
|
||||
3 times. Must be used together with max_pattern_size."""
|
||||
|
||||
def __post_init__(self):
|
||||
if (
|
||||
self.max_pattern_size < 0
|
||||
or self.min_pattern_size < 0
|
||||
or self.min_pattern_size > self.max_pattern_size
|
||||
):
|
||||
raise ValueError(
|
||||
"max_pattern_size, min_pattern_size must be >=0, "
|
||||
"with min_pattern_size <= max_pattern_size. "
|
||||
"Set both to 0 to disable repetitive pattern detection."
|
||||
)
|
||||
if self.max_pattern_size > 0 and self.min_count < 2:
|
||||
raise ValueError(
|
||||
"min_count must be >= 2 to detect repetitive patterns "
|
||||
"in engine output. If you do not wish to detect repetitive "
|
||||
"patterns, set max_pattern_size to 0."
|
||||
)
|
||||
|
||||
|
||||
class RequestOutputKind(Enum):
|
||||
# Return entire output so far in every RequestOutput
|
||||
CUMULATIVE = 0
|
||||
@@ -209,10 +246,6 @@ class SamplingParams(
|
||||
"""Whether to add spaces between special tokens in the output."""
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE
|
||||
skip_clone: bool = False
|
||||
"""Internal flag indicating that this SamplingParams instance is safe to
|
||||
@@ -250,6 +283,14 @@ class SamplingParams(
|
||||
|
||||
skip_reading_prefix_cache: bool | None = None
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = None
|
||||
"""Parameters for detecting repetitive N-gram patterns in output tokens.
|
||||
If such repetition is detected, generation will be ended early. LLMs can
|
||||
sometimes generate repetitive, unhelpful token patterns, stopping only
|
||||
when they hit the maximum output length (e.g. 'abcdabcdabcd...' or
|
||||
'\\emoji \\emoji \\emoji ...'). This feature can detect such behavior
|
||||
and terminate early, saving time and tokens."""
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: int | None = 1,
|
||||
@@ -273,13 +314,13 @@ class SamplingParams(
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
structured_outputs: StructuredOutputsParams | None = None,
|
||||
logit_bias: dict[int, float] | dict[str, float] | None = None,
|
||||
allowed_token_ids: list[int] | None = None,
|
||||
extra_args: dict[str, Any] | None = None,
|
||||
skip_clone: bool = False,
|
||||
repetition_detection: RepetitionDetectionParams | None = None,
|
||||
) -> "SamplingParams":
|
||||
if logit_bias is not None:
|
||||
# Convert token_id to integer
|
||||
@@ -313,13 +354,13 @@ class SamplingParams(
|
||||
detokenize=detokenize,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
output_kind=output_kind,
|
||||
structured_outputs=structured_outputs,
|
||||
logit_bias=logit_bias,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
extra_args=extra_args,
|
||||
skip_clone=skip_clone,
|
||||
repetition_detection=repetition_detection,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
@@ -449,15 +490,6 @@ class SamplingParams(
|
||||
parameter="prompt_logprobs",
|
||||
value=self.prompt_logprobs,
|
||||
)
|
||||
if self.truncate_prompt_tokens is not None and (
|
||||
self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1
|
||||
):
|
||||
raise VLLMValidationError(
|
||||
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||
f"got {self.truncate_prompt_tokens}",
|
||||
parameter="truncate_prompt_tokens",
|
||||
value=self.truncate_prompt_tokens,
|
||||
)
|
||||
assert isinstance(self.stop_token_ids, list)
|
||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||
raise ValueError(
|
||||
@@ -678,9 +710,9 @@ class SamplingParams(
|
||||
return
|
||||
|
||||
# Some sampling parameters are not yet compatible with spec decoding.
|
||||
if self.min_tokens > 1 or self.min_p > _SAMPLING_EPS or self.logit_bias:
|
||||
if self.min_p > _SAMPLING_EPS or self.logit_bias:
|
||||
raise ValueError(
|
||||
"The min_tokens, min_p, and logit_bias sampling parameters "
|
||||
"The min_p and logit_bias sampling parameters "
|
||||
"are not yet supported with speculative decoding."
|
||||
)
|
||||
|
||||
@@ -835,11 +867,28 @@ class SamplingParams(
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
||||
f"structured_outputs={self.structured_outputs}, "
|
||||
f"extra_args={self.extra_args})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def for_sampler_warmup() -> "SamplingParams":
|
||||
"""Set parameters to exercise all sampler logic."""
|
||||
return SamplingParams(
|
||||
temperature=0.9,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
min_p=0.1,
|
||||
frequency_penalty=0.5,
|
||||
presence_penalty=0.5,
|
||||
repetition_penalty=1.2,
|
||||
min_tokens=2,
|
||||
logit_bias={0: -1.0, 1: 0.5},
|
||||
_bad_words_token_ids=[[0], [1, 2]],
|
||||
logprobs=5,
|
||||
prompt_logprobs=1,
|
||||
)
|
||||
|
||||
|
||||
class BeamSearchParams(
|
||||
msgspec.Struct,
|
||||
|
||||
Reference in New Issue
Block a user