v1.0
This commit is contained in:
178
config/utils.py
Normal file
178
config/utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for vLLM config dataclasses."""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
|
||||
|
||||
def config(cls: ConfigT) -> ConfigT:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
|
||||
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
|
||||
provided by `get_kwargs` will be
|
||||
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
|
||||
`cli_arg` as a JSON string which gets validated by `pydantic`.
|
||||
|
||||
Config validation is performed by the tools/pre_commit/validate_config.py
|
||||
script, which is invoked during the pre-commit checks.
|
||||
"""
|
||||
return cls
|
||||
|
||||
|
||||
def get_field(cls: ConfigType, name: str) -> Field:
|
||||
"""Get the default factory field of a dataclass by name. Used for getting
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields[name]
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
if isinstance(default, FieldInfo):
|
||||
# Handle pydantic.Field defaults
|
||||
if default.default_factory is not None:
|
||||
return field(default_factory=default.default_factory)
|
||||
else:
|
||||
default = default.default
|
||||
return field(default=default)
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory."
|
||||
)
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
"""
|
||||
for name in names:
|
||||
if hasattr(object, name):
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
contains any substring matching the pattern: "at 0xFFFFFFF>"
|
||||
We match against 0x followed by 2-16 hex chars (there's
|
||||
a max of 16 on a 64-bit system).
|
||||
|
||||
Args:
|
||||
text (str): The text to check
|
||||
|
||||
Returns:
|
||||
result (bool): `True` if a match is found, `False` otherwise.
|
||||
"""
|
||||
pattern = r"at 0x[a-fA-F0-9]{2,16}>"
|
||||
match = re.search(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def assert_hashable(text: str) -> bool:
|
||||
if not contains_object_print(text):
|
||||
return True
|
||||
raise AssertionError(
|
||||
f"vLLM tried to hash some configs that may have Python objects ids "
|
||||
f"in them. This is a bug, please file an issue. "
|
||||
f"Text being hashed: {text}"
|
||||
)
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (
|
||||
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)
|
||||
):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
def compute_hash(self) -> str: ...
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
def metrics_info(self) -> dict[str, str]: ...
|
||||
|
||||
|
||||
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
||||
processed_overrides = {}
|
||||
for field_name, value in overrides.items():
|
||||
assert hasattr(config, field_name), (
|
||||
f"{type(config)} has no field `{field_name}`"
|
||||
)
|
||||
current_value = getattr(config, field_name)
|
||||
if is_dataclass(current_value) and not is_dataclass(value):
|
||||
assert isinstance(value, dict), (
|
||||
f"Overrides to {type(config)}.{field_name} must be a dict"
|
||||
f" or {type(current_value)}, but got {type(value)}"
|
||||
)
|
||||
value = update_config(
|
||||
current_value, # type: ignore[type-var]
|
||||
value,
|
||||
)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
Reference in New Issue
Block a user