Sync from v0.13
This commit is contained in:
236
vllm/utils/func_utils.py
Normal file
236
vllm/utils/func_utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers that are applied to functions.
|
||||
|
||||
This is similar in concept to the `functools` module.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def identity(value: T, **kwargs) -> T:
|
||||
"""Returns the first provided value."""
|
||||
return value
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
return
|
||||
|
||||
with wrapper.lock: # type: ignore[attr-defined]
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_args(
|
||||
start_index: int,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
params = inspect.signature(fn).parameters
|
||||
pos_types = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_args = pos_kws[start_index : len(args)]
|
||||
if deprecated_args:
|
||||
msg = (
|
||||
f"The positional arguments {deprecated_args} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_kwargs(
|
||||
*kws: str,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
deprecated_kws = set(kws)
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||
if deprecated_kwargs:
|
||||
msg = (
|
||||
f"The keyword arguments {deprecated_kwargs} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
*,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if not params:
|
||||
return False
|
||||
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set(
|
||||
(
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
)
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (
|
||||
requires_kw_only
|
||||
and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||
):
|
||||
return False
|
||||
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||
not requires_kw_only and is_sig_param
|
||||
):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (
|
||||
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Mapping[str, object] | None,
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
mapping param names to values, drop values that can be not be kwarg
|
||||
expanded to overwrite one or more keyword-only args. This is used in a
|
||||
few places to handle custom processor overrides for multimodal models,
|
||||
e.g., for profiling when processor options provided by the user
|
||||
may affect the number of mm tokens per instance.
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
to overwrite one or more keyword only arguments when invoking the
|
||||
callable.
|
||||
"""
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if supports_kw(
|
||||
callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
|
||||
return filtered_overrides
|
||||
Reference in New Issue
Block a user