v1.0
This commit is contained in:
82
utils/__init__.py
Normal file
82
utils/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
# Used by lm-eval
|
||||
"get_open_port": "network_utils",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
|
||||
"""Module-level getattr to handle deprecated utilities."""
|
||||
if name in _DEPRECATED_MAPPINGS:
|
||||
submodule_name = _DEPRECATED_MAPPINGS[name]
|
||||
warnings.warn(
|
||||
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
|
||||
f"Use vllm.utils.{submodule_name}.{name} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
# expose deprecated names in dir() for better UX/tab-completion
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
||||
# String name of register which may be set in order to
|
||||
# force auto-selection of attention backend by Attention
|
||||
# wrapper
|
||||
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
# Possible string values of STR_BACKEND_ENV_VAR
|
||||
# register, corresponding to possible backends
|
||||
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids: list[int] | None,
|
||||
prompt_embeds: torch.Tensor | None,
|
||||
) -> int:
|
||||
"""Calculate the request length (in number of tokens) give either
|
||||
prompt_token_ids or prompt_embeds.
|
||||
"""
|
||||
prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
|
||||
prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
|
||||
|
||||
if prompt_token_len is None:
|
||||
if prompt_embeds_len is None:
|
||||
raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
|
||||
return prompt_embeds_len
|
||||
else:
|
||||
if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
|
||||
raise ValueError(
|
||||
"Prompt token ids and prompt embeds had different lengths"
|
||||
f" prompt_token_ids={prompt_token_len}"
|
||||
f" prompt_embeds={prompt_embeds_len}"
|
||||
)
|
||||
return prompt_token_len
|
||||
BIN
utils/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
utils/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/argparse_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/argparse_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/async_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/async_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/cache.cpython-312.pyc
Normal file
BIN
utils/__pycache__/cache.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/collection_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/collection_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/counter.cpython-312.pyc
Normal file
BIN
utils/__pycache__/counter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/deep_gemm.cpython-312.pyc
Normal file
BIN
utils/__pycache__/deep_gemm.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/flashinfer.cpython-312.pyc
Normal file
BIN
utils/__pycache__/flashinfer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/func_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/func_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/gc_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/gc_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/hashing.cpython-312.pyc
Normal file
BIN
utils/__pycache__/hashing.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/import_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/import_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/jsontree.cpython-312.pyc
Normal file
BIN
utils/__pycache__/jsontree.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/math_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/math_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/mem_constants.cpython-312.pyc
Normal file
BIN
utils/__pycache__/mem_constants.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/mem_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/mem_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/nccl.cpython-312.pyc
Normal file
BIN
utils/__pycache__/nccl.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/network_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/network_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/platform_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/platform_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/profiling.cpython-312.pyc
Normal file
BIN
utils/__pycache__/profiling.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/registry.cpython-312.pyc
Normal file
BIN
utils/__pycache__/registry.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/serial_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/serial_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/system_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/system_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/tensor_schema.cpython-312.pyc
Normal file
BIN
utils/__pycache__/tensor_schema.cpython-312.pyc
Normal file
Binary file not shown.
BIN
utils/__pycache__/torch_utils.cpython-312.pyc
Normal file
BIN
utils/__pycache__/torch_utils.cpython-312.pyc
Normal file
Binary file not shown.
487
utils/argparse_utils.py
Normal file
487
utils/argparse_utils.py
Normal file
@@ -0,0 +1,487 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Argument parsing utilities for vLLM."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import textwrap
|
||||
from argparse import (
|
||||
Action,
|
||||
ArgumentDefaultsHelpFormatter,
|
||||
ArgumentParser,
|
||||
ArgumentTypeError,
|
||||
Namespace,
|
||||
RawDescriptionHelpFormatter,
|
||||
_ArgumentGroup,
|
||||
)
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
import yaml
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
|
||||
"""SortedHelpFormatter that sorts arguments by their option strings."""
|
||||
|
||||
def _split_lines(self, text, width):
|
||||
"""
|
||||
1. Sentences split across lines have their single newlines removed.
|
||||
2. Paragraphs and explicit newlines are split into separate lines.
|
||||
3. Each line is wrapped to the specified width (width of terminal).
|
||||
"""
|
||||
# The patterns also include whitespace after the newline
|
||||
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
|
||||
multiple_newlines = re.compile(r"\n{2,}\s*")
|
||||
text = single_newline.sub(" ", text)
|
||||
lines = re.split(multiple_newlines, text)
|
||||
return sum([textwrap.wrap(line, width) for line in lines], [])
|
||||
|
||||
def add_arguments(self, actions):
|
||||
actions = sorted(actions, key=lambda x: x.option_strings)
|
||||
super().add_arguments(actions)
|
||||
|
||||
|
||||
class FlexibleArgumentParser(ArgumentParser):
|
||||
"""ArgumentParser that allows both underscore and dash in names."""
|
||||
|
||||
_deprecated: set[Action] = set()
|
||||
_json_tip: str = (
|
||||
"When passing JSON CLI arguments, the following sets of arguments "
|
||||
"are equivalent:\n"
|
||||
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
|
||||
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
|
||||
"Additionally, list elements can be passed individually using +:\n"
|
||||
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
|
||||
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
|
||||
)
|
||||
_search_keyword: str | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Set the default "formatter_class" to SortedHelpFormatter
|
||||
if "formatter_class" not in kwargs:
|
||||
kwargs["formatter_class"] = SortedHelpFormatter
|
||||
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
|
||||
self.add_json_tip = kwargs.pop("add_json_tip", True)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if sys.version_info < (3, 13):
|
||||
# Enable the deprecated kwarg for Python 3.12 and below
|
||||
|
||||
def parse_known_args(self, args=None, namespace=None):
|
||||
if args is not None and "--disable-log-requests" in args:
|
||||
# Special case warning because the warning below won't trigger
|
||||
# if –-disable-log-requests because its value is default.
|
||||
logger.warning_once(
|
||||
"argument '--disable-log-requests' is deprecated and "
|
||||
"replaced with '--enable-log-requests'. This will be "
|
||||
"removed in v0.12.0."
|
||||
)
|
||||
namespace, args = super().parse_known_args(args, namespace)
|
||||
for action in FlexibleArgumentParser._deprecated:
|
||||
if (
|
||||
hasattr(namespace, dest := action.dest)
|
||||
and getattr(namespace, dest) != action.default
|
||||
):
|
||||
logger.warning_once("argument '%s' is deprecated", dest)
|
||||
return namespace, args
|
||||
|
||||
def add_argument(self, *args, **kwargs):
|
||||
deprecated = kwargs.pop("deprecated", False)
|
||||
action = super().add_argument(*args, **kwargs)
|
||||
if deprecated:
|
||||
FlexibleArgumentParser._deprecated.add(action)
|
||||
return action
|
||||
|
||||
class _FlexibleArgumentGroup(_ArgumentGroup):
|
||||
def add_argument(self, *args, **kwargs):
|
||||
deprecated = kwargs.pop("deprecated", False)
|
||||
action = super().add_argument(*args, **kwargs)
|
||||
if deprecated:
|
||||
FlexibleArgumentParser._deprecated.add(action)
|
||||
return action
|
||||
|
||||
def add_argument_group(self, *args, **kwargs):
|
||||
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
|
||||
self._action_groups.append(group)
|
||||
return group
|
||||
|
||||
def format_help(self):
|
||||
# Only use custom help formatting for bottom level parsers
|
||||
if self._subparsers is not None:
|
||||
return super().format_help()
|
||||
|
||||
formatter = self._get_formatter()
|
||||
|
||||
# Handle keyword search of the args
|
||||
if (search_keyword := self._search_keyword) is not None:
|
||||
# Normalise the search keyword
|
||||
search_keyword = search_keyword.lower().replace("_", "-")
|
||||
# Return full help if searching for 'all'
|
||||
if search_keyword == "all":
|
||||
self.epilog = self._json_tip
|
||||
return super().format_help()
|
||||
|
||||
# Return group help if searching for a group title
|
||||
for group in self._action_groups:
|
||||
if group.title and group.title.lower() == search_keyword:
|
||||
formatter.start_section(group.title)
|
||||
formatter.add_text(group.description)
|
||||
formatter.add_arguments(group._group_actions)
|
||||
formatter.end_section()
|
||||
formatter.add_text(self._json_tip)
|
||||
return formatter.format_help()
|
||||
|
||||
# Return matched args if searching for an arg name
|
||||
matched_actions = []
|
||||
for group in self._action_groups:
|
||||
for action in group._group_actions:
|
||||
# search option name
|
||||
if any(
|
||||
search_keyword in opt.lower() for opt in action.option_strings
|
||||
):
|
||||
matched_actions.append(action)
|
||||
if matched_actions:
|
||||
formatter.start_section(f"Arguments matching '{search_keyword}'")
|
||||
formatter.add_arguments(matched_actions)
|
||||
formatter.end_section()
|
||||
formatter.add_text(self._json_tip)
|
||||
return formatter.format_help()
|
||||
|
||||
# No match found
|
||||
formatter.add_text(
|
||||
f"No group or arguments matching '{search_keyword}'.\n"
|
||||
"Use '--help' to see available groups or "
|
||||
"'--help=all' to see all available parameters."
|
||||
)
|
||||
return formatter.format_help()
|
||||
|
||||
# usage
|
||||
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
|
||||
|
||||
# description
|
||||
formatter.add_text(self.description)
|
||||
|
||||
# positionals, optionals and user-defined groups
|
||||
formatter.start_section("Config Groups")
|
||||
config_groups = ""
|
||||
for group in self._action_groups:
|
||||
if not group._group_actions:
|
||||
continue
|
||||
title = group.title
|
||||
description = group.description or ""
|
||||
config_groups += f"{title: <24}{description}\n"
|
||||
formatter.add_text(config_groups)
|
||||
formatter.end_section()
|
||||
|
||||
# epilog
|
||||
formatter.add_text(self.epilog)
|
||||
|
||||
# determine help from format above
|
||||
return formatter.format_help()
|
||||
|
||||
def parse_args( # type: ignore[override]
|
||||
self,
|
||||
args: list[str] | None = None,
|
||||
namespace: Namespace | None = None,
|
||||
):
|
||||
if args is None:
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Check for --model in command line arguments first
|
||||
if args and args[0] == "serve":
|
||||
try:
|
||||
model_idx = next(
|
||||
i
|
||||
for i, arg in enumerate(args)
|
||||
if arg == "--model" or arg.startswith("--model=")
|
||||
)
|
||||
logger.warning(
|
||||
"With `vllm serve`, you should provide the model as a "
|
||||
"positional argument or in a config file instead of via "
|
||||
"the `--model` option. "
|
||||
"The `--model` option will be removed in v0.13."
|
||||
)
|
||||
|
||||
if args[model_idx] == "--model":
|
||||
model_tag = args[model_idx + 1]
|
||||
rest_start_idx = model_idx + 2
|
||||
else:
|
||||
model_tag = args[model_idx].removeprefix("--model=")
|
||||
rest_start_idx = model_idx + 1
|
||||
|
||||
# Move <model> to the front, e,g:
|
||||
# [Before]
|
||||
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
|
||||
# [After]
|
||||
# vllm serve <model> -tp 2 --enforce-eager --port 8001
|
||||
args = [
|
||||
"serve",
|
||||
model_tag,
|
||||
*args[1:model_idx],
|
||||
*args[rest_start_idx:],
|
||||
]
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
if "--config" in args:
|
||||
args = self._pull_args_from_config(args)
|
||||
|
||||
def repl(match: re.Match) -> str:
|
||||
"""Replaces underscores with dashes in the matched string."""
|
||||
return match.group(0).replace("_", "-")
|
||||
|
||||
# Everything between the first -- and the first .
|
||||
pattern = re.compile(r"(?<=--)[^\.]*")
|
||||
|
||||
# Convert underscores to dashes and vice versa in argument names
|
||||
processed_args = list[str]()
|
||||
for i, arg in enumerate(args):
|
||||
if arg.startswith("--help="):
|
||||
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
|
||||
processed_args.append("--help")
|
||||
elif arg.startswith("--"):
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
key = pattern.sub(repl, key, count=1)
|
||||
processed_args.append(f"{key}={value}")
|
||||
else:
|
||||
key = pattern.sub(repl, arg, count=1)
|
||||
processed_args.append(key)
|
||||
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
|
||||
# allow -O flag to be used without space, e.g. -O3 or -Odecode
|
||||
# -O.<...> handled later
|
||||
# also handle -O=<mode> here
|
||||
mode = arg[3:] if arg[2] == "=" else arg[2:]
|
||||
processed_args.append(f"-O.mode={mode}")
|
||||
elif (
|
||||
arg == "-O"
|
||||
and i + 1 < len(args)
|
||||
and args[i + 1] in {"0", "1", "2", "3"}
|
||||
):
|
||||
# Convert -O <n> to -O.mode <n>
|
||||
processed_args.append("-O.mode")
|
||||
else:
|
||||
processed_args.append(arg)
|
||||
|
||||
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
|
||||
"""Creates a nested dictionary from a list of keys and a value.
|
||||
|
||||
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
|
||||
`{"a": {"b": {"c": 1}}}`
|
||||
"""
|
||||
nested_dict: Any = value
|
||||
for key in reversed(keys):
|
||||
nested_dict = {key: nested_dict}
|
||||
return nested_dict
|
||||
|
||||
def recursive_dict_update(
|
||||
original: dict[str, Any],
|
||||
update: dict[str, Any],
|
||||
) -> set[str]:
|
||||
"""Recursively updates a dictionary with another dictionary.
|
||||
Returns a set of duplicate keys that were overwritten.
|
||||
"""
|
||||
duplicates = set[str]()
|
||||
for k, v in update.items():
|
||||
if isinstance(v, dict) and isinstance(original.get(k), dict):
|
||||
nested_duplicates = recursive_dict_update(original[k], v)
|
||||
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
|
||||
elif isinstance(v, list) and isinstance(original.get(k), list):
|
||||
original[k] += v
|
||||
else:
|
||||
if k in original:
|
||||
duplicates.add(k)
|
||||
original[k] = v
|
||||
return duplicates
|
||||
|
||||
delete = set[int]()
|
||||
dict_args = defaultdict[str, dict[str, Any]](dict)
|
||||
duplicates = set[str]()
|
||||
for i, processed_arg in enumerate(processed_args):
|
||||
if i in delete: # skip if value from previous arg
|
||||
continue
|
||||
|
||||
if processed_arg.startswith("-") and "." in processed_arg:
|
||||
if "=" in processed_arg:
|
||||
processed_arg, value_str = processed_arg.split("=", 1)
|
||||
if "." not in processed_arg:
|
||||
# False positive, '.' was only in the value
|
||||
continue
|
||||
else:
|
||||
value_str = processed_args[i + 1]
|
||||
delete.add(i + 1)
|
||||
|
||||
if processed_arg.endswith("+"):
|
||||
processed_arg = processed_arg[:-1]
|
||||
value_str = json.dumps(list(value_str.split(",")))
|
||||
|
||||
key, *keys = processed_arg.split(".")
|
||||
try:
|
||||
value = json.loads(value_str)
|
||||
except json.decoder.JSONDecodeError:
|
||||
value = value_str
|
||||
|
||||
# Merge all values with the same key into a single dict
|
||||
arg_dict = create_nested_dict(keys, value)
|
||||
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
|
||||
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
|
||||
delete.add(i)
|
||||
# Filter out the dict args we set to None
|
||||
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
|
||||
if duplicates:
|
||||
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
|
||||
|
||||
# Add the dict args back as if they were originally passed as JSON
|
||||
for dict_arg, dict_value in dict_args.items():
|
||||
processed_args.append(dict_arg)
|
||||
processed_args.append(json.dumps(dict_value))
|
||||
|
||||
return super().parse_args(processed_args, namespace)
|
||||
|
||||
def check_port(self, value):
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
msg = "Port must be an integer"
|
||||
raise ArgumentTypeError(msg) from None
|
||||
|
||||
if not (1024 <= value <= 65535):
|
||||
raise ArgumentTypeError("Port must be between 1024 and 65535")
|
||||
|
||||
return value
|
||||
|
||||
def _pull_args_from_config(self, args: list[str]) -> list[str]:
|
||||
"""Method to pull arguments specified in the config file
|
||||
into the command-line args variable.
|
||||
|
||||
The arguments in config file will be inserted between
|
||||
the argument list.
|
||||
|
||||
example:
|
||||
```yaml
|
||||
port: 12323
|
||||
tensor-parallel-size: 4
|
||||
```
|
||||
```python
|
||||
$: vllm {serve,chat,complete} "facebook/opt-12B" \
|
||||
--config config.yaml -tp 2
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--config', 'config.yaml',
|
||||
'-tp', '2'
|
||||
]
|
||||
$: args = [
|
||||
"serve,chat,complete",
|
||||
"facebook/opt-12B",
|
||||
'--port', '12323',
|
||||
'--tensor-parallel-size', '4',
|
||||
'-tp', '2'
|
||||
]
|
||||
```
|
||||
|
||||
Please note how the config args are inserted after the sub command.
|
||||
this way the order of priorities is maintained when these are args
|
||||
parsed by super().
|
||||
"""
|
||||
assert args.count("--config") <= 1, "More than one config file specified!"
|
||||
|
||||
index = args.index("--config")
|
||||
if index == len(args) - 1:
|
||||
raise ValueError(
|
||||
"No config file specified! \
|
||||
Please check your command-line arguments."
|
||||
)
|
||||
|
||||
file_path = args[index + 1]
|
||||
|
||||
config_args = self.load_config_file(file_path)
|
||||
|
||||
# 0th index might be the sub command {serve,chat,complete,...}
|
||||
# optionally followed by model_tag (only for serve)
|
||||
# followed by config args
|
||||
# followed by rest of cli args.
|
||||
# maintaining this order will enforce the precedence
|
||||
# of cli > config > defaults
|
||||
if args[0].startswith("-"):
|
||||
# No sub command (e.g., api_server entry point)
|
||||
args = config_args + args[0:index] + args[index + 2 :]
|
||||
elif args[0] == "serve":
|
||||
model_in_cli = len(args) > 1 and not args[1].startswith("-")
|
||||
model_in_config = any(arg == "--model" for arg in config_args)
|
||||
|
||||
if not model_in_cli and not model_in_config:
|
||||
raise ValueError(
|
||||
"No model specified! Please specify model either "
|
||||
"as a positional argument or in a config file."
|
||||
)
|
||||
|
||||
if model_in_cli:
|
||||
# Model specified as positional arg, keep CLI version
|
||||
args = (
|
||||
[args[0]]
|
||||
+ [args[1]]
|
||||
+ config_args
|
||||
+ args[2:index]
|
||||
+ args[index + 2 :]
|
||||
)
|
||||
else:
|
||||
# No model in CLI, use config if available
|
||||
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
|
||||
else:
|
||||
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
|
||||
|
||||
return args
|
||||
|
||||
def load_config_file(self, file_path: str) -> list[str]:
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
flattened list with argparse like pattern
|
||||
```yaml
|
||||
port: 12323
|
||||
tensor-parallel-size: 4
|
||||
```
|
||||
returns:
|
||||
processed_args: list[str] = [
|
||||
'--port': '12323',
|
||||
'--tensor-parallel-size': '4'
|
||||
]
|
||||
"""
|
||||
extension: str = file_path.split(".")[-1]
|
||||
if extension not in ("yaml", "yml"):
|
||||
raise ValueError(
|
||||
f"Config file must be of a yaml/yml type. {extension} supplied"
|
||||
)
|
||||
|
||||
# only expecting a flat dictionary of atomic types
|
||||
processed_args: list[str] = []
|
||||
|
||||
config: dict[str, int | str] = {}
|
||||
try:
|
||||
with open(file_path) as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
except Exception as ex:
|
||||
logger.error(
|
||||
"Unable to read the config file at %s. Check path correctness",
|
||||
file_path,
|
||||
)
|
||||
raise ex
|
||||
|
||||
for key, value in config.items():
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
processed_args.append("--" + key)
|
||||
elif isinstance(value, list):
|
||||
if value:
|
||||
processed_args.append("--" + key)
|
||||
for item in value:
|
||||
processed_args.append(str(item))
|
||||
else:
|
||||
processed_args.append("--" + key)
|
||||
processed_args.append(str(value))
|
||||
|
||||
return processed_args
|
||||
303
utils/async_utils.py
Normal file
303
utils/async_utils.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers related to asynchronous code.
|
||||
|
||||
This is similar in concept to the `asyncio` module.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import TypeVar
|
||||
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AsyncMicrobatchTokenizer:
|
||||
"""Asynchronous tokenizer with micro-batching.
|
||||
|
||||
Pulls pending encode/decode requests from a queue and batches them
|
||||
up to reduce overhead. A single-thread ThreadPoolExecutor is used
|
||||
so the event loop stays responsive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
max_batch_size: int = 32,
|
||||
batch_wait_timeout_s: float = 0.002,
|
||||
) -> None:
|
||||
self.tokenizer = tokenizer
|
||||
self.max_batch_size = max_batch_size
|
||||
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._queues: dict[
|
||||
tuple,
|
||||
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
|
||||
] = {}
|
||||
self._batcher_tasks: list[Task] = []
|
||||
|
||||
# Single-thread executor for blocking tokenizer calls.
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
# === Public async API ===
|
||||
async def __call__(self, prompt, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("encode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((prompt, kwargs, result_future))
|
||||
return await result_future
|
||||
|
||||
async def decode(self, token_ids, **kwargs):
|
||||
result_future: Future = self._loop.create_future()
|
||||
key = self._queue_key("decode", kwargs)
|
||||
queue = self._get_queue(self._loop, key)
|
||||
await queue.put((token_ids, result_future))
|
||||
return await result_future
|
||||
|
||||
# === Internal helpers ===
|
||||
def _get_queue(
|
||||
self, loop: asyncio.AbstractEventLoop, key: tuple
|
||||
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
|
||||
"""Get the request queue for the given operation key, creating a new
|
||||
queue and batcher task if needed."""
|
||||
queue = self._queues.get(key)
|
||||
if queue is None:
|
||||
self._queues[key] = queue = asyncio.Queue()
|
||||
if key[0] == "encode":
|
||||
can_batch = key[1] != "other"
|
||||
coro = self._batch_encode_loop(queue, can_batch)
|
||||
else:
|
||||
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
|
||||
coro = self._batch_decode_loop(queue)
|
||||
self._batcher_tasks.append(loop.create_task(coro))
|
||||
return queue
|
||||
|
||||
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
|
||||
"""Batch incoming encode requests for efficiency."""
|
||||
while True:
|
||||
prompt, kwargs, result_future = await queue.get()
|
||||
prompts = [prompt]
|
||||
kwargs_list = [kwargs]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(prompts) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
prompts.append(prompt)
|
||||
result_futures.append(result_future)
|
||||
if not can_batch:
|
||||
kwargs_list.append(kwargs)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# If every request uses identical kwargs we can run a single
|
||||
# batched tokenizer call for a big speed-up.
|
||||
if can_batch and len(prompts) > 1:
|
||||
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, batch_encode_fn
|
||||
)
|
||||
|
||||
for i, fut in enumerate(result_futures):
|
||||
if not fut.done():
|
||||
data = {k: v[i] for k, v in results.items()}
|
||||
fut.set_result(BatchEncoding(data))
|
||||
else:
|
||||
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
|
||||
]
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, encode_fn
|
||||
)
|
||||
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
async def _batch_decode_loop(self, queue: asyncio.Queue):
|
||||
"""Batch incoming decode requests for efficiency."""
|
||||
while True:
|
||||
token_ids, result_future = await queue.get()
|
||||
token_ids_list = [token_ids]
|
||||
result_futures = [result_future]
|
||||
deadline = self._loop.time() + self.batch_wait_timeout_s
|
||||
|
||||
while len(token_ids_list) < self.max_batch_size:
|
||||
timeout = deadline - self._loop.time()
|
||||
if timeout <= 0:
|
||||
break
|
||||
try:
|
||||
token_ids, result_future = await asyncio.wait_for(
|
||||
queue.get(), timeout
|
||||
)
|
||||
token_ids_list.append(token_ids)
|
||||
result_futures.append(result_future)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
try:
|
||||
# Perform a single batched decode call for all requests
|
||||
results = await self._loop.run_in_executor(
|
||||
self._executor, self.tokenizer.batch_decode, token_ids_list
|
||||
)
|
||||
for fut, res in zip(result_futures, results):
|
||||
if not fut.done():
|
||||
fut.set_result(res)
|
||||
except Exception as e:
|
||||
for fut in result_futures:
|
||||
if not fut.done():
|
||||
fut.set_exception(e)
|
||||
|
||||
def _queue_key(self, op: str, kwargs: dict) -> tuple:
|
||||
"""
|
||||
Return a normalized key describing operation + kwargs.
|
||||
|
||||
- `add_special_tokens`: {True/False}
|
||||
- `truncation`: {True/False}
|
||||
- If `truncation` is False (`max_length` is None),
|
||||
returns a key for a can_batch queue.
|
||||
- If `truncation` is True and `max_length` is None or equals
|
||||
`tokenizer.model_max_length`, returns a key for a can_batch queue.
|
||||
- Otherwise, returns a key for a cannot_batch queue.
|
||||
|
||||
Examples:
|
||||
- Decode: ("decode",)
|
||||
- Encode typical:
|
||||
("encode", add_special_tokens, bool_truncation, max_length_label)
|
||||
- Fallback: ("encode", "other")
|
||||
"""
|
||||
|
||||
if op == "decode":
|
||||
return ("decode",)
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens", True)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
if not truncation:
|
||||
return "encode", add_special_tokens, False, None
|
||||
|
||||
model_max = getattr(self.tokenizer, "model_max_length", None)
|
||||
if max_length is None or (model_max is not None and max_length == model_max):
|
||||
return "encode", add_special_tokens, True, "model_max"
|
||||
|
||||
return "encode", "other"
|
||||
|
||||
def __del__(self):
|
||||
if (
|
||||
(tasks := getattr(self, "_batcher_tasks", None))
|
||||
and (loop := getattr(self, "_loop", None))
|
||||
and not loop.is_closed()
|
||||
):
|
||||
|
||||
def cancel_tasks():
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
loop.call_soon_threadsafe(cancel_tasks)
|
||||
|
||||
|
||||
def cancel_task_threadsafe(task: Task):
|
||||
if task and not task.done():
|
||||
run_in_loop(task.get_loop(), task.cancel)
|
||||
|
||||
|
||||
def make_async(
|
||||
func: Callable[P, T],
|
||||
executor: Executor | None = None,
|
||||
) -> Callable[P, Awaitable[T]]:
|
||||
"""
|
||||
Take a blocking function, and run it on in an executor thread.
|
||||
|
||||
This function prevents the blocking function from blocking the
|
||||
asyncio event loop.
|
||||
The code in this function needs to be thread safe.
|
||||
"""
|
||||
|
||||
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
|
||||
loop = asyncio.get_event_loop()
|
||||
p_func = partial(func, *args, **kwargs)
|
||||
return loop.run_in_executor(executor=executor, func=p_func)
|
||||
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
|
||||
if in_loop(loop):
|
||||
function(*args)
|
||||
elif not loop.is_closed():
|
||||
loop.call_soon_threadsafe(function, *args)
|
||||
|
||||
|
||||
def in_loop(event_loop: AbstractEventLoop) -> bool:
|
||||
try:
|
||||
return asyncio.get_running_loop() == event_loop
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
) -> AsyncGenerator[tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
if len(iterators) == 1:
|
||||
# Fast-path single iterator case.
|
||||
async for item in iterators[0]:
|
||||
yield 0, item
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
|
||||
try:
|
||||
while awaits:
|
||||
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[loop.create_task(anext(it))] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
||||
214
utils/cache.py
Normal file
214
utils/cache.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import NamedTuple, TypeVar, cast, overload
|
||||
|
||||
import cachetools
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class _Sentinel: ...
|
||||
|
||||
|
||||
ALL_PINNED_SENTINEL = _Sentinel()
|
||||
|
||||
|
||||
class _MappingOrderCacheView(UserDict[_K, _V]):
|
||||
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
|
||||
super().__init__(data)
|
||||
self.ordered_keys = ordered_keys
|
||||
|
||||
def __iter__(self) -> Iterator[_K]:
|
||||
return iter(self.ordered_keys)
|
||||
|
||||
def keys(self) -> KeysView[_K]:
|
||||
return KeysView(self.ordered_keys)
|
||||
|
||||
|
||||
class CacheInfo(NamedTuple):
|
||||
hits: int
|
||||
total: int
|
||||
|
||||
@property
|
||||
def hit_ratio(self) -> float:
|
||||
if self.total == 0:
|
||||
return 0
|
||||
|
||||
return self.hits / self.total
|
||||
|
||||
def __sub__(self, other: "CacheInfo"):
|
||||
return CacheInfo(
|
||||
hits=self.hits - other.hits,
|
||||
total=self.total - other.total,
|
||||
)
|
||||
|
||||
|
||||
class LRUCache(cachetools.LRUCache[_K, _V]):
|
||||
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
|
||||
super().__init__(capacity, getsizeof)
|
||||
|
||||
self.pinned_items = set[_K]()
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
|
||||
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
|
||||
value = super().__getitem__(key)
|
||||
|
||||
if update_info:
|
||||
self._hits += 1
|
||||
self._total += 1
|
||||
|
||||
return value
|
||||
|
||||
def __delitem__(self, key: _K) -> None:
|
||||
run_on_remove = key in self
|
||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||
super().__delitem__(key)
|
||||
if key in self.pinned_items:
|
||||
# Todo: add warning to inform that del pinned item
|
||||
self._unpin(key)
|
||||
if run_on_remove:
|
||||
self._on_remove(key, value)
|
||||
|
||||
@property
|
||||
def cache(self) -> Mapping[_K, _V]:
|
||||
"""Return the internal cache dictionary in order (read-only)."""
|
||||
return _MappingOrderCacheView(
|
||||
self._Cache__data, # type: ignore
|
||||
self.order,
|
||||
)
|
||||
|
||||
@property
|
||||
def order(self) -> Mapping[_K, None]:
|
||||
"""Return the internal order dictionary (read-only)."""
|
||||
return MappingProxyType(self._LRUCache__order) # type: ignore
|
||||
|
||||
@property
|
||||
def capacity(self) -> float:
|
||||
return self.maxsize
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
if self.maxsize == 0:
|
||||
return 0
|
||||
|
||||
return self.currsize / self.maxsize
|
||||
|
||||
def stat(self, *, delta: bool = False) -> CacheInfo:
|
||||
"""
|
||||
Gets the cumulative number of hits and queries against this cache.
|
||||
|
||||
If `delta=True`, instead gets these statistics
|
||||
since the last call that also passed `delta=True`.
|
||||
"""
|
||||
info = CacheInfo(hits=self._hits, total=self._total)
|
||||
|
||||
if delta:
|
||||
info_delta = info - self._last_info
|
||||
self._last_info = info
|
||||
info = info_delta
|
||||
|
||||
return info
|
||||
|
||||
def touch(self, key: _K) -> None:
|
||||
try:
|
||||
self._LRUCache__order.move_to_end(key) # type: ignore
|
||||
except KeyError:
|
||||
self._LRUCache__order[key] = None # type: ignore
|
||||
|
||||
@overload
|
||||
def get(self, key: _K, /) -> _V | None: ...
|
||||
|
||||
@overload
|
||||
def get(self, key: _K, /, default: _V | _T) -> _V | _T: ...
|
||||
|
||||
def get(self, key: _K, /, default: _V | _T | None = None) -> _V | _T | None:
|
||||
value: _V | _T | None
|
||||
if key in self:
|
||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||
|
||||
self._hits += 1
|
||||
else:
|
||||
value = default
|
||||
|
||||
self._total += 1
|
||||
return value
|
||||
|
||||
@overload
|
||||
def pop(self, key: _K) -> _V: ...
|
||||
|
||||
@overload
|
||||
def pop(self, key: _K, default: _V | _T) -> _V | _T: ...
|
||||
|
||||
def pop(self, key: _K, default: _V | _T | None = None) -> _V | _T | None:
|
||||
value: _V | _T | None
|
||||
if key not in self:
|
||||
return default
|
||||
|
||||
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
|
||||
self.__delitem__(key)
|
||||
return value
|
||||
|
||||
def put(self, key: _K, value: _V) -> None:
|
||||
self.__setitem__(key, value)
|
||||
|
||||
def pin(self, key: _K) -> None:
|
||||
"""
|
||||
Pins a key in the cache preventing it from being
|
||||
evicted in the LRU order.
|
||||
"""
|
||||
if key not in self:
|
||||
raise ValueError(f"Cannot pin key: {key} not in cache.")
|
||||
self.pinned_items.add(key)
|
||||
|
||||
def _unpin(self, key: _K) -> None:
|
||||
"""
|
||||
Unpins a key in the cache allowing it to be
|
||||
evicted in the LRU order.
|
||||
"""
|
||||
self.pinned_items.remove(key)
|
||||
|
||||
def _on_remove(self, key: _K, value: _V | None) -> None:
|
||||
pass
|
||||
|
||||
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
|
||||
if len(self) == 0:
|
||||
return
|
||||
|
||||
self.popitem(remove_pinned=remove_pinned)
|
||||
|
||||
def _remove_old_if_needed(self) -> None:
|
||||
while self.currsize > self.capacity:
|
||||
self.remove_oldest()
|
||||
|
||||
def popitem(self, remove_pinned: bool = False):
|
||||
"""Remove and return the `(key, value)` pair least recently used."""
|
||||
if not remove_pinned:
|
||||
# pop the oldest item in the cache that is not pinned
|
||||
lru_key = next(
|
||||
(key for key in self.order if key not in self.pinned_items),
|
||||
ALL_PINNED_SENTINEL,
|
||||
)
|
||||
if lru_key is ALL_PINNED_SENTINEL:
|
||||
raise RuntimeError(
|
||||
"All items are pinned, cannot remove oldest from the cache."
|
||||
)
|
||||
else:
|
||||
lru_key = next(iter(self.order))
|
||||
value = self.pop(cast(_K, lru_key))
|
||||
return (lru_key, value)
|
||||
|
||||
def clear(self) -> None:
|
||||
while len(self) > 0:
|
||||
self.remove_oldest(remove_pinned=True)
|
||||
|
||||
self._hits = 0
|
||||
self._total = 0
|
||||
self._last_info = CacheInfo(hits=0, total=0)
|
||||
139
utils/collection_utils.py
Normal file
139
utils/collection_utils.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers that are applied to collections.
|
||||
|
||||
This is similar in concept to the `collections` module.
|
||||
"""
|
||||
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
"""
|
||||
A registry that acts like a dictionary but searches for other classes
|
||||
in the MRO if the original class is not found.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return self.contains(key)
|
||||
|
||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
if strict:
|
||||
return key in self.data
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
"""
|
||||
Evaluates dictionary items only when they are accessed.
|
||||
|
||||
Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
"""
|
||||
|
||||
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
self._dict[key] = self._factory[key]()
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||
self._factory[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._factory)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||
"""Convert iterable to list, unless it's already a list."""
|
||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||
|
||||
|
||||
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
||||
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||
return [obj] # type: ignore[list-item]
|
||||
return obj
|
||||
|
||||
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: type[T] | tuple[type[T], ...],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[list[T]]:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if check == "first":
|
||||
return len(value) == 0 or isinstance(value[0], typ)
|
||||
elif check == "all":
|
||||
return all(isinstance(v, typ) for v in value)
|
||||
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike [`itertools.groupby`][], groups are not broken by
|
||||
non-contiguous data.
|
||||
"""
|
||||
groups = defaultdict[_K, list[_V]](list)
|
||||
|
||||
for value in values:
|
||||
groups[key(value)].append(value)
|
||||
|
||||
return groups.items()
|
||||
|
||||
|
||||
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||
"""Swap values between two keys."""
|
||||
v1 = obj.get(key1)
|
||||
v2 = obj.get(key2)
|
||||
if v1 is not None:
|
||||
obj[key2] = v1
|
||||
else:
|
||||
obj.pop(key2, None)
|
||||
if v2 is not None:
|
||||
obj[key1] = v2
|
||||
else:
|
||||
obj.pop(key1, None)
|
||||
45
utils/counter.py
Normal file
45
utils/counter.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self, start: int = 0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.counter = start
|
||||
|
||||
def __next__(self) -> int:
|
||||
i = self.counter
|
||||
self.counter += 1
|
||||
return i
|
||||
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class AtomicCounter:
|
||||
"""An atomic, thread-safe counter"""
|
||||
|
||||
def __init__(self, initial: int = 0) -> None:
|
||||
"""Initialize a new atomic counter to given initial value"""
|
||||
super().__init__()
|
||||
|
||||
self._value = initial
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def value(self) -> int:
|
||||
return self._value
|
||||
|
||||
def inc(self, num: int = 1) -> int:
|
||||
"""Atomically increment the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value += num
|
||||
return self._value
|
||||
|
||||
def dec(self, num: int = 1) -> int:
|
||||
"""Atomically decrement the counter by num and return the new value"""
|
||||
with self._lock:
|
||||
self._value -= num
|
||||
return self._value
|
||||
391
utils/deep_gemm.py
Normal file
391
utils/deep_gemm.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for DeepGEMM API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
|
||||
class DeepGemmQuantScaleFMT(Enum):
|
||||
# Float32 scales in Float32 tensor
|
||||
FLOAT32 = 0
|
||||
# Compute float32 scales and ceil the scales to UE8M0.
|
||||
# Keep the scales in Float32 tensor.
|
||||
FLOAT32_CEIL_UE8M0 = 1
|
||||
# Compute float32 scales and ceil the scales to UE8M0.
|
||||
# Pack the scales into a int32 tensor where each int32
|
||||
# element contains 4 scale values.
|
||||
UE8M0 = 2
|
||||
|
||||
@staticmethod
|
||||
def from_oracle() -> "DeepGemmQuantScaleFMT":
|
||||
if not is_deep_gemm_e8m0_used():
|
||||
return DeepGemmQuantScaleFMT.FLOAT32
|
||||
return (
|
||||
DeepGemmQuantScaleFMT.UE8M0
|
||||
if current_platform.is_device_capability(100)
|
||||
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_supported() -> bool:
|
||||
"""Return `True` if DeepGEMM is supported on the current platform.
|
||||
Currently, only Hopper and Blackwell GPUs are supported.
|
||||
"""
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability(100)
|
||||
)
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
|
||||
|
||||
@functools.cache
|
||||
def is_deep_gemm_e8m0_used() -> bool:
|
||||
"""Return `True` if vLLM is configured to use DeepGEMM "
|
||||
"E8M0 scale on a Hopper or Blackwell-class GPU.
|
||||
"""
|
||||
if not is_deep_gemm_supported():
|
||||
logger.debug_once(
|
||||
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
|
||||
)
|
||||
return False
|
||||
|
||||
_lazy_init()
|
||||
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
if envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
|
||||
return True
|
||||
|
||||
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
|
||||
return False
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable DeepGEMM backend."""
|
||||
raise RuntimeError(
|
||||
"DeepGEMM backend is not available or outdated. Please install or "
|
||||
"update the `deep_gemm` to a newer version to enable FP8 kernels."
|
||||
)
|
||||
|
||||
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
|
||||
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
global _get_mk_alignment_for_contiguous_layout_impl
|
||||
global _transform_sf_into_required_layout_impl
|
||||
# fast path
|
||||
if (
|
||||
_fp8_gemm_nt_impl is not None
|
||||
or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None
|
||||
or _get_mk_alignment_for_contiguous_layout_impl is not None
|
||||
or _transform_sf_into_required_layout_impl is not None
|
||||
):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
return
|
||||
|
||||
# Set up deep_gemm cache path
|
||||
DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
|
||||
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
|
||||
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, "deep_gemm"
|
||||
)
|
||||
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||
_dg, "get_paged_mqa_logits_metadata", None
|
||||
)
|
||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||
_dg, "get_mn_major_tma_aligned_tensor", None
|
||||
)
|
||||
_get_mk_alignment_for_contiguous_layout_impl = getattr(
|
||||
_dg, "get_mk_alignment_for_contiguous_layout", None
|
||||
)
|
||||
_transform_sf_into_required_layout_impl = getattr(
|
||||
_dg, "transform_sf_into_required_layout", None
|
||||
)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
_lazy_init()
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
return int(_dg.get_num_sms())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_mk_alignment_for_contiguous_layout() -> list[int]:
|
||||
_lazy_init()
|
||||
if _get_mk_alignment_for_contiguous_layout_impl is None:
|
||||
return _missing()
|
||||
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
|
||||
return [mk_align_size, mk_align_size]
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||
_lazy_init()
|
||||
if _get_mn_major_tma_aligned_tensor_impl is None:
|
||||
return _missing()
|
||||
return _get_mn_major_tma_aligned_tensor_impl(x)
|
||||
|
||||
|
||||
def fp8_gemm_nt(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _fp8_gemm_nt_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
if "is_deep_gemm_e8m0_used" in kwargs:
|
||||
use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
|
||||
del kwargs["is_deep_gemm_e8m0_used"]
|
||||
else:
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
|
||||
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def transform_sf_into_required_layout(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _transform_sf_into_required_layout_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _transform_sf_into_required_layout_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def get_paged_mqa_logits_metadata(
|
||||
context_lens: torch.Tensor, block_size: int, num_sms: int
|
||||
) -> torch.Tensor:
|
||||
"""Build scheduling metadata for paged MQA logits.
|
||||
|
||||
Args:
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
per batch element.
|
||||
block_size: KV-cache block size in tokens (e.g., 64).
|
||||
num_sms: Number of SMs available. 132 for Hopper
|
||||
|
||||
Returns:
|
||||
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
||||
schedule work across SMs.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _get_paged_mqa_logits_metadata_impl is None:
|
||||
return _missing()
|
||||
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
|
||||
|
||||
|
||||
def fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_paged_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_paged_mqa_logits_impl(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=True,
|
||||
)
|
||||
|
||||
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def _align(x: int, y: int) -> int:
|
||||
return cdiv(x, y) * y
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
block_m, block_n = block_size
|
||||
x_padded = torch.zeros(
|
||||
(_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
"""Return a global difference metric for unit tests.
|
||||
|
||||
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
|
||||
error, causing `torch.testing.assert_close` to fail. Instead of checking
|
||||
every element, we compute a cosine-style similarity over the whole tensor
|
||||
and report `1 - sim`. Once kernel accuracy improves this helper can be
|
||||
removed.
|
||||
"""
|
||||
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
def should_use_deepgemm_for_fp8_linear(
|
||||
output_dtype: torch.dtype,
|
||||
weight: torch.Tensor,
|
||||
supports_deep_gemm: bool | None = None,
|
||||
):
|
||||
if supports_deep_gemm is None:
|
||||
supports_deep_gemm = is_deep_gemm_supported()
|
||||
return (
|
||||
supports_deep_gemm
|
||||
and output_dtype == torch.bfloat16
|
||||
and weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"fp8_mqa_logits",
|
||||
"fp8_paged_mqa_logits",
|
||||
"get_paged_mqa_logits_metadata",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
"get_mk_alignment_for_contiguous_layout",
|
||||
]
|
||||
490
utils/flashinfer.py
Normal file
490
utils/flashinfer.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compatibility wrapper for FlashInfer API changes.
|
||||
|
||||
Users of vLLM should always import **only** these wrappers.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This is the storage path for the cubins, it can be replaced
|
||||
# with a local path for testing.
|
||||
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
||||
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
||||
"FLASHINFER_CUBINS_REPOSITORY",
|
||||
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cubin() -> bool:
|
||||
"""Return `True` if flashinfer-cubin package is available."""
|
||||
if envs.VLLM_HAS_FLASHINFER_CUBIN:
|
||||
return True
|
||||
if importlib.util.find_spec("flashinfer_cubin") is not None:
|
||||
return True
|
||||
logger.debug_once("flashinfer-cubin package was not found")
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer() -> bool:
|
||||
"""Return `True` if flashinfer-python package is available."""
|
||||
# Use find_spec to check if the module exists without importing it
|
||||
# This avoids potential CUDA initialization side effects
|
||||
if importlib.util.find_spec("flashinfer") is None:
|
||||
logger.debug_once("FlashInfer unavailable since package was not found")
|
||||
return False
|
||||
# When not using flashinfer cubin,
|
||||
# Also check if nvcc is available since it's required to JIT compile flashinfer
|
||||
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
|
||||
logger.debug_once(
|
||||
"FlashInfer unavailable since nvcc was not found "
|
||||
"and not using pre-downloaded cubins"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
"""Placeholder for unavailable FlashInfer backend."""
|
||||
raise RuntimeError(
|
||||
"FlashInfer backend is not available. Please install the package "
|
||||
"to enable FlashInfer kernels: "
|
||||
"https://github.com/flashinfer-ai/flashinfer"
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule(module_name: str) -> Any | None:
|
||||
"""Safely import a submodule and return it, or None if not available."""
|
||||
try:
|
||||
return importlib.import_module(module_name)
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
# General lazy import wrapper
|
||||
def _lazy_import_wrapper(
|
||||
module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
|
||||
):
|
||||
"""Create a lazy import wrapper for a specific function."""
|
||||
|
||||
@functools.cache
|
||||
def _get_impl():
|
||||
if not has_flashinfer():
|
||||
return None
|
||||
mod = _get_submodule(module_name)
|
||||
return getattr(mod, attr_name, None) if mod else None
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
impl = _get_impl()
|
||||
if impl is None:
|
||||
return fallback_fn(*args, **kwargs)
|
||||
return impl(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# Create lazy wrappers for each function
|
||||
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
|
||||
)
|
||||
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
|
||||
)
|
||||
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
|
||||
"flashinfer.fused_moe", "cutlass_fused_moe"
|
||||
)
|
||||
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
|
||||
nvfp4_block_scale_interleave = _lazy_import_wrapper(
|
||||
"flashinfer", "nvfp4_block_scale_interleave"
|
||||
)
|
||||
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
|
||||
"flashinfer", "trtllm_fp4_block_scale_moe"
|
||||
)
|
||||
|
||||
# Special case for autotune since it returns a context manager
|
||||
autotune = _lazy_import_wrapper(
|
||||
"flashinfer.autotuner",
|
||||
"autotune",
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_comm() -> bool:
|
||||
"""Return `True` if FlashInfer comm module is available."""
|
||||
return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_all2all() -> bool:
|
||||
"""Return `True` if FlashInfer mnnvl all2all is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.comm", "Mapping"),
|
||||
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return `True` if FlashInfer MoE module is available."""
|
||||
return (
|
||||
has_flashinfer()
|
||||
and importlib.util.find_spec("flashinfer.fused_moe") is not None
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "cutlass_fused_moe"),
|
||||
("flashinfer", "fp4_quantize"),
|
||||
("flashinfer", "nvfp4_block_scale_interleave"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return `True` if NVIDIA's artifactory is accessible.
|
||||
|
||||
This checks connectivity to the kernel inference library artifactory
|
||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||
"""
|
||||
# If we have pre-downloaded cubins, we can assume the cubins are available.
|
||||
if has_flashinfer_cubin():
|
||||
return True
|
||||
|
||||
try:
|
||||
# Use a short timeout to avoid blocking for too long
|
||||
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||
accessible = response.status_code == 200
|
||||
if accessible:
|
||||
logger.debug_once("NVIDIA artifactory is accessible")
|
||||
else:
|
||||
logger.warning_once(
|
||||
"NVIDIA artifactory returned failed status code: %d",
|
||||
response.status_code,
|
||||
)
|
||||
return accessible
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
@functools.cache
|
||||
def supports_trtllm_attention() -> bool:
|
||||
"""
|
||||
TRTLLM attention is supported if the platform is SM100,
|
||||
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
|
||||
"""
|
||||
# Batch-invariant mode disables TRTLLM attention
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
|
||||
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
|
||||
return env_value
|
||||
|
||||
|
||||
def force_use_trtllm_attention() -> bool | None:
|
||||
"""
|
||||
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
|
||||
return `True` if TRTLLM attention is forced to be used,
|
||||
return `False` if TRTLLM attention is forced to be not used.
|
||||
"""
|
||||
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
|
||||
|
||||
|
||||
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
|
||||
"""Check if the current configuration supports TRTLLM attention."""
|
||||
if force_use_trtllm_attention() is False:
|
||||
return False
|
||||
has_trtllm = supports_trtllm_attention()
|
||||
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
|
||||
|
||||
|
||||
def use_trtllm_attention(
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
dcp_world_size: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
has_sinks: bool = False,
|
||||
has_spec: bool = False,
|
||||
) -> bool:
|
||||
"""Return `True` if TRTLLM attention is used."""
|
||||
force_use_trtllm = force_use_trtllm_attention()
|
||||
|
||||
# Environment variable is set to 0 - respect it
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
# Decode context parallel is not supported
|
||||
if dcp_world_size > 1:
|
||||
logger.warning_once(
|
||||
"Trtllm does not support returning LSE and as a result "
|
||||
"does not support DCP, reverting to FlashInfer"
|
||||
)
|
||||
return False
|
||||
|
||||
# The platform is not supported
|
||||
if not supports_trtllm_attention():
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported on this platform, "
|
||||
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
# The combination of query and key heads is not supported
|
||||
if num_qo_heads % num_kv_heads != 0:
|
||||
if force_use_trtllm:
|
||||
logger.warning_once(
|
||||
"TRTLLM attention is not supported for this combination of "
|
||||
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
|
||||
)
|
||||
return False
|
||||
|
||||
if has_spec and not is_prefill:
|
||||
# Speculative decoding requires TRTLLM attention for decodes
|
||||
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
|
||||
return True
|
||||
|
||||
# Must use TRTLLM attention if query is FP8 quantized
|
||||
if q_dtype == current_platform.fp8_dtype():
|
||||
logger.info_once("Using TRTLLM attention (query is quantized).")
|
||||
return True
|
||||
|
||||
# If sinks are being used, we must use TRTLLM attention as it's
|
||||
# the only backend that supports them
|
||||
if has_sinks:
|
||||
logger.info_once("Using TRTLLM attention (required for attention sinks).")
|
||||
return True
|
||||
|
||||
if force_use_trtllm is None:
|
||||
# Environment variable not set - use auto-detection
|
||||
if is_prefill:
|
||||
# Prefill auto-detection
|
||||
use_trtllm = kv_cache_dtype == "auto"
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
|
||||
else:
|
||||
# Decode auto-detection
|
||||
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
|
||||
if use_trtllm:
|
||||
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
# Environment variable is set to 1 - respect it
|
||||
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
|
||||
return True
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def flashinfer_mm_fp4(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
|
||||
|
||||
return flashinfer_mm_fp4_(
|
||||
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
|
||||
)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
)
|
||||
def flashinfer_mm_fp4_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
g_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::bmm_fp8",
|
||||
mutates_args=[],
|
||||
device_types="cuda",
|
||||
)
|
||||
def bmm_fp8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
from flashinfer import bmm_fp8 as bmm_fp8_
|
||||
|
||||
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"vllm::bmm_fp8",
|
||||
)
|
||||
def bmm_fp8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||
assert a.shape[1] == b.shape[1]
|
||||
|
||||
if backend == "cutlass":
|
||||
block_scale_a = block_scale_a.view(torch.uint8)
|
||||
block_scale_b = block_scale_b.view(torch.uint8)
|
||||
|
||||
return flashinfer_mm_fp4(
|
||||
a,
|
||||
b.t(),
|
||||
block_scale_a,
|
||||
block_scale_b.t(),
|
||||
alpha,
|
||||
out_dtype,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def flashinfer_scaled_fp8_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
assert a.shape[1] == b.shape[0]
|
||||
assert scale_a.numel() == 1 and scale_b.numel() == 1
|
||||
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
|
||||
assert a.device.type == "cuda" and b.device.type == "cuda"
|
||||
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
|
||||
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
|
||||
|
||||
output = bmm_fp8(
|
||||
a.unsqueeze(0),
|
||||
b.unsqueeze(0),
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype,
|
||||
"auto",
|
||||
).view(a.shape[0], b.shape[1])
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
|
||||
@functools.cache
|
||||
def flashinfer_disable_q_quantization() -> bool:
|
||||
"""Cache result which only depends on the environment"""
|
||||
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
"flashinfer_cutlass_fused_moe",
|
||||
"flashinfer_fp4_quantize",
|
||||
"nvfp4_block_scale_interleave",
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
"can_use_trtllm_attention",
|
||||
"use_trtllm_attention",
|
||||
"flashinfer_disable_q_quantization",
|
||||
"flashinfer_scaled_fp4_mm",
|
||||
"flashinfer_scaled_fp8_mm",
|
||||
]
|
||||
236
utils/func_utils.py
Normal file
236
utils/func_utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers that are applied to functions.
|
||||
|
||||
This is similar in concept to the `functools` module.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from collections.abc import Callable, Mapping
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def identity(value: T, **kwargs) -> T:
|
||||
"""Returns the first provided value."""
|
||||
return value
|
||||
|
||||
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
|
||||
if wrapper.has_run: # type: ignore[attr-defined]
|
||||
return
|
||||
|
||||
with wrapper.lock: # type: ignore[attr-defined]
|
||||
if not wrapper.has_run: # type: ignore[attr-defined]
|
||||
wrapper.has_run = True # type: ignore[attr-defined]
|
||||
return f(*args, **kwargs)
|
||||
|
||||
wrapper.has_run = False # type: ignore[attr-defined]
|
||||
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_args(
|
||||
start_index: int,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
params = inspect.signature(fn).parameters
|
||||
pos_types = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_args = pos_kws[start_index : len(args)]
|
||||
if deprecated_args:
|
||||
msg = (
|
||||
f"The positional arguments {deprecated_args} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def deprecate_kwargs(
|
||||
*kws: str,
|
||||
is_deprecated: bool | Callable[[], bool] = True,
|
||||
additional_message: str | None = None,
|
||||
) -> Callable[[F], F]:
|
||||
deprecated_kws = set(kws)
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||
if deprecated_kwargs:
|
||||
msg = (
|
||||
f"The keyword arguments {deprecated_kwargs} are "
|
||||
"deprecated and will be removed in a future update."
|
||||
)
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
*,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if not params:
|
||||
return False
|
||||
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set(
|
||||
(
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
)
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (
|
||||
requires_kw_only
|
||||
and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
|
||||
):
|
||||
return False
|
||||
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
|
||||
not requires_kw_only and is_sig_param
|
||||
):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (
|
||||
last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Mapping[str, object] | None,
|
||||
*,
|
||||
requires_kw_only: bool = True,
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
mapping param names to values, drop values that can be not be kwarg
|
||||
expanded to overwrite one or more keyword-only args. This is used in a
|
||||
few places to handle custom processor overrides for multimodal models,
|
||||
e.g., for profiling when processor options provided by the user
|
||||
may affect the number of mm tokens per instance.
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
to overwrite one or more keyword only arguments when invoking the
|
||||
callable.
|
||||
"""
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if supports_kw(
|
||||
callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
dropped_keys = overrides.keys() - filtered_overrides.keys()
|
||||
if dropped_keys:
|
||||
if requires_kw_only:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword-only args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"The following intended overrides are not keyword args "
|
||||
"and will be dropped: %s",
|
||||
dropped_keys,
|
||||
)
|
||||
|
||||
return filtered_overrides
|
||||
147
utils/gc_utils.py
Normal file
147
utils/gc_utils.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
from collections import Counter
|
||||
from contextlib import suppress
|
||||
from typing import Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GCDebugConfig:
|
||||
"""
|
||||
Config for GC Debugger.
|
||||
- 0: disable GC debugger
|
||||
- 1: enable GC debugger with gc.collect elpased times
|
||||
- '{"top_objects":5}': enable GC debugger with top 5 collected objects
|
||||
"""
|
||||
|
||||
def __init__(self, gc_debug_conf: str | None = None) -> None:
|
||||
self.enabled: bool = False
|
||||
self.top_objects: int = -1
|
||||
|
||||
if not gc_debug_conf or gc_debug_conf == "0":
|
||||
pass
|
||||
elif gc_debug_conf == "1":
|
||||
self.enabled = True
|
||||
else:
|
||||
try:
|
||||
json_conf = json.loads(gc_debug_conf)
|
||||
self.enabled = True
|
||||
self.top_objects = json_conf.get("top_objects", -1)
|
||||
except Exception:
|
||||
self.enabled = False
|
||||
logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG)
|
||||
logger.debug("GC Debug Config. %s", str(self))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"enabled:{self.enabled},top_objects:{self.top_objects}"
|
||||
|
||||
|
||||
class GCDebugger:
|
||||
"""
|
||||
Debugger for GC which logs helpful information for GC understanding.
|
||||
To enable, you should call maybe_attach_gc_debug_callback in the process.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GCDebugConfig) -> None:
|
||||
self.config = config
|
||||
# Start time in micro second of this GC cycle
|
||||
self.start_time_ns: int = time.monotonic_ns()
|
||||
# If config.top_objects is positive,
|
||||
# compute top collected objects by object types
|
||||
self.gc_top_collected_objects: str = ""
|
||||
|
||||
def handle(self, phase: str, info: dict[str, int]) -> None:
|
||||
"""
|
||||
Handles a GC event (e.g. GC start or GC finish)
|
||||
"""
|
||||
generation = info.get("generation")
|
||||
if generation is None:
|
||||
return
|
||||
if phase == "start":
|
||||
# Before GC started, record GC start time
|
||||
# and top collected objects
|
||||
self.start_time_ns = time.monotonic_ns()
|
||||
self.gc_top_collected_objects = _compute_top_gc_collected_objects(
|
||||
gc.get_objects(generation), self.config.top_objects
|
||||
)
|
||||
elif phase == "stop":
|
||||
# After GC finished, Record GC elapsed time and
|
||||
# optionally top collected objects
|
||||
elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6
|
||||
logger.info(
|
||||
"GC took %.3fms to complete. "
|
||||
"Collected %s objects in GC generation %d.%s",
|
||||
elpased_ms,
|
||||
str(info.get("collected", "?")),
|
||||
generation,
|
||||
(
|
||||
f" Top collected objects: \n{self.gc_top_collected_objects}"
|
||||
if self.gc_top_collected_objects
|
||||
else ""
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def freeze_gc_heap() -> None:
|
||||
"""
|
||||
Freeze all objects tracked by the garbage collector. It should be invoked
|
||||
after server init / warmup, to reduce GC overhead from static objects
|
||||
during serving time.
|
||||
"""
|
||||
# Ensure all static objects are pushed down to the oldest generation for
|
||||
# freeze
|
||||
gc.collect(0)
|
||||
gc.collect(1)
|
||||
gc.collect(2)
|
||||
# Freeze all GC tracked objects
|
||||
gc.freeze()
|
||||
|
||||
|
||||
def maybe_attach_gc_debug_callback() -> None:
|
||||
"""
|
||||
Attached a callback for GC debug when VLLM_GC_DEBUG is enabled.
|
||||
"""
|
||||
config = GCDebugConfig(envs.VLLM_GC_DEBUG)
|
||||
if config.enabled:
|
||||
debugger: GCDebugger = GCDebugger(config)
|
||||
|
||||
def gc_callback(phase: str, info: dict[str, int]) -> None:
|
||||
debugger.handle(phase, info)
|
||||
|
||||
gc.callbacks.append(gc_callback)
|
||||
|
||||
|
||||
def _compute_detailed_type(o: Any) -> str:
|
||||
"""
|
||||
Detailed object type.
|
||||
|
||||
TODO(Jialin): Further enhance the detailed type with element types for
|
||||
easier debugging. We tried but occasionally it would run into signals
|
||||
which kills the engine.
|
||||
"""
|
||||
size_str: str = ""
|
||||
# Object doesn't support len() - this can happen with type objects
|
||||
# or other objects that don't implement __len__ properly
|
||||
with suppress(Exception):
|
||||
size_str = f"(size:{len(o)})"
|
||||
return f"{str(type(o))}{size_str}"
|
||||
|
||||
|
||||
def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str:
|
||||
"""
|
||||
Group collected objects by types.
|
||||
"""
|
||||
if top <= 0:
|
||||
return ""
|
||||
object_types = [_compute_detailed_type(o) for o in objects]
|
||||
return "\n".join(
|
||||
f"{count:>5}:{object_type}"
|
||||
for object_type, count in Counter(object_types).most_common(top)
|
||||
)
|
||||
63
utils/hashing.py
Normal file
63
utils/hashing.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import cbor2
|
||||
|
||||
|
||||
def sha256(input: Any) -> bytes:
|
||||
"""Hash any picklable Python object using SHA-256.
|
||||
|
||||
The input is serialized using pickle before hashing, which allows
|
||||
arbitrary Python objects to be used. Note that this function does
|
||||
not use a hash seed—if you need one, prepend it explicitly to the input.
|
||||
|
||||
Args:
|
||||
input: Any picklable Python object.
|
||||
|
||||
Returns:
|
||||
Bytes representing the SHA-256 hash of the serialized input.
|
||||
"""
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
|
||||
def sha256_cbor(input: Any) -> bytes:
|
||||
"""Hash objects using CBOR serialization and SHA-256.
|
||||
|
||||
This option is useful for non-Python-dependent serialization and hashing.
|
||||
|
||||
Args:
|
||||
input: Object to be serialized and hashed. Supported types include
|
||||
basic Python types and complex structures like lists, tuples, and
|
||||
dictionaries.
|
||||
Custom classes must implement CBOR serialization methods.
|
||||
|
||||
Returns:
|
||||
Bytes representing the SHA-256 hash of the CBOR serialized input.
|
||||
"""
|
||||
input_bytes = cbor2.dumps(input, canonical=True)
|
||||
return hashlib.sha256(input_bytes).digest()
|
||||
|
||||
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
|
||||
"""Get a hash function by name, or raise an error if the function is not found.
|
||||
|
||||
Args:
|
||||
hash_fn_name: Name of the hash function.
|
||||
|
||||
Returns:
|
||||
A hash function.
|
||||
"""
|
||||
if hash_fn_name == "sha256":
|
||||
return sha256
|
||||
if hash_fn_name == "sha256_cbor":
|
||||
return sha256_cbor
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
411
utils/import_utils.py
Normal file
411
utils/import_utils.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers related to importing modules.
|
||||
|
||||
This is similar in concept to the `importlib` module.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from functools import cache
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
"""
|
||||
Lazy initialization of the Hugging Face modules.
|
||||
"""
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
|
||||
init_hf_modules()
|
||||
|
||||
|
||||
def import_pynvml():
|
||||
"""
|
||||
Historical comments:
|
||||
|
||||
libnvml.so is the library behind nvidia-smi, and
|
||||
pynvml is a Python wrapper around it. We use it to get GPU
|
||||
status without initializing CUDA context in the current process.
|
||||
Historically, there are two packages that provide pynvml:
|
||||
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||
wrapper. It is a dependency of vLLM, and is installed when users
|
||||
install vLLM. It provides a Python module named `pynvml`.
|
||||
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||
and therefore conflicts with the official one. What's worse,
|
||||
the module is a Python package, and has higher priority than
|
||||
the official one which is a standalone Python file.
|
||||
This causes errors when both of them are installed.
|
||||
Starting from version 12.0, it migrates to a new module
|
||||
named `pynvml_utils` to avoid the conflict.
|
||||
It is so confusing that many packages in the community use the
|
||||
unofficial one by mistake, and we have to handle this case.
|
||||
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
|
||||
one, and it will cause errors, see the issue
|
||||
https://github.com/vllm-project/vllm/issues/12847 for example.
|
||||
After all the troubles, we decide to copy the official `pynvml`
|
||||
module to our codebase, and use it directly.
|
||||
"""
|
||||
import vllm.third_party.pynvml as pynvml
|
||||
|
||||
return pynvml
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: str | os.PathLike):
|
||||
"""
|
||||
Import a Python file according to its file path.
|
||||
|
||||
Based on the official recipe:
|
||||
https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
raise ModuleNotFoundError(f"No module named {module_name!r}")
|
||||
|
||||
assert spec.loader is not None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
"""
|
||||
Resolve an object by its fully-qualified class name.
|
||||
"""
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
|
||||
@cache
|
||||
def get_vllm_optional_dependencies():
|
||||
metadata = importlib.metadata.metadata("vllm")
|
||||
requirements = metadata.get_all("Requires-Dist", [])
|
||||
extras = metadata.get_all("Provides-Extra", [])
|
||||
|
||||
return {
|
||||
extra: [
|
||||
re.split(r";|>=|<=|==", req)[0]
|
||||
for req in requirements
|
||||
if req.endswith(f'extra == "{extra}"')
|
||||
]
|
||||
for extra in extras
|
||||
}
|
||||
|
||||
|
||||
class _PlaceholderBase:
|
||||
"""
|
||||
Disallows downstream usage of placeholder modules.
|
||||
|
||||
We need to explicitly override each dunder method because
|
||||
[`__getattr__`][vllm.utils.import_utils._PlaceholderBase.__getattr__]
|
||||
is not called when they are accessed.
|
||||
|
||||
Info:
|
||||
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
|
||||
"""
|
||||
|
||||
def __getattr__(self, key: str) -> Never:
|
||||
"""
|
||||
The main class should implement this to throw an error
|
||||
for attribute accesses representing downstream usage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# [Basic customization]
|
||||
|
||||
def __lt__(self, other: object):
|
||||
return self.__getattr__("__lt__")
|
||||
|
||||
def __le__(self, other: object):
|
||||
return self.__getattr__("__le__")
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return self.__getattr__("__eq__")
|
||||
|
||||
def __ne__(self, other: object):
|
||||
return self.__getattr__("__ne__")
|
||||
|
||||
def __gt__(self, other: object):
|
||||
return self.__getattr__("__gt__")
|
||||
|
||||
def __ge__(self, other: object):
|
||||
return self.__getattr__("__ge__")
|
||||
|
||||
def __hash__(self):
|
||||
return self.__getattr__("__hash__")
|
||||
|
||||
def __bool__(self):
|
||||
return self.__getattr__("__bool__")
|
||||
|
||||
# [Callable objects]
|
||||
|
||||
def __call__(self, *args: object, **kwargs: object):
|
||||
return self.__getattr__("__call__")
|
||||
|
||||
# [Container types]
|
||||
|
||||
def __len__(self):
|
||||
return self.__getattr__("__len__")
|
||||
|
||||
def __getitem__(self, key: object):
|
||||
return self.__getattr__("__getitem__")
|
||||
|
||||
def __setitem__(self, key: object, value: object):
|
||||
return self.__getattr__("__setitem__")
|
||||
|
||||
def __delitem__(self, key: object):
|
||||
return self.__getattr__("__delitem__")
|
||||
|
||||
# __missing__ is optional according to __getitem__ specification,
|
||||
# so it is skipped
|
||||
|
||||
# __iter__ and __reversed__ have a default implementation
|
||||
# based on __len__ and __getitem__, so they are skipped.
|
||||
|
||||
# [Numeric Types]
|
||||
|
||||
def __add__(self, other: object):
|
||||
return self.__getattr__("__add__")
|
||||
|
||||
def __sub__(self, other: object):
|
||||
return self.__getattr__("__sub__")
|
||||
|
||||
def __mul__(self, other: object):
|
||||
return self.__getattr__("__mul__")
|
||||
|
||||
def __matmul__(self, other: object):
|
||||
return self.__getattr__("__matmul__")
|
||||
|
||||
def __truediv__(self, other: object):
|
||||
return self.__getattr__("__truediv__")
|
||||
|
||||
def __floordiv__(self, other: object):
|
||||
return self.__getattr__("__floordiv__")
|
||||
|
||||
def __mod__(self, other: object):
|
||||
return self.__getattr__("__mod__")
|
||||
|
||||
def __divmod__(self, other: object):
|
||||
return self.__getattr__("__divmod__")
|
||||
|
||||
def __pow__(self, other: object, modulo: object = ...):
|
||||
return self.__getattr__("__pow__")
|
||||
|
||||
def __lshift__(self, other: object):
|
||||
return self.__getattr__("__lshift__")
|
||||
|
||||
def __rshift__(self, other: object):
|
||||
return self.__getattr__("__rshift__")
|
||||
|
||||
def __and__(self, other: object):
|
||||
return self.__getattr__("__and__")
|
||||
|
||||
def __xor__(self, other: object):
|
||||
return self.__getattr__("__xor__")
|
||||
|
||||
def __or__(self, other: object):
|
||||
return self.__getattr__("__or__")
|
||||
|
||||
# r* and i* methods have lower priority than
|
||||
# the methods for left operand so they are skipped
|
||||
|
||||
def __neg__(self):
|
||||
return self.__getattr__("__neg__")
|
||||
|
||||
def __pos__(self):
|
||||
return self.__getattr__("__pos__")
|
||||
|
||||
def __abs__(self):
|
||||
return self.__getattr__("__abs__")
|
||||
|
||||
def __invert__(self):
|
||||
return self.__getattr__("__invert__")
|
||||
|
||||
# __complex__, __int__ and __float__ have a default implementation
|
||||
# based on __index__, so they are skipped.
|
||||
|
||||
def __index__(self):
|
||||
return self.__getattr__("__index__")
|
||||
|
||||
def __round__(self, ndigits: object = ...):
|
||||
return self.__getattr__("__round__")
|
||||
|
||||
def __trunc__(self):
|
||||
return self.__getattr__("__trunc__")
|
||||
|
||||
def __floor__(self):
|
||||
return self.__getattr__("__floor__")
|
||||
|
||||
def __ceil__(self):
|
||||
return self.__getattr__("__ceil__")
|
||||
|
||||
# [Context managers]
|
||||
|
||||
def __enter__(self):
|
||||
return self.__getattr__("__enter__")
|
||||
|
||||
def __exit__(self, *args: object, **kwargs: object):
|
||||
return self.__getattr__("__exit__")
|
||||
|
||||
|
||||
class PlaceholderModule(_PlaceholderBase):
|
||||
"""
|
||||
A placeholder object to use when a module does not exist.
|
||||
|
||||
This enables more informative errors when trying to access attributes
|
||||
of a module that does not exist.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Apply name mangling to avoid conflicting with module attributes
|
||||
self.__name = name
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self, attr_path)
|
||||
|
||||
def __getattr__(self, key: str) -> Never:
|
||||
name = self.__name
|
||||
|
||||
try:
|
||||
importlib.import_module(name)
|
||||
except ImportError as exc:
|
||||
for extra, names in get_vllm_optional_dependencies().items():
|
||||
if name in names:
|
||||
msg = f"Please install vllm[{extra}] for {extra} support"
|
||||
raise ImportError(msg) from exc
|
||||
|
||||
raise exc
|
||||
|
||||
raise AssertionError(
|
||||
"PlaceholderModule should not be used "
|
||||
"when the original module can be imported"
|
||||
)
|
||||
|
||||
|
||||
class _PlaceholderModuleAttr(_PlaceholderBase):
|
||||
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Apply name mangling to avoid conflicting with module attributes
|
||||
self.__module = module
|
||||
self.__attr_path = attr_path
|
||||
|
||||
def placeholder_attr(self, attr_path: str):
|
||||
return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}")
|
||||
|
||||
def __getattr__(self, key: str) -> Never:
|
||||
getattr(self.__module, f"{self.__attr_path}.{key}")
|
||||
|
||||
raise AssertionError(
|
||||
"PlaceholderModule should not be used "
|
||||
"when the original module can be imported"
|
||||
)
|
||||
|
||||
|
||||
class LazyLoader(ModuleType):
|
||||
"""
|
||||
`LazyLoader` module borrowed from [Tensorflow]
|
||||
(https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py)
|
||||
with an addition of "module caching".
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
Modules such as `xgrammar` might do additional side effects, so we
|
||||
only want to use this when it is needed, delaying all eager effects.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: dict[str, Any],
|
||||
name: str,
|
||||
):
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
self._module: ModuleType | None = None
|
||||
|
||||
super().__init__(str(name))
|
||||
|
||||
def _load(self) -> ModuleType:
|
||||
# Import the target module and insert it into the parent's namespace
|
||||
try:
|
||||
module = importlib.import_module(self.__name__)
|
||||
self._parent_module_globals[self._local_name] = module
|
||||
# The additional add to sys.modules
|
||||
# ensures library is actually loaded.
|
||||
sys.modules[self._local_name] = module
|
||||
except ModuleNotFoundError as err:
|
||||
raise err from None
|
||||
|
||||
# Update this object's dict so that if someone keeps a
|
||||
# reference to the LazyLoader, lookups are efficient
|
||||
# (__getattr__ is only called on lookups that fail).
|
||||
self.__dict__.update(module.__dict__)
|
||||
return module
|
||||
|
||||
def __getattr__(self, item: Any) -> Any:
|
||||
if self._module is None:
|
||||
self._module = self._load()
|
||||
return getattr(self._module, item)
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
if self._module is None:
|
||||
self._module = self._load()
|
||||
return dir(self._module)
|
||||
|
||||
|
||||
# Optional dependency detection utilities
|
||||
@cache
|
||||
def _has_module(module_name: str) -> bool:
|
||||
"""Return True if *module_name* can be found in the current environment.
|
||||
|
||||
The result is cached so that subsequent queries for the same module incur
|
||||
no additional overhead.
|
||||
"""
|
||||
return importlib.util.find_spec(module_name) is not None
|
||||
|
||||
|
||||
def has_pplx() -> bool:
|
||||
"""Whether the optional `pplx_kernels` package is available."""
|
||||
return _has_module("pplx_kernels")
|
||||
|
||||
|
||||
def has_deep_ep() -> bool:
|
||||
"""Whether the optional `deep_ep` package is available."""
|
||||
return _has_module("deep_ep")
|
||||
|
||||
|
||||
def has_deep_gemm() -> bool:
|
||||
"""Whether the optional `deep_gemm` package is available."""
|
||||
return _has_module("deep_gemm")
|
||||
|
||||
|
||||
def has_triton_kernels() -> bool:
|
||||
"""Whether the optional `triton_kernels` package is available."""
|
||||
return _has_module("triton_kernels")
|
||||
|
||||
|
||||
def has_tilelang() -> bool:
|
||||
"""Whether the optional `tilelang` package is available."""
|
||||
return _has_module("tilelang")
|
||||
|
||||
|
||||
def has_arctic_inference() -> bool:
|
||||
"""Whether the optional `arctic_inference` package is available."""
|
||||
|
||||
return _has_module("arctic_inference")
|
||||
165
utils/jsontree.py
Normal file
165
utils/jsontree.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Helper functions to work with nested JSON structures."""
|
||||
|
||||
from collections.abc import Callable, Iterable
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import BatchedTensorInputs
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_U = TypeVar("_U")
|
||||
|
||||
JSONTree: TypeAlias = (
|
||||
dict[str, "JSONTree[_T]"] | list["JSONTree[_T]"] | tuple["JSONTree[_T]", ...] | _T
|
||||
)
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
_JSONTree: TypeAlias = (
|
||||
dict[str, "JSONTree[_T]"]
|
||||
| list["JSONTree[_T]"]
|
||||
| tuple["JSONTree[_T]", ...]
|
||||
| dict[str, _T]
|
||||
| list[_T]
|
||||
| tuple[_T, ...]
|
||||
| _T
|
||||
)
|
||||
"""
|
||||
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
|
||||
"""
|
||||
|
||||
|
||||
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
|
||||
"""Iterate through each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
for v in value.values():
|
||||
yield from json_iter_leaves(v)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
for v in value:
|
||||
yield from json_iter_leaves(v)
|
||||
else:
|
||||
yield value
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[["torch.Tensor"], "torch.Tensor"],
|
||||
value: "BatchedTensorInputs",
|
||||
) -> "BatchedTensorInputs": ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: _T | dict[str, _T],
|
||||
) -> _U | dict[str, _U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: _T | list[_T],
|
||||
) -> _U | list[_U]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: _T | tuple[_T, ...],
|
||||
) -> _U | tuple[_U, ...]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: JSONTree[_T],
|
||||
) -> JSONTree[_U]: ...
|
||||
|
||||
|
||||
def json_map_leaves(
|
||||
func: Callable[[_T], _U],
|
||||
value: "BatchedTensorInputs" | _JSONTree[_T],
|
||||
) -> "BatchedTensorInputs" | _JSONTree[_U]:
|
||||
"""Apply a function to each leaf in a nested JSON structure."""
|
||||
if isinstance(value, dict):
|
||||
return {
|
||||
k: json_map_leaves(func, v) # type: ignore[arg-type]
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(json_map_leaves(func, v) for v in value)
|
||||
else:
|
||||
return func(value)
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: _T | dict[str, _T],
|
||||
/,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: _T | list[_T],
|
||||
/,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: _T | tuple[_T, ...],
|
||||
/,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_T, _T], _T],
|
||||
value: JSONTree[_T],
|
||||
/,
|
||||
) -> _T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def json_reduce_leaves(
|
||||
func: Callable[[_U, _T], _U],
|
||||
value: JSONTree[_T],
|
||||
initial: _U,
|
||||
/,
|
||||
) -> _U: ...
|
||||
|
||||
|
||||
def json_reduce_leaves(
|
||||
func: Callable[..., _T | _U],
|
||||
value: _JSONTree[_T],
|
||||
initial: _U = cast(_U, ...), # noqa: B008
|
||||
/,
|
||||
) -> _T | _U:
|
||||
"""
|
||||
Apply a function of two arguments cumulatively to each leaf in a
|
||||
nested JSON structure, from left to right, so as to reduce the
|
||||
sequence to a single value.
|
||||
"""
|
||||
if initial is ...:
|
||||
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
|
||||
|
||||
return reduce(
|
||||
func, # type: ignore[arg-type]
|
||||
json_iter_leaves(value),
|
||||
initial,
|
||||
)
|
||||
|
||||
|
||||
def json_count_leaves(value: JSONTree[_T]) -> int:
|
||||
"""Count the number of leaves in a nested JSON structure."""
|
||||
return sum(1 for _ in json_iter_leaves(value))
|
||||
32
utils/math_utils.py
Normal file
32
utils/math_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Math utility functions for vLLM."""
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
def next_power_of_2(n: int) -> int:
|
||||
"""The next power of 2 (inclusive)"""
|
||||
if n < 1:
|
||||
return 1
|
||||
return 1 << (n - 1).bit_length()
|
||||
|
||||
|
||||
def prev_power_of_2(n: int) -> int:
|
||||
"""The previous power of 2 (inclusive)"""
|
||||
if n <= 0:
|
||||
return 0
|
||||
return 1 << (n.bit_length() - 1)
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
"""Round up x to the nearest multiple of y."""
|
||||
return ((x + y - 1) // y) * y
|
||||
|
||||
|
||||
def round_down(x: int, y: int) -> int:
|
||||
"""Round down x to the nearest multiple of y."""
|
||||
return (x // y) * y
|
||||
13
utils/mem_constants.py
Normal file
13
utils/mem_constants.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
MB_bytes = 1_000_000
|
||||
"""The number of bytes in one megabyte (MB)."""
|
||||
|
||||
MiB_bytes = 1 << 20
|
||||
"""The number of bytes in one mebibyte (MiB)."""
|
||||
|
||||
GB_bytes = 1_000_000_000
|
||||
"""The number of bytes in one gigabyte (GB)."""
|
||||
|
||||
GiB_bytes = 1 << 30
|
||||
"""The number of bytes in one gibibyte (GiB)."""
|
||||
232
utils/mem_utils.py
Normal file
232
utils/mem_utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import gc
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
|
||||
from .mem_constants import GiB_bytes
|
||||
|
||||
|
||||
@cache
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
||||
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
||||
# will fail
|
||||
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
def get_cpu_memory() -> int:
|
||||
"""Returns the total CPU memory of the node in bytes."""
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
|
||||
class DeviceMemoryProfiler:
|
||||
def __init__(self, device: torch.types.Device | None = None):
|
||||
self.device = device
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
gc.collect()
|
||||
return current_platform.get_current_memory_usage(self.device)
|
||||
|
||||
def __enter__(self):
|
||||
self.initial_memory = self.current_memory_usage()
|
||||
# This allows us to call methods of the context manager if needed
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.final_memory = self.current_memory_usage()
|
||||
self.consumed_memory = self.final_memory - self.initial_memory
|
||||
|
||||
# Force garbage collection
|
||||
gc.collect()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySnapshot:
|
||||
"""Memory snapshot."""
|
||||
|
||||
torch_peak: int = 0
|
||||
free_memory: int = 0
|
||||
total_memory: int = 0
|
||||
cuda_memory: int = 0
|
||||
torch_memory: int = 0
|
||||
non_torch_memory: int = 0
|
||||
timestamp: float = 0.0
|
||||
auto_measure: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.auto_measure:
|
||||
self.measure()
|
||||
|
||||
def measure(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# we measure the torch peak memory usage via allocated_bytes,
|
||||
# rather than `torch.cuda.memory_reserved()` .
|
||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
|
||||
|
||||
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
|
||||
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
|
||||
):
|
||||
# On UMA (Orin, Thor and Spark) platform,
|
||||
# where both CPU and GPU rely on system memory,
|
||||
# the cudaMemGetInfo function shows the amount of free system memory
|
||||
# rather than what’s actually available.
|
||||
# In the case,
|
||||
# torch.cuda.mem_get_info() only reports "free" memory,
|
||||
# which can be lower than what is actually
|
||||
# available due to not including cache memory.
|
||||
# There’s also a comprehensive reference page
|
||||
# that explains how you can compute the proper value yourself.
|
||||
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
|
||||
self.free_memory = psutil.virtual_memory().available
|
||||
|
||||
self.cuda_memory = self.total_memory - self.free_memory
|
||||
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
# this is used to measure the non-torch memory usage
|
||||
self.torch_memory = torch.cuda.memory_reserved()
|
||||
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
return MemorySnapshot(
|
||||
torch_peak=self.torch_peak - other.torch_peak,
|
||||
free_memory=self.free_memory - other.free_memory,
|
||||
total_memory=self.total_memory - other.total_memory,
|
||||
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||
torch_memory=self.torch_memory - other.torch_memory,
|
||||
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
||||
timestamp=self.timestamp - other.timestamp,
|
||||
auto_measure=False,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryProfilingResult:
|
||||
"""Memory profiling result. All numbers are in bytes."""
|
||||
|
||||
non_kv_cache_memory: int = 0
|
||||
torch_peak_increase: int = 0
|
||||
non_torch_increase: int = 0
|
||||
weights_memory: float = 0
|
||||
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
profile_time: float = 0.0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
||||
f"Total non KV cache memory: "
|
||||
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
|
||||
f"torch peak memory increase: "
|
||||
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
|
||||
f"non-torch forward increase memory: "
|
||||
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
|
||||
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memory_profiling(
|
||||
baseline_snapshot: MemorySnapshot, weights_memory: int
|
||||
) -> Generator[MemoryProfilingResult, None, None]:
|
||||
"""Memory profiling context manager.
|
||||
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||
weights_memory: memory used by PyTorch when loading the model weights.
|
||||
Note that, before loading the model weights, we also initialize the device
|
||||
and distributed environment, which may consume some memory. This part is not
|
||||
included in the weights_memory because PyTorch does not control it.
|
||||
|
||||
The memory in one GPU can be classified into 3 categories:
|
||||
1. memory used by anything other than the current vLLM instance.
|
||||
2. memory used by torch in the current vLLM instance.
|
||||
3. memory used in the current vLLM instance, but not by torch.
|
||||
|
||||
A quantitive example:
|
||||
|
||||
Before creating the current vLLM instance:
|
||||
category 1: 1 GiB
|
||||
category 2: 0 GiB
|
||||
category 3: 0 GiB
|
||||
|
||||
After creating the current vLLM instance and loading the model,
|
||||
(i.e. before profiling):
|
||||
category 1: 1 GiB
|
||||
category 2: 2 GiB (model weights take 2 GiB)
|
||||
category 3: 0.5 GiB (memory used by NCCL)
|
||||
|
||||
During profiling (peak):
|
||||
category 1: 1 GiB
|
||||
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
After profiling:
|
||||
category 1: 1 GiB
|
||||
category 2: 3 GiB (after garbage-collecting activation tensors)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
In this case, non-kv cache takes 5 GiB in total, including:
|
||||
a. 2 GiB used by the model weights (category 2)
|
||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||
c. 1 GiB used by non-torch components (category 3)
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
||||
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||
""" # noqa
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
result = MemoryProfilingResult()
|
||||
|
||||
result.before_create = baseline_snapshot
|
||||
# the part of memory used for holding the model weights
|
||||
result.weights_memory = weights_memory
|
||||
|
||||
result.before_profile.measure()
|
||||
|
||||
yield result
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result.after_profile.measure()
|
||||
|
||||
diff_profile = result.after_profile - result.before_profile
|
||||
diff_from_create = result.after_profile - result.before_create
|
||||
result.torch_peak_increase = diff_profile.torch_peak
|
||||
result.non_torch_increase = diff_from_create.non_torch_memory
|
||||
result.profile_time = diff_profile.timestamp
|
||||
|
||||
non_torch_memory = result.non_torch_increase
|
||||
peak_activation_memory = result.torch_peak_increase
|
||||
result.non_kv_cache_memory = (
|
||||
non_torch_memory + peak_activation_memory + result.weights_memory
|
||||
) # noqa
|
||||
64
utils/nccl.py
Normal file
64
utils/nccl.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def find_nccl_library() -> str:
|
||||
"""Return NCCL/RCCL shared library name to load.
|
||||
|
||||
Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend.
|
||||
"""
|
||||
so_file = envs.VLLM_NCCL_SO_PATH
|
||||
if so_file:
|
||||
logger.info(
|
||||
"Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
|
||||
)
|
||||
else:
|
||||
if torch.version.cuda is not None:
|
||||
so_file = "libnccl.so.2"
|
||||
elif torch.version.hip is not None:
|
||||
so_file = "librccl.so.1"
|
||||
else:
|
||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||
logger.debug_once("Found nccl from library %s", so_file)
|
||||
return so_file
|
||||
|
||||
|
||||
def find_nccl_include_paths() -> list[str] | None:
|
||||
"""Return possible include paths containing `nccl.h`.
|
||||
|
||||
Considers `VLLM_NCCL_INCLUDE_PATH` and the `nvidia-nccl-cuXX` package.
|
||||
"""
|
||||
paths: list[str] = []
|
||||
inc = envs.VLLM_NCCL_INCLUDE_PATH
|
||||
if inc and os.path.isdir(inc):
|
||||
paths.append(inc)
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec("nvidia.nccl")
|
||||
if spec and getattr(spec, "submodule_search_locations", None):
|
||||
for loc in spec.submodule_search_locations:
|
||||
inc_dir = os.path.join(loc, "include")
|
||||
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
|
||||
paths.append(inc_dir)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to find nccl include path from nvidia.nccl package: %s", e)
|
||||
|
||||
seen: set[str] = set()
|
||||
out: list[str] = []
|
||||
for p in paths:
|
||||
if p and p not in seen:
|
||||
out.append(p)
|
||||
seen.add(p)
|
||||
return out or None
|
||||
331
utils/network_utils.py
Normal file
331
utils/network_utils.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||
for sock in sockets:
|
||||
if sock is not None:
|
||||
sock.close(linger=0)
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||
logger.warning(
|
||||
"The environment variable HOST_IP is deprecated and ignored, as"
|
||||
" it is often used by Docker and other software to"
|
||||
" interact with the container's network stack. Please "
|
||||
"use VLLM_HOST_IP instead to set the IP address for vLLM processes"
|
||||
" to communicate with each other."
|
||||
)
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s:
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" VLLM_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def test_loopback_bind(address, family):
|
||||
try:
|
||||
s = socket.socket(family, socket.SOCK_DGRAM)
|
||||
s.bind((address, 0)) # Port 0 = auto assign
|
||||
s.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def get_loopback_ip() -> str:
|
||||
loopback_ip = envs.VLLM_LOOPBACK_IP
|
||||
if loopback_ip:
|
||||
return loopback_ip
|
||||
|
||||
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
|
||||
|
||||
if test_loopback_bind("127.0.0.1", socket.AF_INET):
|
||||
return "127.0.0.1"
|
||||
elif test_loopback_bind("::1", socket.AF_INET6):
|
||||
return "::1"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
|
||||
"Set the VLLM_LOOPBACK_IP environment variable explicitly."
|
||||
)
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def split_host_port(host_port: str) -> tuple[str, int]:
|
||||
# ipv6
|
||||
if host_port.startswith("["):
|
||||
host, port = host_port.rsplit("]", 1)
|
||||
host = host[1:]
|
||||
port = port.split(":")[1]
|
||||
return host, int(port)
|
||||
else:
|
||||
host, port = host_port.split(":")
|
||||
return host, int(port)
|
||||
|
||||
|
||||
def join_host_port(host: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_zmq_ipc_path() -> str:
|
||||
base_rpc_path = envs.VLLM_RPC_BASE_PATH
|
||||
return f"ipc://{base_rpc_path}/{uuid4()}"
|
||||
|
||||
|
||||
def get_open_zmq_inproc_path() -> str:
|
||||
return f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
"""
|
||||
Get an open port for the vLLM process to listen on.
|
||||
An edge case to handle, is when we run data parallel,
|
||||
we need to avoid ports that are potentially used by
|
||||
the data parallel master process.
|
||||
Right now we reserve 10 ports for the data parallel master
|
||||
process. Currently it uses 2 ports.
|
||||
"""
|
||||
if "VLLM_DP_MASTER_PORT" in os.environ:
|
||||
dp_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
reserved_port_range = range(dp_master_port, dp_master_port + 10)
|
||||
while True:
|
||||
candidate_port = _get_open_port()
|
||||
if candidate_port not in reserved_port_range:
|
||||
return candidate_port
|
||||
return _get_open_port()
|
||||
|
||||
|
||||
def get_open_ports_list(count: int = 5) -> list[int]:
|
||||
"""Get a list of open ports."""
|
||||
ports = set[int]()
|
||||
while len(ports) < count:
|
||||
ports.add(get_open_port())
|
||||
return list(ports)
|
||||
|
||||
|
||||
def _get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def find_process_using_port(port: int) -> psutil.Process | None:
|
||||
# TODO: We can not check for running processes with network
|
||||
# port on macOS. Therefore, we can not have a full graceful shutdown
|
||||
# of vLLM. For now, let's not look for processes in this case.
|
||||
# Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
|
||||
if sys.platform.startswith("darwin"):
|
||||
return None
|
||||
|
||||
our_pid = os.getpid()
|
||||
for conn in psutil.net_connections():
|
||||
if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
|
||||
try:
|
||||
return psutil.Process(conn.pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||
"""Split a zmq path into its parts."""
|
||||
parsed = urlparse(path)
|
||||
if not parsed.scheme:
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
scheme = parsed.scheme
|
||||
host = parsed.hostname or ""
|
||||
port = str(parsed.port or "")
|
||||
|
||||
if scheme == "tcp" and not all((host, port)):
|
||||
# The host and port fields are required for tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
if scheme != "tcp" and port:
|
||||
# port only makes sense with tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
|
||||
"""Make a ZMQ path from its parts.
|
||||
|
||||
Args:
|
||||
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
|
||||
host: The host - can be an IPv4 address, IPv6 address, or hostname.
|
||||
port: Optional port number, only used for TCP sockets.
|
||||
|
||||
Returns:
|
||||
A properly formatted ZMQ path string.
|
||||
"""
|
||||
if port is None:
|
||||
return f"{scheme}://{host}"
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"{scheme}://[{host}]:{port}"
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined]
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
identity: bytes | None = None,
|
||||
linger: int | None = None,
|
||||
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(socket_type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
# For systems with substantial memory (>32GB total, >16GB available):
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
|
||||
|
||||
if bind is None:
|
||||
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||
|
||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
|
||||
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
if linger is not None:
|
||||
socket.setsockopt(zmq.LINGER, linger)
|
||||
|
||||
if socket_type == zmq.XPUB:
|
||||
socket.setsockopt(zmq.XPUB_VERBOSE, True)
|
||||
|
||||
# Determine if the path is a TCP socket with an IPv6 address.
|
||||
# Enable IPv6 on the zmq socket if so.
|
||||
scheme, host, _ = split_zmq_path(path)
|
||||
if scheme == "tcp" and is_valid_ipv6_address(host):
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
if bind:
|
||||
socket.bind(path)
|
||||
else:
|
||||
socket.connect(path)
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
linger: int = 0,
|
||||
identity: bytes | None = None,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=linger)
|
||||
59
utils/platform_utils.py
Normal file
59
utils/platform_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def cuda_is_initialized() -> bool:
|
||||
"""Check if CUDA is initialized."""
|
||||
if not torch.cuda._is_compiled():
|
||||
return False
|
||||
return torch.cuda.is_initialized()
|
||||
|
||||
|
||||
def xpu_is_initialized() -> bool:
|
||||
"""Check if XPU is initialized."""
|
||||
if not torch.xpu._is_compiled():
|
||||
return False
|
||||
return torch.xpu.is_initialized()
|
||||
|
||||
|
||||
def get_cu_count(device_id: int = 0) -> int:
|
||||
"""Returns the total number of compute units (CU) on single GPU."""
|
||||
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
||||
|
||||
|
||||
def cuda_get_device_properties(
|
||||
device, names: Sequence[str], init_cuda=False
|
||||
) -> tuple[Any, ...]:
|
||||
"""Get specified CUDA device property values without initializing CUDA in
|
||||
the current process."""
|
||||
if init_cuda or cuda_is_initialized():
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
return tuple(getattr(props, name) for name in names)
|
||||
|
||||
# Run in subprocess to avoid initializing CUDA as a side effect.
|
||||
mp_ctx = multiprocessing.get_context("fork")
|
||||
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
|
||||
return executor.submit(cuda_get_device_properties, device, names, True).result()
|
||||
|
||||
|
||||
@cache
|
||||
def is_pin_memory_available() -> bool:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.is_pin_memory_available()
|
||||
|
||||
|
||||
@cache
|
||||
def is_uva_available() -> bool:
|
||||
"""Check if Unified Virtual Addressing (UVA) is available."""
|
||||
# UVA requires pinned memory.
|
||||
# TODO: Add more requirements for UVA if needed.
|
||||
return is_pin_memory_available()
|
||||
56
utils/profiling.py
Normal file
56
utils/profiling.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cprofile_context(save_file: str | None = None):
|
||||
"""Run a cprofile
|
||||
|
||||
Args:
|
||||
save_file: path to save the profile result. "1" or
|
||||
None will result in printing to stdout.
|
||||
"""
|
||||
import cProfile
|
||||
|
||||
prof = cProfile.Profile()
|
||||
prof.enable()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
prof.disable()
|
||||
if save_file and save_file != "1":
|
||||
prof.dump_stats(save_file)
|
||||
else:
|
||||
prof.print_stats(sort="cumtime")
|
||||
|
||||
|
||||
def cprofile(save_file: str | None = None, enabled: bool = True):
|
||||
"""Decorator to profile a Python method using cProfile.
|
||||
|
||||
Args:
|
||||
save_file: Path to save the profile result.
|
||||
If "1", None, or "", results will be printed to stdout.
|
||||
enabled: Set to false to turn this into a no-op
|
||||
"""
|
||||
|
||||
def decorator(func: Callable):
|
||||
@wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any):
|
||||
if not enabled:
|
||||
# If profiling is disabled, just call the function directly.
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with cprofile_context(save_file):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
49
utils/registry.py
Normal file
49
utils/registry.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
"""
|
||||
A registry for managing pluggable extension classes.
|
||||
|
||||
This class provides a simple mechanism to register and instantiate
|
||||
extension classes by name. It is commonly used to implement plugin
|
||||
systems where different implementations can be swapped at runtime.
|
||||
|
||||
Examples:
|
||||
Basic usage with a registry instance:
|
||||
|
||||
>>> FOO_REGISTRY = ExtensionManager()
|
||||
>>> @FOO_REGISTRY.register("my_foo_impl")
|
||||
... class MyFooImpl(Foo):
|
||||
... def __init__(self, value):
|
||||
... self.value = value
|
||||
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize an empty extension registry.
|
||||
"""
|
||||
self.name2class: dict[str, type] = {}
|
||||
|
||||
def register(self, name: str):
|
||||
"""
|
||||
Decorator to register a class with the given name.
|
||||
"""
|
||||
|
||||
def wrap(cls_to_register):
|
||||
self.name2class[name] = cls_to_register
|
||||
return cls_to_register
|
||||
|
||||
return wrap
|
||||
|
||||
def load(self, cls_name: str, *args, **kwargs) -> Any:
|
||||
"""
|
||||
Instantiate and return a registered extension class by name.
|
||||
"""
|
||||
cls = self.name2class.get(cls_name)
|
||||
assert cls is not None, f"Extension class {cls_name} not found"
|
||||
return cls(*args, **kwargs)
|
||||
169
utils/serial_utils.py
Normal file
169
utils/serial_utils.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import PoolingRequestOutput
|
||||
|
||||
sys_byteorder = sys.byteorder
|
||||
|
||||
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
# I'm not sure if other platforms' CPUs support the fp8 data format.
|
||||
# EMBED_DTYPE only uses the fp8 data representation,
|
||||
# does not use fp8 computation, and only occurs on the CPU.
|
||||
# Apologize for any possible break.
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
# numpy does not support bfloat16 and fp8
|
||||
"bfloat16": torch.float16,
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
|
||||
"float32": np.float32,
|
||||
"float16": np.float16,
|
||||
# numpy does not support bfloat16 and fp8
|
||||
"bfloat16": np.float16,
|
||||
"fp8_e4m3": np.uint8,
|
||||
"fp8_e5m2": np.uint8,
|
||||
}
|
||||
|
||||
ENDIANNESS = ["native", "big", "little"]
|
||||
|
||||
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
|
||||
Endianness = Literal["native", "big", "little"]
|
||||
EncodingFormat = Literal["float", "base64", "bytes"]
|
||||
|
||||
|
||||
def tensor2binary(
|
||||
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
|
||||
) -> bytes:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
|
||||
|
||||
np_array = (
|
||||
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
|
||||
)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return np_array.tobytes()
|
||||
|
||||
|
||||
def binary2tensor(
|
||||
binary: bytes,
|
||||
shape: tuple[int, ...],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> torch.Tensor:
|
||||
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
|
||||
|
||||
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return torch.from_numpy(np_array).view(torch_dtype)
|
||||
|
||||
|
||||
def encode_pooling_output(
|
||||
output: PoolingRequestOutput,
|
||||
encoding_format: EncodingFormat,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> list[float] | str | bytes:
|
||||
if encoding_format == "float":
|
||||
return output.outputs.data.tolist()
|
||||
elif encoding_format == "base64":
|
||||
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
elif encoding_format == "bytes":
|
||||
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataItem:
|
||||
index: int
|
||||
embed_dtype: EmbedDType
|
||||
endianness: Endianness
|
||||
start: int
|
||||
end: int
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
def encode_pooling_bytes(
|
||||
pooling_outputs: list[PoolingRequestOutput],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
):
|
||||
num_prompt_tokens = 0
|
||||
items: list[dict[str, MetadataItem]] = []
|
||||
body = []
|
||||
offset = 0
|
||||
for idx, output in enumerate(pooling_outputs):
|
||||
binary = tensor2binary(
|
||||
tensor=output.outputs.data,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
size = len(binary)
|
||||
|
||||
item = {
|
||||
"index": idx,
|
||||
"embed_dtype": embed_dtype,
|
||||
"endianness": endianness,
|
||||
"start": offset,
|
||||
"end": offset + size,
|
||||
"shape": output.outputs.data.shape,
|
||||
}
|
||||
|
||||
body.append(binary)
|
||||
items.append(item)
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
offset += size
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": num_prompt_tokens,
|
||||
"total_tokens": num_prompt_tokens,
|
||||
}
|
||||
return body, items, usage
|
||||
|
||||
|
||||
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
||||
items.sort(key=lambda x: x.index)
|
||||
|
||||
tensor_list: list[torch.Tensor] = []
|
||||
for item in items:
|
||||
binary = body[item.start : item.end]
|
||||
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
|
||||
tensor_list.append(tensor)
|
||||
return tensor_list
|
||||
229
utils/system_utils.py
Normal file
229
utils/system_utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from collections.abc import Callable, Iterator
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
|
||||
import psutil
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
from .platform_utils import cuda_is_initialized, xpu_is_initialized
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CYAN = "\033[1;36m"
|
||||
RESET = "\033[0;0m"
|
||||
|
||||
|
||||
# Environment variable utilities
|
||||
|
||||
|
||||
def update_environment_variables(envs_dict: dict[str, str]):
|
||||
"""Update multiple environment variables with logging."""
|
||||
for k, v in envs_dict.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
logger.warning(
|
||||
"Overwriting environment variable %s from '%s' to '%s'",
|
||||
k,
|
||||
os.environ[k],
|
||||
v,
|
||||
)
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_env_var(key: str, value: str) -> Iterator[None]:
|
||||
"""Temporarily set an environment variable."""
|
||||
old = os.environ.get(key)
|
||||
os.environ[key] = value
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if old is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = old
|
||||
|
||||
|
||||
# File path utilities
|
||||
|
||||
|
||||
def unique_filepath(fn: Callable[[int], Path]) -> Path:
|
||||
"""Generate a unique file path by trying incrementing integers.
|
||||
|
||||
Note: This function has a TOCTOU race condition.
|
||||
Caller should use atomic operations (e.g., open with 'x' mode)
|
||||
when creating the file to ensure thread safety.
|
||||
"""
|
||||
i = 0
|
||||
while True:
|
||||
p = fn(i)
|
||||
if not p.exists():
|
||||
return p
|
||||
i += 1
|
||||
|
||||
|
||||
# Process management utilities
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
"""
|
||||
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
||||
return
|
||||
|
||||
reasons = []
|
||||
if is_in_ray_actor():
|
||||
# even if we choose to spawn, we need to pass the ray address
|
||||
# to the subprocess so that it knows how to connect to the ray cluster.
|
||||
# env vars are inherited by subprocesses, even if we use spawn.
|
||||
import ray
|
||||
|
||||
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
||||
reasons.append("In a Ray actor and can only be spawned")
|
||||
|
||||
if cuda_is_initialized():
|
||||
reasons.append("CUDA is initialized")
|
||||
elif xpu_is_initialized():
|
||||
reasons.append("XPU is initialized")
|
||||
|
||||
if reasons:
|
||||
logger.warning(
|
||||
"We must use the `spawn` multiprocessing start method. "
|
||||
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/usage/"
|
||||
"troubleshooting.html#python-multiprocessing "
|
||||
"for more information. Reasons: %s",
|
||||
"; ".join(reasons),
|
||||
)
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
"""Get a multiprocessing context with a particular method (spawn or fork).
|
||||
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
|
||||
determine the multiprocessing method (default is fork). However, under
|
||||
certain conditions, we may enforce spawn and override the value of
|
||||
VLLM_WORKER_MULTIPROC_METHOD.
|
||||
"""
|
||||
_maybe_force_spawn()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
|
||||
def set_process_title(
|
||||
name: str,
|
||||
suffix: str = "",
|
||||
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX,
|
||||
) -> None:
|
||||
"""Set the current process title with optional suffix."""
|
||||
try:
|
||||
import setproctitle
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
if suffix:
|
||||
name = f"{name}_{suffix}"
|
||||
|
||||
setproctitle.setproctitle(f"{prefix}::{name}")
|
||||
|
||||
|
||||
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
"""Add colored prefix to file output for log decoration."""
|
||||
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||
file_write = file.write
|
||||
|
||||
def write_with_prefix(s: str):
|
||||
if not s:
|
||||
return
|
||||
if file.start_new_line: # type: ignore[attr-defined]
|
||||
file_write(prefix)
|
||||
idx = 0
|
||||
while (next_idx := s.find("\n", idx)) != -1:
|
||||
next_idx += 1
|
||||
file_write(s[idx:next_idx])
|
||||
if next_idx == len(s):
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
return
|
||||
file_write(prefix)
|
||||
idx = next_idx
|
||||
file_write(s[idx:])
|
||||
file.start_new_line = False # type: ignore[attr-defined]
|
||||
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
file.write = write_with_prefix # type: ignore[method-assign]
|
||||
|
||||
|
||||
def decorate_logs(process_name: str | None = None) -> None:
|
||||
"""Decorate stdout/stderr with process name and PID prefix."""
|
||||
if process_name is None:
|
||||
process_name = get_mp_context().current_process().name
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
|
||||
def kill_process_tree(pid: int):
|
||||
"""
|
||||
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||
|
||||
Args:
|
||||
pid (int): Process ID of the parent process
|
||||
"""
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
# Get all children recursively
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
# Send SIGKILL to all children first
|
||||
for child in children:
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(child.pid, signal.SIGKILL)
|
||||
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
|
||||
# Resource utilities
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
|
||||
def set_ulimit(target_soft_limit: int = 65535):
|
||||
if sys.platform.startswith("win"):
|
||||
logger.info("Windows detected, skipping ulimit adjustment.")
|
||||
return
|
||||
|
||||
import resource
|
||||
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Found ulimit of %s and failed to automatically increase "
|
||||
"with error %s. This can cause fd limit errors like "
|
||||
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||
"increasing with ulimit -n",
|
||||
current_soft,
|
||||
e,
|
||||
)
|
||||
255
utils/tensor_schema.py
Normal file
255
utils/tensor_schema.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from types import UnionType
|
||||
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TensorShape:
|
||||
def __init__(
|
||||
self,
|
||||
*dims: int | str,
|
||||
dynamic_dims: set[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
|
||||
|
||||
def resolve(self, **bindings: int) -> tuple[int | str, ...]:
|
||||
resolved = list[int | str]()
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str) and dim in bindings:
|
||||
resolved.append(bindings[dim])
|
||||
else:
|
||||
resolved.append(dim)
|
||||
return tuple(resolved)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the tensor shape."""
|
||||
dim_strs = []
|
||||
for dim in self.dims:
|
||||
if isinstance(dim, str):
|
||||
if dim in self.dynamic_dims:
|
||||
dim_strs.append(f"{dim}*") # Mark dynamic dimensions with *
|
||||
else:
|
||||
dim_strs.append(dim)
|
||||
else:
|
||||
dim_strs.append(str(dim))
|
||||
return f"({', '.join(dim_strs)})"
|
||||
|
||||
|
||||
class TensorSchema:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
validate: bool = True,
|
||||
resolve_bindings: dict[str, int] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
if validate:
|
||||
self.validate()
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return getattr(self, key)
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self, key, default)
|
||||
|
||||
def _match_shape_with_dynamic(
|
||||
self,
|
||||
actual: tuple[int, ...],
|
||||
reference: tuple[int, ...],
|
||||
expected_shape: tuple[int | str, ...],
|
||||
dynamic_dims: set[str],
|
||||
) -> bool:
|
||||
if len(actual) != len(reference) or len(actual) > len(expected_shape):
|
||||
return False
|
||||
|
||||
for i, (a, r) in enumerate(zip(actual, reference)):
|
||||
# When validating list inputs, we match shape suffixes only
|
||||
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
|
||||
# to the leading symbolic dim (e.g. "bn"). This allows comparing
|
||||
# only the trailing dimensions of each element in the list.
|
||||
dim = expected_shape[-len(actual) + i]
|
||||
# Skip this dimension if it's marked dynamic
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
if a != r:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
|
||||
if not idxs:
|
||||
return ""
|
||||
|
||||
return str(list(idxs))
|
||||
|
||||
def _validate_field(
|
||||
self,
|
||||
value: object,
|
||||
field_name: str,
|
||||
expected_shape: tuple[int | str, ...],
|
||||
dynamic_dims: set[str],
|
||||
leading_idxs: tuple[int, ...] = (),
|
||||
) -> tuple[int, ...]:
|
||||
"""Validate a field and return the actual shape."""
|
||||
if isinstance(value, (int, float)):
|
||||
return () # Scalar
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.shape
|
||||
|
||||
if not isinstance(value, (list, tuple)):
|
||||
raise TypeError(
|
||||
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
|
||||
f"one of the expected types: int, float, Tensor, list, tuple. "
|
||||
f"Got: {type(value)}"
|
||||
)
|
||||
|
||||
if len(value) == 0:
|
||||
raise ValueError(
|
||||
f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence"
|
||||
)
|
||||
|
||||
# Ensure all tensors in the list have the same
|
||||
# shape, besides dynamic dimensions
|
||||
for i, v in enumerate(value):
|
||||
shape = self._validate_field(
|
||||
v,
|
||||
field_name,
|
||||
expected_shape[1:],
|
||||
dynamic_dims,
|
||||
leading_idxs=leading_idxs + (i,),
|
||||
)
|
||||
|
||||
if i == 0:
|
||||
first_shape = shape
|
||||
elif not self._match_shape_with_dynamic(
|
||||
shape,
|
||||
first_shape,
|
||||
expected_shape,
|
||||
dynamic_dims,
|
||||
):
|
||||
raise ValueError(
|
||||
f"{field_name}{self._fmt_indexer(leading_idxs)} "
|
||||
f"contains inconsistent shapes: {first_shape} "
|
||||
f"(index 0) vs {shape} (index {i})"
|
||||
)
|
||||
|
||||
# Treat the list as a stacked tensor:
|
||||
# shape = (len(list), *tensor.shape)
|
||||
return (len(value),) + first_shape
|
||||
|
||||
def _validate_tensor_shape_expected(
|
||||
self,
|
||||
actual_shape: tuple[int, ...],
|
||||
expected_shape: tuple[int | str, ...],
|
||||
field_name: str,
|
||||
shape_env: dict[str, int],
|
||||
dynamic_dims: set[str],
|
||||
) -> None:
|
||||
"""Validate that the actual tensor shape matches the expected shape."""
|
||||
|
||||
if len(actual_shape) != len(expected_shape):
|
||||
raise ValueError(
|
||||
f"{field_name} has rank {len(actual_shape)} "
|
||||
f"but expected {len(expected_shape)}. "
|
||||
f"Expected shape: {expected_shape}, "
|
||||
f"but got {actual_shape}"
|
||||
)
|
||||
|
||||
for i, dim in enumerate(expected_shape):
|
||||
if dim in dynamic_dims:
|
||||
continue
|
||||
elif isinstance(dim, int):
|
||||
if actual_shape[i] != dim:
|
||||
raise ValueError(
|
||||
f"{field_name} dim[{i}] expected "
|
||||
f"{dim}, got {actual_shape[i]}. "
|
||||
f"Expected shape: {expected_shape}, "
|
||||
f"but got {actual_shape}"
|
||||
)
|
||||
elif isinstance(dim, str):
|
||||
if dim in shape_env:
|
||||
if actual_shape[i] != shape_env[dim]:
|
||||
raise ValueError(
|
||||
f"{field_name} dim[{i}] expected "
|
||||
f"'{dim}'={shape_env[dim]}, got "
|
||||
f"{actual_shape[i]}"
|
||||
)
|
||||
else:
|
||||
shape_env[dim] = actual_shape[i]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{field_name} dim[{i}] has unsupported type: {type(dim)}"
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
shape_env = dict[str, int]()
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
# Check if field is missing
|
||||
if not hasattr(self, field_name) or getattr(self, field_name) is None:
|
||||
# Check if field is marked as optional
|
||||
actual_type = field_type
|
||||
if get_origin(field_type) is Annotated:
|
||||
args = get_args(field_type)
|
||||
actual_type = args[0]
|
||||
|
||||
# Check arg was provided as Union
|
||||
if get_origin(actual_type) in {Union, UnionType}:
|
||||
# Union for Union[X, Y] and UnionType for X | Y
|
||||
args = get_args(actual_type)
|
||||
# Skip validation when Union contains None
|
||||
if type(None) in args:
|
||||
continue
|
||||
# Otherwise field is required, raise error
|
||||
raise ValueError(f"Required field '{field_name}' is missing")
|
||||
|
||||
# Field exists, proceed with validation
|
||||
value = getattr(self, field_name)
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
expected_shape = arg.resolve(**self._resolve_bindings)
|
||||
actual_shape = self._validate_field(
|
||||
value,
|
||||
field_name,
|
||||
expected_shape,
|
||||
arg.dynamic_dims,
|
||||
)
|
||||
|
||||
self._validate_tensor_shape_expected(
|
||||
actual_shape,
|
||||
expected_shape,
|
||||
field_name,
|
||||
shape_env,
|
||||
arg.dynamic_dims,
|
||||
)
|
||||
|
||||
def print_shapes(self) -> None:
|
||||
"""Print TensorShape annotations for debugging."""
|
||||
logger.debug("Shapes in %s:", self.__class__.__name__)
|
||||
type_hints = get_type_hints(self.__class__, include_extras=True)
|
||||
|
||||
for field_name, field_type in type_hints.items():
|
||||
if get_origin(field_type) is not None:
|
||||
args = get_args(field_type)
|
||||
for arg in args:
|
||||
if isinstance(arg, TensorShape):
|
||||
logger.debug(" %s: %s", field_name, str(arg))
|
||||
658
utils/torch_utils.py
Normal file
658
utils/torch_utils.py
Normal file
@@ -0,0 +1,658 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Callable, Collection
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
|
||||
import vllm.envs as envs
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.sequence import IntermediateTensors
|
||||
else:
|
||||
ModelConfig = object
|
||||
IntermediateTensors = object
|
||||
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float32": torch.float32,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.uint8,
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"fp8_inc": torch.float8_e4m3fn,
|
||||
"fp8_ds_mla": torch.uint8,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
torch.float64: np.float64,
|
||||
torch.uint8: np.uint8,
|
||||
torch.int32: np.int32,
|
||||
torch.int64: np.int64,
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_num_threads(num_threads: int):
|
||||
"""Sets the default number of threads for PyTorch to the given value."""
|
||||
old_num_threads = torch.get_num_threads()
|
||||
torch.set_num_threads(num_threads)
|
||||
yield
|
||||
torch.set_num_threads(old_num_threads)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def guard_cuda_initialization():
|
||||
"""Avoid unexpected CUDA initialization."""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
yield
|
||||
return
|
||||
|
||||
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
|
||||
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
if "No CUDA GPUs are available" in str(e):
|
||||
err_msg = "CUDA initialization is blocked."
|
||||
else:
|
||||
err_msg = str(e)
|
||||
raise RuntimeError(err_msg) from e
|
||||
finally:
|
||||
if had_key:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
|
||||
else:
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES")
|
||||
|
||||
|
||||
def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
"""Get the size of the data type in bytes."""
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
# bool = 0, int = 1, float = 2, complex = 3
|
||||
def _get_precision_level(dtype: torch.dtype) -> int:
|
||||
# NOTE: Complex dtypes return `is_floating_point=False`
|
||||
return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2
|
||||
|
||||
|
||||
def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
|
||||
"""
|
||||
Test whether it is lossless to cast a tensor from
|
||||
`src_dtype` to `tgt_dtype`.
|
||||
"""
|
||||
if src_dtype == tgt_dtype:
|
||||
return True
|
||||
|
||||
src_level = _get_precision_level(src_dtype)
|
||||
tgt_level = _get_precision_level(tgt_dtype)
|
||||
|
||||
if src_level < tgt_level:
|
||||
return True
|
||||
if src_level > tgt_level:
|
||||
return False
|
||||
|
||||
# Compare integral types
|
||||
if not src_dtype.is_floating_point and not src_dtype.is_complex:
|
||||
src_info = torch.iinfo(src_dtype)
|
||||
tgt_info = torch.iinfo(tgt_dtype)
|
||||
return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
|
||||
|
||||
# Compare floating-point types
|
||||
src_info = torch.finfo(src_dtype)
|
||||
tgt_info = torch.finfo(tgt_dtype)
|
||||
return (
|
||||
src_info.min >= tgt_info.min
|
||||
and src_info.max <= tgt_info.max
|
||||
and src_info.resolution >= tgt_info.resolution
|
||||
)
|
||||
|
||||
|
||||
def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
|
||||
"""
|
||||
Get the common `dtype` where all of the other `dtypes` can be
|
||||
cast to it without losing any information.
|
||||
"""
|
||||
return max(
|
||||
dtypes,
|
||||
key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
|
||||
)
|
||||
|
||||
|
||||
def _generate_random_fp8(
|
||||
tensor: torch.Tensor,
|
||||
low: float,
|
||||
high: float,
|
||||
) -> None:
|
||||
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
|
||||
# it may occur Inf or NaN if we directly use torch.randint
|
||||
# to generate random data for fp8 data.
|
||||
# For example, s.11111.00 in fp8e5m2 format represents Inf.
|
||||
# | E4M3 | E5M2
|
||||
# -----|-------------|-------------------
|
||||
# Inf | N/A | s.11111.00
|
||||
# NaN | s.1111.111 | s.11111.{01,10,11}
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
|
||||
tensor_tmp.uniform_(low, high)
|
||||
ops.convert_fp8(tensor, tensor_tmp)
|
||||
del tensor_tmp
|
||||
|
||||
|
||||
def get_kv_cache_torch_dtype(
|
||||
cache_dtype: str | torch.dtype | None,
|
||||
model_dtype: str | torch.dtype | None = None,
|
||||
) -> torch.dtype:
|
||||
if isinstance(cache_dtype, str):
|
||||
if cache_dtype == "auto":
|
||||
if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||
elif isinstance(model_dtype, torch.dtype):
|
||||
torch_dtype = model_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
||||
elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
elif isinstance(cache_dtype, torch.dtype):
|
||||
torch_dtype = cache_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype: str, model_config: ModelConfig
|
||||
) -> torch.dtype:
|
||||
if kv_cache_dtype == "auto":
|
||||
# Model config may not be specified for unit tests, default to float16
|
||||
return model_config.dtype if model_config else torch.half
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
|
||||
|
||||
def create_kv_caches_with_random_flash(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: str | torch.dtype | None,
|
||||
model_dtype: str | torch.dtype | None = None,
|
||||
seed: int | None = None,
|
||||
device: str | None = "cuda",
|
||||
cache_layout: str | None = "NHD",
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
|
||||
assert cache_layout in ("NHD", "HND")
|
||||
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
|
||||
|
||||
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order)
|
||||
scale = head_size**-0.5
|
||||
|
||||
key_caches: list[torch.Tensor] = []
|
||||
value_caches: list[torch.Tensor] = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
key_value_cache = torch.empty(
|
||||
size=kv_cache_allocation_shape, dtype=dtype, device=device
|
||||
).permute(*stride_order)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == "fp8":
|
||||
_generate_random_fp8(key_value_cache, -scale, scale)
|
||||
else:
|
||||
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
||||
key_caches.append(key_value_cache[:, 0])
|
||||
value_caches.append(key_value_cache[:, 1])
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
def create_kv_caches_with_random(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype: str | torch.dtype | None,
|
||||
model_dtype: str | torch.dtype | None = None,
|
||||
seed: int | None = None,
|
||||
device: str | None = "cuda",
|
||||
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
if cache_dtype == "fp8" and head_size % 16:
|
||||
raise ValueError(
|
||||
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||||
|
||||
scale = head_size**-0.5
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_caches: list[torch.Tensor] = []
|
||||
for _ in range(num_layers):
|
||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
key_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == "fp8":
|
||||
_generate_random_fp8(key_cache, -scale, scale)
|
||||
else:
|
||||
raise ValueError(f"Does not support key cache of type {cache_dtype}")
|
||||
key_caches.append(key_cache)
|
||||
|
||||
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
|
||||
value_caches: list[torch.Tensor] = []
|
||||
for _ in range(num_layers):
|
||||
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
|
||||
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
|
||||
value_cache.uniform_(-scale, scale)
|
||||
elif cache_dtype == "fp8":
|
||||
_generate_random_fp8(value_cache, -scale, scale)
|
||||
else:
|
||||
raise ValueError(f"Does not support value cache of type {cache_dtype}")
|
||||
value_caches.append(value_cache)
|
||||
return key_caches, value_caches
|
||||
|
||||
|
||||
def async_tensor_h2d(
|
||||
data: list,
|
||||
dtype: torch.dtype,
|
||||
target_device: str | torch.device,
|
||||
pin_memory: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Asynchronously create a tensor and copy it from host to device."""
|
||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||
return t.to(device=target_device, non_blocking=True)
|
||||
|
||||
|
||||
def make_ndarray_with_pad(
|
||||
x: list[list[T]],
|
||||
pad: T,
|
||||
dtype: npt.DTypeLike,
|
||||
*,
|
||||
max_len: int | None = None,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Make a padded array from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
if max_len is None:
|
||||
# Unlike for most functions, map is faster than a genexpr over `len`
|
||||
max_len = max(map(len, x), default=0)
|
||||
|
||||
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
|
||||
for ind, blocktb in enumerate(x):
|
||||
assert len(blocktb) <= max_len
|
||||
padded_x[ind, : len(blocktb)] = blocktb
|
||||
|
||||
return padded_x
|
||||
|
||||
|
||||
def make_tensor_with_pad(
|
||||
x: list[list[T]],
|
||||
pad: T,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
max_len: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Make a padded tensor from 2D inputs.
|
||||
|
||||
The padding is applied to the end of each inner list until it reaches
|
||||
`max_len`.
|
||||
"""
|
||||
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
|
||||
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
|
||||
|
||||
tensor = torch.from_numpy(padded_x).to(device)
|
||||
if pin_memory:
|
||||
tensor = tensor.pin_memory()
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
prev_set_stream = torch.cuda.set_stream
|
||||
|
||||
_current_stream_tls = threading.local()
|
||||
|
||||
|
||||
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
|
||||
_current_stream_tls.value = stream
|
||||
prev_set_stream(stream)
|
||||
|
||||
|
||||
torch.cuda.set_stream = _patched_set_stream
|
||||
|
||||
|
||||
class _StreamPlaceholder:
|
||||
def __init__(self):
|
||||
self.synchronize = lambda: None
|
||||
|
||||
|
||||
def current_stream() -> torch.cuda.Stream:
|
||||
"""
|
||||
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
|
||||
it turns out that `torch.cuda.current_stream()` is quite expensive,
|
||||
as it will construct a new stream object at each call.
|
||||
here we patch `torch.cuda.set_stream` to keep track of the current stream
|
||||
directly, so that we can avoid calling `torch.cuda.current_stream()`.
|
||||
|
||||
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
|
||||
from C/C++ code.
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
# On ROCm using the default 0 stream in combination with RCCL
|
||||
# is hurting performance. Therefore creating a dedicated stream
|
||||
# per process
|
||||
if current_platform.is_rocm():
|
||||
# torch.cuda.set_stream here is the alias of _pathed_set_stream
|
||||
torch.cuda.set_stream(torch.cuda.Stream())
|
||||
elif current_platform.is_cpu():
|
||||
_current_stream_tls.value = _StreamPlaceholder()
|
||||
else:
|
||||
current_stream = current_platform.current_stream
|
||||
if current_stream is not None:
|
||||
_current_stream_tls.value = current_stream()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Fail to set current stream, current platform "
|
||||
"may not support current_stream with torch API"
|
||||
)
|
||||
return _current_stream_tls.value
|
||||
|
||||
|
||||
# Global auxilary stream for running operations in background streams.
|
||||
# We have single global auxilary stream to avoid an explosion of streams
|
||||
# for every layer (and make profiling look sane).
|
||||
#
|
||||
# aux_stream() is currently used for:
|
||||
# - MoE shared_expert overlap with router
|
||||
_aux_stream: torch.cuda.Stream | None = None
|
||||
|
||||
|
||||
def aux_stream() -> torch.cuda.Stream | None:
|
||||
"""
|
||||
Ensures aux_stream is initialized only once
|
||||
"""
|
||||
global _aux_stream
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# TODO: validate this works properly on ROCm platform.
|
||||
if _aux_stream is None and current_platform.is_cuda():
|
||||
_aux_stream = torch.cuda.Stream()
|
||||
|
||||
return _aux_stream
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if current_platform.is_rocm():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = (
|
||||
torch.cuda._device_count_amdsmi()
|
||||
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||
else -1
|
||||
)
|
||||
else:
|
||||
raw_count = torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return ixfops.weak_ref_tensor(tensor)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: torch.Tensor
|
||||
| list[torch.Tensor]
|
||||
| tuple[torch.Tensor]
|
||||
| IntermediateTensors,
|
||||
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
return weak_ref_tensor(tensors)
|
||||
if isinstance(tensors, list):
|
||||
return [weak_ref_tensor(t) for t in tensors]
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
|
||||
# For IntermediateTensors used in pipeline parallelism
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if isinstance(tensors, IntermediateTensors):
|
||||
ret = IntermediateTensors(
|
||||
{key: weak_ref_tensor(val) for key, val in tensors.tensors.items()}
|
||||
)
|
||||
return ret
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
|
||||
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
|
||||
"""
|
||||
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
# Helper function used in testing.
|
||||
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
|
||||
torch_version = version.parse(torch_version)
|
||||
return torch_version >= version.parse(target)
|
||||
|
||||
|
||||
def is_torch_equal_or_newer(target: str) -> bool:
|
||||
"""Check if the installed torch version is >= the target version.
|
||||
|
||||
Args:
|
||||
target: a version string, like "2.6.0".
|
||||
|
||||
Returns:
|
||||
Whether the condition meets.
|
||||
"""
|
||||
try:
|
||||
return _is_torch_equal_or_newer(str(torch.__version__), target)
|
||||
except Exception:
|
||||
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
|
||||
return Version(importlib.metadata.version("torch")) >= Version(target)
|
||||
|
||||
|
||||
def _is_torch_equal(target: str) -> bool:
|
||||
assert target.count(".") == 2
|
||||
torch_version = str(torch.__version__)
|
||||
torch_version = version.parse(torch_version)
|
||||
# torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu"
|
||||
# or "2.6.0+cu128" but never "2.6.0.1"
|
||||
return (
|
||||
torch_version >= version.parse(target)
|
||||
and version.parse(target + ".1") > torch_version
|
||||
)
|
||||
|
||||
|
||||
def is_torch_equal(target: str) -> bool:
|
||||
"""Check if the installed torch version is == the target version.
|
||||
|
||||
Args:
|
||||
target: a version string, like "2.6.0".
|
||||
|
||||
Returns:
|
||||
Whether the condition meets.
|
||||
"""
|
||||
try:
|
||||
return _is_torch_equal(target)
|
||||
except Exception:
|
||||
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||
|
||||
|
||||
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
|
||||
# In particular, the FakeScalarType is not supported for earlier versions of
|
||||
# PyTorch which breaks dynamo for any ops registered using ScalarType.
|
||||
def supports_dynamo() -> bool:
|
||||
return is_torch_equal_or_newer("2.4.0")
|
||||
|
||||
|
||||
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
||||
def supports_xccl() -> bool:
|
||||
return (
|
||||
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
|
||||
)
|
||||
|
||||
|
||||
# Some backends use pytorch version < 2.4.0 which doesn't
|
||||
# support `torch.library.custom_op`.
|
||||
def supports_custom_op() -> bool:
|
||||
return hasattr(torch.library, "custom_op")
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
||||
|
||||
|
||||
def direct_register_custom_op(
|
||||
op_name: str,
|
||||
op_func: Callable,
|
||||
mutates_args: list[str] | None = None,
|
||||
fake_impl: Callable | None = None,
|
||||
target_lib: Library | None = None,
|
||||
dispatch_key: str | None = None,
|
||||
tags: tuple[torch.Tag, ...] = (),
|
||||
):
|
||||
"""
|
||||
`torch.library.custom_op` can have significant overhead because it
|
||||
needs to consider complicated dispatching logic. This function
|
||||
directly registers a custom op and dispatches it to the CUDA backend.
|
||||
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
||||
for more details.
|
||||
|
||||
By default, the custom op is registered to the vLLM library. If you
|
||||
want to register it to a different library, you can pass the library
|
||||
object to the `target_lib` argument.
|
||||
|
||||
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
||||
library object. If you want to bind the operator to a different library,
|
||||
make sure the library object is alive when the operator is used.
|
||||
"""
|
||||
if not supports_custom_op():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
assert not current_platform.is_cuda_alike(), (
|
||||
"cuda platform needs torch>=2.4 to support custom op, "
|
||||
"chances are you are using an old version of pytorch "
|
||||
"or a custom build of pytorch. It is recommended to "
|
||||
"use vLLM in a fresh new environment and let it install "
|
||||
"the required dependencies."
|
||||
)
|
||||
return
|
||||
|
||||
if mutates_args is None:
|
||||
mutates_args = []
|
||||
|
||||
if dispatch_key is None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dispatch_key = current_platform.dispatch_key
|
||||
|
||||
import torch.library
|
||||
|
||||
if hasattr(torch.library, "infer_schema"):
|
||||
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
||||
else:
|
||||
# for pytorch 2.4
|
||||
import torch._custom_op.impl
|
||||
|
||||
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
||||
my_lib = target_lib or vllm_lib
|
||||
my_lib.define(op_name + schema_str, tags=tags)
|
||||
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
Reference in New Issue
Block a user