146 lines
4.7 KiB
Python
146 lines
4.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import ast
|
|
import inspect
|
|
import textwrap
|
|
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
|
from typing import TYPE_CHECKING, Any, TypeVar
|
|
|
|
import regex as re
|
|
|
|
if TYPE_CHECKING:
|
|
from _typeshed import DataclassInstance
|
|
|
|
ConfigType = type[DataclassInstance]
|
|
else:
|
|
ConfigType = type
|
|
|
|
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
|
|
|
|
|
def config(cls: ConfigT) -> ConfigT:
|
|
"""
|
|
A decorator that ensures all fields in a dataclass have default values
|
|
and that each field has a docstring.
|
|
|
|
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
|
|
provided by `get_kwargs` will be
|
|
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
|
|
`cli_arg` as a JSON string which gets validated by `pydantic`.
|
|
|
|
Config validation is performed by the tools/validate_config.py
|
|
script, which is invoked during the pre-commit checks.
|
|
"""
|
|
return cls
|
|
|
|
|
|
def get_field(cls: ConfigType, name: str) -> Field:
|
|
"""Get the default factory field of a dataclass by name. Used for getting
|
|
default factory fields in `EngineArgs`."""
|
|
if not is_dataclass(cls):
|
|
raise TypeError("The given class is not a dataclass.")
|
|
cls_fields = {f.name: f for f in fields(cls)}
|
|
if name not in cls_fields:
|
|
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
|
named_field: Field = cls_fields[name]
|
|
if (default_factory := named_field.default_factory) is not MISSING:
|
|
return field(default_factory=default_factory)
|
|
if (default := named_field.default) is not MISSING:
|
|
return field(default=default)
|
|
raise ValueError(
|
|
f"{cls.__name__}.{name} must have a default value or default factory.")
|
|
|
|
|
|
def contains_object_print(text: str) -> bool:
|
|
"""
|
|
Check if the text looks like a printed Python object, e.g.
|
|
contains any substring matching the pattern: "at 0xFFFFFFF>"
|
|
We match against 0x followed by 2-16 hex chars (there's
|
|
a max of 16 on a 64-bit system).
|
|
|
|
Args:
|
|
text (str): The text to check
|
|
|
|
Returns:
|
|
result (bool): `True` if a match is found, `False` otherwise.
|
|
"""
|
|
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
|
|
match = re.search(pattern, text)
|
|
return match is not None
|
|
|
|
|
|
def assert_hashable(text: str) -> bool:
|
|
if not contains_object_print(text):
|
|
return True
|
|
raise AssertionError(
|
|
f"vLLM tried to hash some configs that may have Python objects ids "
|
|
f"in them. This is a bug, please file an issue. "
|
|
f"Text being hashed: {text}")
|
|
|
|
|
|
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
|
"""
|
|
Get any docstrings placed after attribute assignments in a class body.
|
|
|
|
https://davidism.com/mit-license/
|
|
"""
|
|
|
|
def pairwise(iterable):
|
|
"""
|
|
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
|
|
|
Can be removed when Python 3.9 support is dropped.
|
|
"""
|
|
iterator = iter(iterable)
|
|
a = next(iterator, None)
|
|
|
|
for b in iterator:
|
|
yield a, b
|
|
a = b
|
|
|
|
try:
|
|
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
|
except (OSError, KeyError, TypeError):
|
|
# HACK: Python 3.13+ workaround - set missing __firstlineno__
|
|
# Workaround can be removed after we upgrade to pydantic==2.12.0
|
|
with open(inspect.getfile(cls)) as f:
|
|
for i, line in enumerate(f):
|
|
if f"class {cls.__name__}" in line and ":" in line:
|
|
cls.__firstlineno__ = i + 1
|
|
break
|
|
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
|
|
|
if not isinstance(cls_node, ast.ClassDef):
|
|
raise TypeError("Given object was not a class.")
|
|
|
|
out = {}
|
|
|
|
# Consider each pair of nodes.
|
|
for a, b in pairwise(cls_node.body):
|
|
# Must be an assignment then a constant string.
|
|
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
|
or not isinstance(b, ast.Expr)
|
|
or not isinstance(b.value, ast.Constant)
|
|
or not isinstance(b.value.value, str)):
|
|
continue
|
|
|
|
doc = inspect.cleandoc(b.value.value)
|
|
|
|
# An assignment can have multiple targets (a = b = v), but an
|
|
# annotated assignment only has one target.
|
|
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
|
|
|
for target in targets:
|
|
# Must be assigning to a plain name.
|
|
if not isinstance(target, ast.Name):
|
|
continue
|
|
|
|
out[target.id] = doc
|
|
|
|
return out
|
|
|
|
|
|
def is_init_field(cls: ConfigType, name: str) -> bool:
|
|
return next(f for f in fields(cls) if f.name == name).init
|