# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utility functions for vLLM config dataclasses.""" import ast import inspect import textwrap from collections.abc import Iterable from dataclasses import MISSING, Field, field, fields, is_dataclass, replace from itertools import pairwise from typing import TYPE_CHECKING, Any, Protocol, TypeVar import regex as re from pydantic.fields import FieldInfo from typing_extensions import runtime_checkable if TYPE_CHECKING: from _typeshed import DataclassInstance else: DataclassInstance = Any ConfigType = type[DataclassInstance] ConfigT = TypeVar("ConfigT", bound=ConfigType) def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument provided by `get_kwargs` will be `pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the `cli_arg` as a JSON string which gets validated by `pydantic`. Config validation is performed by the tools/pre_commit/validate_config.py script, which is invoked during the pre-commit checks. """ return cls def get_field(cls: ConfigType, name: str) -> Field: """Get the default factory field of a dataclass by name. Used for getting default factory fields in `EngineArgs`.""" if not is_dataclass(cls): raise TypeError("The given class is not a dataclass.") cls_fields = {f.name: f for f in fields(cls)} if name not in cls_fields: raise ValueError(f"Field '{name}' not found in {cls.__name__}.") named_field: Field = cls_fields[name] if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: if isinstance(default, FieldInfo): # Handle pydantic.Field defaults if default.default_factory is not None: return field(default_factory=default.default_factory) else: default = default.default return field(default=default) raise ValueError( f"{cls.__name__}.{name} must have a default value or default factory." ) def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any: """ A helper function that retrieves an attribute from an object which may have multiple possible names. This is useful when fetching attributes from arbitrary `transformers.PretrainedConfig` instances. """ for name in names: if hasattr(object, name): return getattr(object, name) return default def contains_object_print(text: str) -> bool: """ Check if the text looks like a printed Python object, e.g. contains any substring matching the pattern: "at 0xFFFFFFF>" We match against 0x followed by 2-16 hex chars (there's a max of 16 on a 64-bit system). Args: text (str): The text to check Returns: result (bool): `True` if a match is found, `False` otherwise. """ pattern = r"at 0x[a-fA-F0-9]{2,16}>" match = re.search(pattern, text) return match is not None def assert_hashable(text: str) -> bool: if not contains_object_print(text): return True raise AssertionError( f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " f"Text being hashed: {text}" ) def get_attr_docs(cls: type[Any]) -> dict[str, str]: """ Get any docstrings placed after attribute assignments in a class body. https://davidism.com/mit-license/ """ cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0] if not isinstance(cls_node, ast.ClassDef): raise TypeError("Given object was not a class.") out = {} # Consider each pair of nodes. for a, b in pairwise(cls_node.body): # Must be an assignment then a constant string. if ( not isinstance(a, (ast.Assign, ast.AnnAssign)) or not isinstance(b, ast.Expr) or not isinstance(b.value, ast.Constant) or not isinstance(b.value.value, str) ): continue doc = inspect.cleandoc(b.value.value) # An assignment can have multiple targets (a = b = v), but an # annotated assignment only has one target. targets = a.targets if isinstance(a, ast.Assign) else [a.target] for target in targets: # Must be assigning to a plain name. if not isinstance(target, ast.Name): continue out[target.id] = doc return out def is_init_field(cls: ConfigType, name: str) -> bool: return next(f for f in fields(cls) if f.name == name).init @runtime_checkable class SupportsHash(Protocol): def compute_hash(self) -> str: ... class SupportsMetricsInfo(Protocol): def metrics_info(self) -> dict[str, str]: ... def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT: processed_overrides = {} for field_name, value in overrides.items(): assert hasattr(config, field_name), ( f"{type(config)} has no field `{field_name}`" ) current_value = getattr(config, field_name) if is_dataclass(current_value) and not is_dataclass(value): assert isinstance(value, dict), ( f"Overrides to {type(config)}.{field_name} must be a dict" f" or {type(current_value)}, but got {type(value)}" ) value = update_config( current_value, # type: ignore[type-var] value, ) processed_overrides[field_name] = value return replace(config, **processed_overrides)