This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

82
utils/__init__.py Normal file
View File

@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
import warnings
from typing import Any
import torch
from vllm.logger import init_logger
_DEPRECATED_MAPPINGS = {
"cprofile": "profiling",
"cprofile_context": "profiling",
# Used by lm-eval
"get_open_port": "network_utils",
}
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
"""Module-level getattr to handle deprecated utilities."""
if name in _DEPRECATED_MAPPINGS:
submodule_name = _DEPRECATED_MAPPINGS[name]
warnings.warn(
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
f"Use vllm.utils.{submodule_name}.{name} instead.",
DeprecationWarning,
stacklevel=2,
)
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
return getattr(module, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__() -> list[str]:
# expose deprecated names in dir() for better UX/tab-completion
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
logger = init_logger(__name__)
# Constants related to forcing the attention backend selection
# String name of register which may be set in order to
# force auto-selection of attention backend by Attention
# wrapper
STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND"
# Possible string values of STR_BACKEND_ENV_VAR
# register, corresponding to possible backends
STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
def random_uuid() -> str:
return str(uuid.uuid4().hex)
def length_from_prompt_token_ids_or_embeds(
prompt_token_ids: list[int] | None,
prompt_embeds: torch.Tensor | None,
) -> int:
"""Calculate the request length (in number of tokens) give either
prompt_token_ids or prompt_embeds.
"""
prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids)
prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds)
if prompt_token_len is None:
if prompt_embeds_len is None:
raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.")
return prompt_embeds_len
else:
if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len:
raise ValueError(
"Prompt token ids and prompt embeds had different lengths"
f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}"
)
return prompt_token_len

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

487
utils/argparse_utils.py Normal file
View File

@@ -0,0 +1,487 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Argument parsing utilities for vLLM."""
import json
import sys
import textwrap
from argparse import (
Action,
ArgumentDefaultsHelpFormatter,
ArgumentParser,
ArgumentTypeError,
Namespace,
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import defaultdict
from typing import Any
import regex as re
import yaml
from vllm.logger import init_logger
logger = init_logger(__name__)
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(" ", text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
_deprecated: set[Action] = set()
_json_tip: str = (
"When passing JSON CLI arguments, the following sets of arguments "
"are equivalent:\n"
' --json-arg \'{"key1": "value1", "key2": {"key3": "value2"}}\'\n'
" --json-arg.key1 value1 --json-arg.key2.key3 value2\n\n"
"Additionally, list elements can be passed individually using +:\n"
' --json-arg \'{"key4": ["value3", "value4", "value5"]}\'\n'
" --json-arg.key4+ value3 --json-arg.key4+='value4,value5'\n\n"
)
_search_keyword: str | None = None
def __init__(self, *args, **kwargs):
# Set the default "formatter_class" to SortedHelpFormatter
if "formatter_class" not in kwargs:
kwargs["formatter_class"] = SortedHelpFormatter
# Pop kwarg "add_json_tip" to control whether to add the JSON tip
self.add_json_tip = kwargs.pop("add_json_tip", True)
super().__init__(*args, **kwargs)
if sys.version_info < (3, 13):
# Enable the deprecated kwarg for Python 3.12 and below
def parse_known_args(self, args=None, namespace=None):
if args is not None and "--disable-log-requests" in args:
# Special case warning because the warning below won't trigger
# if -disable-log-requests because its value is default.
logger.warning_once(
"argument '--disable-log-requests' is deprecated and "
"replaced with '--enable-log-requests'. This will be "
"removed in v0.12.0."
)
namespace, args = super().parse_known_args(args, namespace)
for action in FlexibleArgumentParser._deprecated:
if (
hasattr(namespace, dest := action.dest)
and getattr(namespace, dest) != action.default
):
logger.warning_once("argument '%s' is deprecated", dest)
return namespace, args
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
class _FlexibleArgumentGroup(_ArgumentGroup):
def add_argument(self, *args, **kwargs):
deprecated = kwargs.pop("deprecated", False)
action = super().add_argument(*args, **kwargs)
if deprecated:
FlexibleArgumentParser._deprecated.add(action)
return action
def add_argument_group(self, *args, **kwargs):
group = self._FlexibleArgumentGroup(self, *args, **kwargs)
self._action_groups.append(group)
return group
def format_help(self):
# Only use custom help formatting for bottom level parsers
if self._subparsers is not None:
return super().format_help()
formatter = self._get_formatter()
# Handle keyword search of the args
if (search_keyword := self._search_keyword) is not None:
# Normalise the search keyword
search_keyword = search_keyword.lower().replace("_", "-")
# Return full help if searching for 'all'
if search_keyword == "all":
self.epilog = self._json_tip
return super().format_help()
# Return group help if searching for a group title
for group in self._action_groups:
if group.title and group.title.lower() == search_keyword:
formatter.start_section(group.title)
formatter.add_text(group.description)
formatter.add_arguments(group._group_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# Return matched args if searching for an arg name
matched_actions = []
for group in self._action_groups:
for action in group._group_actions:
# search option name
if any(
search_keyword in opt.lower() for opt in action.option_strings
):
matched_actions.append(action)
if matched_actions:
formatter.start_section(f"Arguments matching '{search_keyword}'")
formatter.add_arguments(matched_actions)
formatter.end_section()
formatter.add_text(self._json_tip)
return formatter.format_help()
# No match found
formatter.add_text(
f"No group or arguments matching '{search_keyword}'.\n"
"Use '--help' to see available groups or "
"'--help=all' to see all available parameters."
)
return formatter.format_help()
# usage
formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups)
# description
formatter.add_text(self.description)
# positionals, optionals and user-defined groups
formatter.start_section("Config Groups")
config_groups = ""
for group in self._action_groups:
if not group._group_actions:
continue
title = group.title
description = group.description or ""
config_groups += f"{title: <24}{description}\n"
formatter.add_text(config_groups)
formatter.end_section()
# epilog
formatter.add_text(self.epilog)
# determine help from format above
return formatter.format_help()
def parse_args( # type: ignore[override]
self,
args: list[str] | None = None,
namespace: Namespace | None = None,
):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
try:
model_idx = next(
i
for i, arg in enumerate(args)
if arg == "--model" or arg.startswith("--model=")
)
logger.warning(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option. "
"The `--model` option will be removed in v0.13."
)
if args[model_idx] == "--model":
model_tag = args[model_idx + 1]
rest_start_idx = model_idx + 2
else:
model_tag = args[model_idx].removeprefix("--model=")
rest_start_idx = model_idx + 1
# Move <model> to the front, e,g:
# [Before]
# vllm serve -tp 2 --model <model> --enforce-eager --port 8001
# [After]
# vllm serve <model> -tp 2 --enforce-eager --port 8001
args = [
"serve",
model_tag,
*args[1:model_idx],
*args[rest_start_idx:],
]
except StopIteration:
pass
if "--config" in args:
args = self._pull_args_from_config(args)
def repl(match: re.Match) -> str:
"""Replaces underscores with dashes in the matched string."""
return match.group(0).replace("_", "-")
# Everything between the first -- and the first .
pattern = re.compile(r"(?<=--)[^\.]*")
# Convert underscores to dashes and vice versa in argument names
processed_args = list[str]()
for i, arg in enumerate(args):
if arg.startswith("--help="):
FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower()
processed_args.append("--help")
elif arg.startswith("--"):
if "=" in arg:
key, value = arg.split("=", 1)
key = pattern.sub(repl, key, count=1)
processed_args.append(f"{key}={value}")
else:
key = pattern.sub(repl, arg, count=1)
processed_args.append(key)
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
else:
processed_args.append(arg)
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.
For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
`{"a": {"b": {"c": 1}}}`
"""
nested_dict: Any = value
for key in reversed(keys):
nested_dict = {key: nested_dict}
return nested_dict
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
) -> set[str]:
"""Recursively updates a dictionary with another dictionary.
Returns a set of duplicate keys that were overwritten.
"""
duplicates = set[str]()
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
nested_duplicates = recursive_dict_update(original[k], v)
duplicates |= {f"{k}.{d}" for d in nested_duplicates}
elif isinstance(v, list) and isinstance(original.get(k), list):
original[k] += v
else:
if k in original:
duplicates.add(k)
original[k] = v
return duplicates
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, '.' was only in the value
continue
else:
value_str = processed_args[i + 1]
delete.add(i + 1)
if processed_arg.endswith("+"):
processed_arg = processed_arg[:-1]
value_str = json.dumps(list(value_str.split(",")))
key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
arg_duplicates = recursive_dict_update(dict_args[key], arg_dict)
duplicates |= {f"{key}.{d}" for d in arg_duplicates}
delete.add(i)
# Filter out the dict args we set to None
processed_args = [a for i, a in enumerate(processed_args) if i not in delete]
if duplicates:
logger.warning("Found duplicate keys %s", ", ".join(duplicates))
# Add the dict args back as if they were originally passed as JSON
for dict_arg, dict_value in dict_args.items():
processed_args.append(dict_arg)
processed_args.append(json.dumps(dict_value))
return super().parse_args(processed_args, namespace)
def check_port(self, value):
try:
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
def _pull_args_from_config(self, args: list[str]) -> list[str]:
"""Method to pull arguments specified in the config file
into the command-line args variable.
The arguments in config file will be inserted between
the argument list.
example:
```yaml
port: 12323
tensor-parallel-size: 4
```
```python
$: vllm {serve,chat,complete} "facebook/opt-12B" \
--config config.yaml -tp 2
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--config', 'config.yaml',
'-tp', '2'
]
$: args = [
"serve,chat,complete",
"facebook/opt-12B",
'--port', '12323',
'--tensor-parallel-size', '4',
'-tp', '2'
]
```
Please note how the config args are inserted after the sub command.
this way the order of priorities is maintained when these are args
parsed by super().
"""
assert args.count("--config") <= 1, "More than one config file specified!"
index = args.index("--config")
if index == len(args) - 1:
raise ValueError(
"No config file specified! \
Please check your command-line arguments."
)
file_path = args[index + 1]
config_args = self.load_config_file(file_path)
# 0th index might be the sub command {serve,chat,complete,...}
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0].startswith("-"):
# No sub command (e.g., api_server entry point)
args = config_args + args[0:index] + args[index + 2 :]
elif args[0] == "serve":
model_in_cli = len(args) > 1 and not args[1].startswith("-")
model_in_config = any(arg == "--model" for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model specified! Please specify model either "
"as a positional argument or in a config file."
)
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = (
[args[0]]
+ [args[1]]
+ config_args
+ args[2:index]
+ args[index + 2 :]
)
else:
# No model in CLI, use config if available
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2 :]
return args
def load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
```yaml
port: 12323
tensor-parallel-size: 4
```
returns:
processed_args: list[str] = [
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split(".")[-1]
if extension not in ("yaml", "yml"):
raise ValueError(
f"Config file must be of a yaml/yml type. {extension} supplied"
)
# only expecting a flat dictionary of atomic types
processed_args: list[str] = []
config: dict[str, int | str] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
except Exception as ex:
logger.error(
"Unable to read the config file at %s. Check path correctness",
file_path,
)
raise ex
for key, value in config.items():
if isinstance(value, bool):
if value:
processed_args.append("--" + key)
elif isinstance(value, list):
if value:
processed_args.append("--" + key)
for item in value:
processed_args.append(str(item))
else:
processed_args.append("--" + key)
processed_args.append(str(value))
return processed_args

303
utils/async_utils.py Normal file
View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers related to asynchronous code.
This is similar in concept to the `asyncio` module.
"""
import asyncio
import contextlib
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import AsyncGenerator, Awaitable, Callable
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from typing import TypeVar
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
class AsyncMicrobatchTokenizer:
"""Asynchronous tokenizer with micro-batching.
Pulls pending encode/decode requests from a queue and batches them
up to reduce overhead. A single-thread ThreadPoolExecutor is used
so the event loop stays responsive.
"""
def __init__(
self,
tokenizer,
max_batch_size: int = 32,
batch_wait_timeout_s: float = 0.002,
) -> None:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.batch_wait_timeout_s = batch_wait_timeout_s
self._loop = asyncio.get_running_loop()
self._queues: dict[
tuple,
asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]],
] = {}
self._batcher_tasks: list[Task] = []
# Single-thread executor for blocking tokenizer calls.
self._executor = ThreadPoolExecutor(max_workers=1)
# === Public async API ===
async def __call__(self, prompt, **kwargs):
result_future: Future = self._loop.create_future()
key = self._queue_key("encode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((prompt, kwargs, result_future))
return await result_future
async def decode(self, token_ids, **kwargs):
result_future: Future = self._loop.create_future()
key = self._queue_key("decode", kwargs)
queue = self._get_queue(self._loop, key)
await queue.put((token_ids, result_future))
return await result_future
# === Internal helpers ===
def _get_queue(
self, loop: asyncio.AbstractEventLoop, key: tuple
) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]:
"""Get the request queue for the given operation key, creating a new
queue and batcher task if needed."""
queue = self._queues.get(key)
if queue is None:
self._queues[key] = queue = asyncio.Queue()
if key[0] == "encode":
can_batch = key[1] != "other"
coro = self._batch_encode_loop(queue, can_batch)
else:
assert key[0] == "decode", f"Unknown operation type: {key[0]}."
coro = self._batch_decode_loop(queue)
self._batcher_tasks.append(loop.create_task(coro))
return queue
async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool):
"""Batch incoming encode requests for efficiency."""
while True:
prompt, kwargs, result_future = await queue.get()
prompts = [prompt]
kwargs_list = [kwargs]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(prompts) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
prompt, kwargs, result_future = await asyncio.wait_for(
queue.get(), timeout
)
prompts.append(prompt)
result_futures.append(result_future)
if not can_batch:
kwargs_list.append(kwargs)
except asyncio.TimeoutError:
break
try:
# If every request uses identical kwargs we can run a single
# batched tokenizer call for a big speed-up.
if can_batch and len(prompts) > 1:
batch_encode_fn = partial(self.tokenizer, prompts, **kwargs)
results = await self._loop.run_in_executor(
self._executor, batch_encode_fn
)
for i, fut in enumerate(result_futures):
if not fut.done():
data = {k: v[i] for k, v in results.items()}
fut.set_result(BatchEncoding(data))
else:
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs)
]
results = await self._loop.run_in_executor(
self._executor, encode_fn
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
async def _batch_decode_loop(self, queue: asyncio.Queue):
"""Batch incoming decode requests for efficiency."""
while True:
token_ids, result_future = await queue.get()
token_ids_list = [token_ids]
result_futures = [result_future]
deadline = self._loop.time() + self.batch_wait_timeout_s
while len(token_ids_list) < self.max_batch_size:
timeout = deadline - self._loop.time()
if timeout <= 0:
break
try:
token_ids, result_future = await asyncio.wait_for(
queue.get(), timeout
)
token_ids_list.append(token_ids)
result_futures.append(result_future)
except asyncio.TimeoutError:
break
try:
# Perform a single batched decode call for all requests
results = await self._loop.run_in_executor(
self._executor, self.tokenizer.batch_decode, token_ids_list
)
for fut, res in zip(result_futures, results):
if not fut.done():
fut.set_result(res)
except Exception as e:
for fut in result_futures:
if not fut.done():
fut.set_exception(e)
def _queue_key(self, op: str, kwargs: dict) -> tuple:
"""
Return a normalized key describing operation + kwargs.
- `add_special_tokens`: {True/False}
- `truncation`: {True/False}
- If `truncation` is False (`max_length` is None),
returns a key for a can_batch queue.
- If `truncation` is True and `max_length` is None or equals
`tokenizer.model_max_length`, returns a key for a can_batch queue.
- Otherwise, returns a key for a cannot_batch queue.
Examples:
- Decode: ("decode",)
- Encode typical:
("encode", add_special_tokens, bool_truncation, max_length_label)
- Fallback: ("encode", "other")
"""
if op == "decode":
return ("decode",)
add_special_tokens = kwargs.get("add_special_tokens", True)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
if not truncation:
return "encode", add_special_tokens, False, None
model_max = getattr(self.tokenizer, "model_max_length", None)
if max_length is None or (model_max is not None and max_length == model_max):
return "encode", add_special_tokens, True, "model_max"
return "encode", "other"
def __del__(self):
if (
(tasks := getattr(self, "_batcher_tasks", None))
and (loop := getattr(self, "_loop", None))
and not loop.is_closed()
):
def cancel_tasks():
for task in tasks:
task.cancel()
loop.call_soon_threadsafe(cancel_tasks)
def cancel_task_threadsafe(task: Task):
if task and not task.done():
run_in_loop(task.get_loop(), task.cancel)
def make_async(
func: Callable[P, T],
executor: Executor | None = None,
) -> Callable[P, Awaitable[T]]:
"""
Take a blocking function, and run it on in an executor thread.
This function prevents the blocking function from blocking the
asyncio event loop.
The code in this function needs to be thread safe.
"""
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=executor, func=p_func)
return _async_wrapper
def run_in_loop(loop: AbstractEventLoop, function: Callable, *args):
if in_loop(loop):
function(*args)
elif not loop.is_closed():
loop.call_soon_threadsafe(function, *args)
def in_loop(event_loop: AbstractEventLoop) -> bool:
try:
return asyncio.get_running_loop() == event_loop
except RuntimeError:
return False
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
if len(iterators) == 1:
# Fast-path single iterator case.
async for item in iterators[0]:
yield 0, item
return
loop = asyncio.get_running_loop()
awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)}
try:
while awaits:
done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED)
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[loop.create_task(anext(it))] = pair
yield i, item
except StopAsyncIteration:
pass
finally:
# Cancel any remaining iterators
for f, (_, it) in awaits.items():
with contextlib.suppress(BaseException):
f.cancel()
await it.aclose()
async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]:
"""Collect all items from an async generator into a list."""
items = []
async for item in iterator:
items.append(item)
return items

214
utils/cache.py Normal file
View File

@@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import UserDict
from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping
from types import MappingProxyType
from typing import NamedTuple, TypeVar, cast, overload
import cachetools
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
_T = TypeVar("_T")
class _Sentinel: ...
ALL_PINNED_SENTINEL = _Sentinel()
class _MappingOrderCacheView(UserDict[_K, _V]):
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
super().__init__(data)
self.ordered_keys = ordered_keys
def __iter__(self) -> Iterator[_K]:
return iter(self.ordered_keys)
def keys(self) -> KeysView[_K]:
return KeysView(self.ordered_keys)
class CacheInfo(NamedTuple):
hits: int
total: int
@property
def hit_ratio(self) -> float:
if self.total == 0:
return 0
return self.hits / self.total
def __sub__(self, other: "CacheInfo"):
return CacheInfo(
hits=self.hits - other.hits,
total=self.total - other.total,
)
class LRUCache(cachetools.LRUCache[_K, _V]):
def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]()
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
value = super().__getitem__(key)
if update_info:
self._hits += 1
self._total += 1
return value
def __delitem__(self, key: _K) -> None:
run_on_remove = key in self
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
@property
def cache(self) -> Mapping[_K, _V]:
"""Return the internal cache dictionary in order (read-only)."""
return _MappingOrderCacheView(
self._Cache__data, # type: ignore
self.order,
)
@property
def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
@property
def capacity(self) -> float:
return self.maxsize
@property
def usage(self) -> float:
if self.maxsize == 0:
return 0
return self.currsize / self.maxsize
def stat(self, *, delta: bool = False) -> CacheInfo:
"""
Gets the cumulative number of hits and queries against this cache.
If `delta=True`, instead gets these statistics
since the last call that also passed `delta=True`.
"""
info = CacheInfo(hits=self._hits, total=self._total)
if delta:
info_delta = info - self._last_info
self._last_info = info
info = info_delta
return info
def touch(self, key: _K) -> None:
try:
self._LRUCache__order.move_to_end(key) # type: ignore
except KeyError:
self._LRUCache__order[key] = None # type: ignore
@overload
def get(self, key: _K, /) -> _V | None: ...
@overload
def get(self, key: _K, /, default: _V | _T) -> _V | _T: ...
def get(self, key: _K, /, default: _V | _T | None = None) -> _V | _T | None:
value: _V | _T | None
if key in self:
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
self._hits += 1
else:
value = default
self._total += 1
return value
@overload
def pop(self, key: _K) -> _V: ...
@overload
def pop(self, key: _K, default: _V | _T) -> _V | _T: ...
def pop(self, key: _K, default: _V | _T | None = None) -> _V | _T | None:
value: _V | _T | None
if key not in self:
return default
value = self.__getitem__(key, update_info=False) # type: ignore[call-arg]
self.__delitem__(key)
return value
def put(self, key: _K, value: _V) -> None:
self.__setitem__(key, value)
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: _K) -> None:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: _V | None) -> None:
pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if len(self) == 0:
return
self.popitem(remove_pinned=remove_pinned)
def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.order if key not in self.pinned_items),
ALL_PINNED_SENTINEL,
)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError(
"All items are pinned, cannot remove oldest from the cache."
)
else:
lru_key = next(iter(self.order))
value = self.pop(cast(_K, lru_key))
return (lru_key, value)
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
self._hits = 0
self._total = 0
self._last_info = CacheInfo(hits=0, total=0)

139
utils/collection_utils.py Normal file
View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to collections.
This is similar in concept to the `collections` module.
"""
from collections import UserDict, defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class ClassRegistry(UserDict[type[T], _V]):
"""
A registry that acts like a dictionary but searches for other classes
in the MRO if the original class is not found.
"""
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
class LazyDict(Mapping[str, T], Generic[T]):
"""
Evaluates dictionary items only when they are accessed.
Adapted from: https://stackoverflow.com/a/47212782/5082708
"""
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
self._factory[key] = value
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""Swap values between two keys."""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)

45
utils/counter.py Normal file
View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
class Counter:
def __init__(self, start: int = 0) -> None:
super().__init__()
self.counter = start
def __next__(self) -> int:
i = self.counter
self.counter += 1
return i
def reset(self) -> None:
self.counter = 0
class AtomicCounter:
"""An atomic, thread-safe counter"""
def __init__(self, initial: int = 0) -> None:
"""Initialize a new atomic counter to given initial value"""
super().__init__()
self._value = initial
self._lock = threading.Lock()
@property
def value(self) -> int:
return self._value
def inc(self, num: int = 1) -> int:
"""Atomically increment the counter by num and return the new value"""
with self._lock:
self._value += num
return self._value
def dec(self, num: int = 1) -> int:
"""Atomically decrement the counter by num and return the new value"""
with self._lock:
self._value -= num
return self._value

391
utils/deep_gemm.py Normal file
View File

@@ -0,0 +1,391 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import **only** these wrappers.
"""
import functools
import importlib
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, NoReturn
import torch
import vllm.envs as envs
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
class DeepGemmQuantScaleFMT(Enum):
# Float32 scales in Float32 tensor
FLOAT32 = 0
# Compute float32 scales and ceil the scales to UE8M0.
# Keep the scales in Float32 tensor.
FLOAT32_CEIL_UE8M0 = 1
# Compute float32 scales and ceil the scales to UE8M0.
# Pack the scales into a int32 tensor where each int32
# element contains 4 scale values.
UE8M0 = 2
@staticmethod
def from_oracle() -> "DeepGemmQuantScaleFMT":
if not is_deep_gemm_e8m0_used():
return DeepGemmQuantScaleFMT.FLOAT32
return (
DeepGemmQuantScaleFMT.UE8M0
if current_platform.is_device_capability(100)
else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
)
@functools.cache
def is_deep_gemm_supported() -> bool:
"""Return `True` if DeepGEMM is supported on the current platform.
Currently, only Hopper and Blackwell GPUs are supported.
"""
is_supported_arch = current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability(100)
)
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
@functools.cache
def is_deep_gemm_e8m0_used() -> bool:
"""Return `True` if vLLM is configured to use DeepGEMM "
"E8M0 scale on a Hopper or Blackwell-class GPU.
"""
if not is_deep_gemm_supported():
logger.debug_once(
"DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
)
return False
_lazy_init()
if _fp8_gemm_nt_impl is None:
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False
if envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.info_once("DeepGEMM E8M0 enabled on current platform.")
return True
logger.info_once("DeepGEMM E8M0 disabled on current configuration.")
return False
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable DeepGEMM backend."""
raise RuntimeError(
"DeepGEMM backend is not available or outdated. Please install or "
"update the `deep_gemm` to a newer version to enable FP8 kernels."
)
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None
_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None
def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
global _get_paged_mqa_logits_metadata_impl
global _get_mn_major_tma_aligned_tensor_impl
global _get_mk_alignment_for_contiguous_layout_impl
global _transform_sf_into_required_layout_impl
# fast path
if (
_fp8_gemm_nt_impl is not None
or _grouped_impl is not None
or _grouped_masked_impl is not None
or _fp8_mqa_logits_impl is not None
or _fp8_paged_mqa_logits_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None
or _get_mk_alignment_for_contiguous_layout_impl is not None
or _transform_sf_into_required_layout_impl is not None
):
return
if not has_deep_gemm():
return
# Set up deep_gemm cache path
DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
envs.VLLM_CACHE_ROOT, "deep_gemm"
)
_dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
_get_paged_mqa_logits_metadata_impl = getattr(
_dg, "get_paged_mqa_logits_metadata", None
)
_get_mn_major_tma_aligned_tensor_impl = getattr(
_dg, "get_mn_major_tma_aligned_tensor", None
)
_get_mk_alignment_for_contiguous_layout_impl = getattr(
_dg, "get_mk_alignment_for_contiguous_layout", None
)
_transform_sf_into_required_layout_impl = getattr(
_dg, "transform_sf_into_required_layout", None
)
def get_num_sms() -> int:
_lazy_init()
_dg = importlib.import_module("deep_gemm")
return int(_dg.get_num_sms())
@functools.cache
def get_mk_alignment_for_contiguous_layout() -> list[int]:
_lazy_init()
if _get_mk_alignment_for_contiguous_layout_impl is None:
return _missing()
mk_align_size = _get_mk_alignment_for_contiguous_layout_impl()
return [mk_align_size, mk_align_size]
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
_lazy_init()
if _get_mn_major_tma_aligned_tensor_impl is None:
return _missing()
return _get_mn_major_tma_aligned_tensor_impl(x)
def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs)
if "is_deep_gemm_e8m0_used" in kwargs:
use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"]
del kwargs["is_deep_gemm_e8m0_used"]
else:
use_ue8m0 = is_deep_gemm_e8m0_used()
return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def transform_sf_into_required_layout(*args, **kwargs):
_lazy_init()
if _transform_sf_into_required_layout_impl is None:
return _missing(*args, **kwargs)
return _transform_sf_into_required_layout_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.
Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
_lazy_init()
if _fp8_mqa_logits_impl is None:
return _missing()
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
def get_paged_mqa_logits_metadata(
context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
"""Build scheduling metadata for paged MQA logits.
Args:
context_lens: Tensor of shape [B], dtype int32; effective context length
per batch element.
block_size: KV-cache block size in tokens (e.g., 64).
num_sms: Number of SMs available. 132 for Hopper
Returns:
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
schedule work across SMs.
"""
_lazy_init()
if _get_paged_mqa_logits_metadata_impl is None:
return _missing()
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
def fp8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.
Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
_lazy_init()
if _fp8_paged_mqa_logits_impl is None:
return _missing()
return _fp8_paged_mqa_logits_impl(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=True,
)
def _ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def _align(x: int, y: int) -> int:
return cdiv(x, y) * y
DEFAULT_BLOCK_SIZE = [128, 128]
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def per_block_cast_to_fp8(
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
block_m, block_n = block_size
x_padded = torch.zeros(
(_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2)
)
def calc_diff(x: torch.Tensor, y: torch.Tensor):
"""Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
error, causing `torch.testing.assert_close` to fail. Instead of checking
every element, we compute a cosine-style similarity over the whole tensor
and report `1 - sim`. Once kernel accuracy improves this helper can be
removed.
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def should_use_deepgemm_for_fp8_linear(
output_dtype: torch.dtype,
weight: torch.Tensor,
supports_deep_gemm: bool | None = None,
):
if supports_deep_gemm is None:
supports_deep_gemm = is_deep_gemm_supported()
return (
supports_deep_gemm
and output_dtype == torch.bfloat16
and weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0
)
__all__ = [
"calc_diff",
"fp8_gemm_nt",
"m_grouped_fp8_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_mqa_logits",
"fp8_paged_mqa_logits",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",
"is_deep_gemm_supported",
"get_num_sms",
"should_use_deepgemm_for_fp8_linear",
"get_col_major_tma_aligned_tensor",
"get_mk_alignment_for_contiguous_layout",
]

490
utils/flashinfer.py Normal file
View File

@@ -0,0 +1,490 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compatibility wrapper for FlashInfer API changes.
Users of vLLM should always import **only** these wrappers.
"""
import contextlib
import functools
import importlib
import importlib.util
import os
import shutil
from collections.abc import Callable
from typing import Any, NoReturn
import requests
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
"FLASHINFER_CUBINS_REPOSITORY",
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
)
@functools.cache
def has_flashinfer_cubin() -> bool:
"""Return `True` if flashinfer-cubin package is available."""
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True
if importlib.util.find_spec("flashinfer_cubin") is not None:
return True
logger.debug_once("flashinfer-cubin package was not found")
return False
@functools.cache
def has_flashinfer() -> bool:
"""Return `True` if flashinfer-python package is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
if importlib.util.find_spec("flashinfer") is None:
logger.debug_once("FlashInfer unavailable since package was not found")
return False
# When not using flashinfer cubin,
# Also check if nvcc is available since it's required to JIT compile flashinfer
if not has_flashinfer_cubin() and shutil.which("nvcc") is None:
logger.debug_once(
"FlashInfer unavailable since nvcc was not found "
"and not using pre-downloaded cubins"
)
return False
return True
def _missing(*_: Any, **__: Any) -> NoReturn:
"""Placeholder for unavailable FlashInfer backend."""
raise RuntimeError(
"FlashInfer backend is not available. Please install the package "
"to enable FlashInfer kernels: "
"https://github.com/flashinfer-ai/flashinfer"
)
def _get_submodule(module_name: str) -> Any | None:
"""Safely import a submodule and return it, or None if not available."""
try:
return importlib.import_module(module_name)
except (ImportError, ModuleNotFoundError):
return None
# General lazy import wrapper
def _lazy_import_wrapper(
module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing
):
"""Create a lazy import wrapper for a specific function."""
@functools.cache
def _get_impl():
if not has_flashinfer():
return None
mod = _get_submodule(module_name)
return getattr(mod, attr_name, None) if mod else None
def wrapper(*args, **kwargs):
impl = _get_impl()
if impl is None:
return fallback_fn(*args, **kwargs)
return impl(*args, **kwargs)
return wrapper
# Create lazy wrappers for each function
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"
)
flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"
)
flashinfer_cutlass_fused_moe = _lazy_import_wrapper(
"flashinfer.fused_moe", "cutlass_fused_moe"
)
flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave"
)
trtllm_fp4_block_scale_moe = _lazy_import_wrapper(
"flashinfer", "trtllm_fp4_block_scale_moe"
)
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
"flashinfer.autotuner",
"autotune",
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(),
)
@functools.cache
def has_flashinfer_comm() -> bool:
"""Return `True` if FlashInfer comm module is available."""
return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None
@functools.cache
def has_flashinfer_all2all() -> bool:
"""Return `True` if FlashInfer mnnvl all2all is available."""
if not has_flashinfer_comm():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.comm", "Mapping"),
("flashinfer.comm.mnnvl", "MnnvlMemory"),
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_flashinfer_moe() -> bool:
"""Return `True` if FlashInfer MoE module is available."""
return (
has_flashinfer()
and importlib.util.find_spec("flashinfer.fused_moe") is not None
)
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
if not has_flashinfer_moe():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "nvfp4_block_scale_interleave"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_nvidia_artifactory() -> bool:
"""Return `True` if NVIDIA's artifactory is accessible.
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
# If we have pre-downloaded cubins, we can assume the cubins are available.
if has_flashinfer_cubin():
return True
try:
# Use a short timeout to avoid blocking for too long
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
accessible = response.status_code == 200
if accessible:
logger.debug_once("NVIDIA artifactory is accessible")
else:
logger.warning_once(
"NVIDIA artifactory returned failed status code: %d",
response.status_code,
)
return accessible
except Exception as e:
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
return False
@functools.cache
def supports_trtllm_attention() -> bool:
"""
TRTLLM attention is supported if the platform is SM100,
NVIDIA artifactory is accessible, and batch-invariant mode is not enabled.
"""
# Batch-invariant mode disables TRTLLM attention
if vllm_is_batch_invariant():
return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
return current_platform.is_device_capability(100) and has_nvidia_artifactory()
@functools.cache
def _force_use_trtllm_attention(env_value: bool | None) -> bool | None:
"""Cache the env value for VLLM_USE_TRTLLM_ATTENTION"""
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value)
return env_value
def force_use_trtllm_attention() -> bool | None:
"""
Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set,
return `True` if TRTLLM attention is forced to be used,
return `False` if TRTLLM attention is forced to be not used.
"""
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
if force_use_trtllm_attention() is False:
return False
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
def use_trtllm_attention(
num_qo_heads: int,
num_kv_heads: int,
num_tokens: int,
max_seq_len: int,
dcp_world_size: int,
kv_cache_dtype: str,
q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False,
has_spec: bool = False,
) -> bool:
"""Return `True` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention()
# Environment variable is set to 0 - respect it
if force_use_trtllm is not None and not force_use_trtllm:
return False
# Decode context parallel is not supported
if dcp_world_size > 1:
logger.warning_once(
"Trtllm does not support returning LSE and as a result "
"does not support DCP, reverting to FlashInfer"
)
return False
# The platform is not supported
if not supports_trtllm_attention():
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported on this platform, "
"but VLLM_USE_TRTLLM_ATTENTION is set to 1"
)
return False
# The combination of query and key heads is not supported
if num_qo_heads % num_kv_heads != 0:
if force_use_trtllm:
logger.warning_once(
"TRTLLM attention is not supported for this combination of "
"query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1"
)
return False
if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once("Using TRTLLM attention (enabled for speculative decoding).")
return True
# Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype():
logger.info_once("Using TRTLLM attention (query is quantized).")
return True
# If sinks are being used, we must use TRTLLM attention as it's
# the only backend that supports them
if has_sinks:
logger.info_once("Using TRTLLM attention (required for attention sinks).")
return True
if force_use_trtllm is None:
# Environment variable not set - use auto-detection
if is_prefill:
# Prefill auto-detection
use_trtllm = kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM prefill attention (auto-detected).")
else:
# Decode auto-detection
use_trtllm = num_tokens <= 256 and kv_cache_dtype == "auto"
if use_trtllm:
logger.warning_once("Using TRTLLM decode attention (auto-detected).")
return use_trtllm
# Environment variable is set to 1 - respect it
logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)")
return True
if has_flashinfer():
@torch.library.custom_op(
"vllm::flashinfer_mm_fp4",
mutates_args=[],
device_types="cuda",
)
def flashinfer_mm_fp4(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
from flashinfer import mm_fp4 as flashinfer_mm_fp4_
return flashinfer_mm_fp4_(
A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend
)
@torch.library.register_fake(
"vllm::flashinfer_mm_fp4",
)
def flashinfer_mm_fp4_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
g_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device)
@torch.library.custom_op(
"vllm::bmm_fp8",
mutates_args=[],
device_types="cuda",
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
from flashinfer import bmm_fp8 as bmm_fp8_
return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend)
@torch.library.register_fake(
"vllm::bmm_fp8",
)
def bmm_fp8_fake(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
return torch.empty(
A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device
)
def flashinfer_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
backend: str,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1]
if backend == "cutlass":
block_scale_a = block_scale_a.view(torch.uint8)
block_scale_b = block_scale_b.view(torch.uint8)
return flashinfer_mm_fp4(
a,
b.t(),
block_scale_a,
block_scale_b.t(),
alpha,
out_dtype,
backend=backend,
)
def flashinfer_scaled_fp8_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
assert a.ndim == 2 and b.ndim == 2
assert a.shape[1] == b.shape[0]
assert scale_a.numel() == 1 and scale_b.numel() == 1
assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn
assert a.device.type == "cuda" and b.device.type == "cuda"
assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32
assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda"
output = bmm_fp8(
a.unsqueeze(0),
b.unsqueeze(0),
scale_a,
scale_b,
out_dtype,
"auto",
).view(a.shape[0], b.shape[1])
if bias is not None:
output = output + bias
return output
@functools.cache
def flashinfer_disable_q_quantization() -> bool:
"""Cache result which only depends on the environment"""
return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"flashinfer_fp4_quantize",
"nvfp4_block_scale_interleave",
"trtllm_fp4_block_scale_moe",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_comm",
"has_flashinfer_all2all",
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention",
"flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm",
"flashinfer_scaled_fp8_mm",
]

236
utils/func_utils.py Normal file
View File

@@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to functions.
This is similar in concept to the `functools` module.
"""
import inspect
import threading
import warnings
from collections.abc import Callable, Mapping
from functools import lru_cache, partial, wraps
from typing import Any, TypeVar
from typing_extensions import ParamSpec
from vllm.logger import init_logger
logger = init_logger(__name__)
P = ParamSpec("P")
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
def identity(value: T, **kwargs) -> T:
"""Returns the first provided value."""
return value
def run_once(f: Callable[P, None]) -> Callable[P, None]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> None:
if wrapper.has_run: # type: ignore[attr-defined]
return
with wrapper.lock: # type: ignore[attr-defined]
if not wrapper.has_run: # type: ignore[attr-defined]
wrapper.has_run = True # type: ignore[attr-defined]
return f(*args, **kwargs)
wrapper.has_run = False # type: ignore[attr-defined]
wrapper.lock = threading.Lock() # type: ignore[attr-defined]
return wrapper
def deprecate_args(
start_index: int,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [kw for kw, param in params.items() if param.kind in pos_types]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index : len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: bool | Callable[[], bool] = True,
additional_message: str | None = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update."
)
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
@lru_cache
def supports_kw(
callable: Callable[..., object],
kw_name: str,
*,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters
if not params:
return False
param_val = params.get(kw_name)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set(
(
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
)
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (
requires_kw_only
and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY
):
return False
if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or (
not requires_kw_only and is_sig_param
):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (
last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name
)
return False
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Mapping[str, object] | None,
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
mapping param names to values, drop values that can be not be kwarg
expanded to overwrite one or more keyword-only args. This is used in a
few places to handle custom processor overrides for multimodal models,
e.g., for profiling when processor options provided by the user
may affect the number of mm tokens per instance.
Args:
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
to overwrite one or more keyword only arguments when invoking the
callable.
"""
if not overrides:
return {}
# Drop any mm_processor_kwargs provided by the user that
# are not kwargs, unless it can fit it var_kwargs param
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if supports_kw(
callable,
kwarg_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
}
# If anything is dropped, log a warning
dropped_keys = overrides.keys() - filtered_overrides.keys()
if dropped_keys:
if requires_kw_only:
logger.warning(
"The following intended overrides are not keyword-only args "
"and will be dropped: %s",
dropped_keys,
)
else:
logger.warning(
"The following intended overrides are not keyword args "
"and will be dropped: %s",
dropped_keys,
)
return filtered_overrides

147
utils/gc_utils.py Normal file
View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import json
import time
from collections import Counter
from contextlib import suppress
from typing import Any
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
class GCDebugConfig:
"""
Config for GC Debugger.
- 0: disable GC debugger
- 1: enable GC debugger with gc.collect elpased times
- '{"top_objects":5}': enable GC debugger with top 5 collected objects
"""
def __init__(self, gc_debug_conf: str | None = None) -> None:
self.enabled: bool = False
self.top_objects: int = -1
if not gc_debug_conf or gc_debug_conf == "0":
pass
elif gc_debug_conf == "1":
self.enabled = True
else:
try:
json_conf = json.loads(gc_debug_conf)
self.enabled = True
self.top_objects = json_conf.get("top_objects", -1)
except Exception:
self.enabled = False
logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG)
logger.debug("GC Debug Config. %s", str(self))
def __repr__(self) -> str:
return f"enabled:{self.enabled},top_objects:{self.top_objects}"
class GCDebugger:
"""
Debugger for GC which logs helpful information for GC understanding.
To enable, you should call maybe_attach_gc_debug_callback in the process.
"""
def __init__(self, config: GCDebugConfig) -> None:
self.config = config
# Start time in micro second of this GC cycle
self.start_time_ns: int = time.monotonic_ns()
# If config.top_objects is positive,
# compute top collected objects by object types
self.gc_top_collected_objects: str = ""
def handle(self, phase: str, info: dict[str, int]) -> None:
"""
Handles a GC event (e.g. GC start or GC finish)
"""
generation = info.get("generation")
if generation is None:
return
if phase == "start":
# Before GC started, record GC start time
# and top collected objects
self.start_time_ns = time.monotonic_ns()
self.gc_top_collected_objects = _compute_top_gc_collected_objects(
gc.get_objects(generation), self.config.top_objects
)
elif phase == "stop":
# After GC finished, Record GC elapsed time and
# optionally top collected objects
elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6
logger.info(
"GC took %.3fms to complete. "
"Collected %s objects in GC generation %d.%s",
elpased_ms,
str(info.get("collected", "?")),
generation,
(
f" Top collected objects: \n{self.gc_top_collected_objects}"
if self.gc_top_collected_objects
else ""
),
)
def freeze_gc_heap() -> None:
"""
Freeze all objects tracked by the garbage collector. It should be invoked
after server init / warmup, to reduce GC overhead from static objects
during serving time.
"""
# Ensure all static objects are pushed down to the oldest generation for
# freeze
gc.collect(0)
gc.collect(1)
gc.collect(2)
# Freeze all GC tracked objects
gc.freeze()
def maybe_attach_gc_debug_callback() -> None:
"""
Attached a callback for GC debug when VLLM_GC_DEBUG is enabled.
"""
config = GCDebugConfig(envs.VLLM_GC_DEBUG)
if config.enabled:
debugger: GCDebugger = GCDebugger(config)
def gc_callback(phase: str, info: dict[str, int]) -> None:
debugger.handle(phase, info)
gc.callbacks.append(gc_callback)
def _compute_detailed_type(o: Any) -> str:
"""
Detailed object type.
TODO(Jialin): Further enhance the detailed type with element types for
easier debugging. We tried but occasionally it would run into signals
which kills the engine.
"""
size_str: str = ""
# Object doesn't support len() - this can happen with type objects
# or other objects that don't implement __len__ properly
with suppress(Exception):
size_str = f"(size:{len(o)})"
return f"{str(type(o))}{size_str}"
def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str:
"""
Group collected objects by types.
"""
if top <= 0:
return ""
object_types = [_compute_detailed_type(o) for o in objects]
return "\n".join(
f"{count:>5}:{object_type}"
for object_type, count in Counter(object_types).most_common(top)
)

63
utils/hashing.py Normal file
View File

@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import hashlib
import pickle
from collections.abc import Callable
from typing import Any
import cbor2
def sha256(input: Any) -> bytes:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
arbitrary Python objects to be used. Note that this function does
not use a hash seed—if you need one, prepend it explicitly to the input.
Args:
input: Any picklable Python object.
Returns:
Bytes representing the SHA-256 hash of the serialized input.
"""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return hashlib.sha256(input_bytes).digest()
def sha256_cbor(input: Any) -> bytes:
"""Hash objects using CBOR serialization and SHA-256.
This option is useful for non-Python-dependent serialization and hashing.
Args:
input: Object to be serialized and hashed. Supported types include
basic Python types and complex structures like lists, tuples, and
dictionaries.
Custom classes must implement CBOR serialization methods.
Returns:
Bytes representing the SHA-256 hash of the CBOR serialized input.
"""
input_bytes = cbor2.dumps(input, canonical=True)
return hashlib.sha256(input_bytes).digest()
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
"""Get a hash function by name, or raise an error if the function is not found.
Args:
hash_fn_name: Name of the hash function.
Returns:
A hash function.
"""
if hash_fn_name == "sha256":
return sha256
if hash_fn_name == "sha256_cbor":
return sha256_cbor
raise ValueError(f"Unsupported hash function: {hash_fn_name}")

411
utils/import_utils.py Normal file
View File

@@ -0,0 +1,411 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers related to importing modules.
This is similar in concept to the `importlib` module.
"""
import importlib.metadata
import importlib.util
import os
import sys
from functools import cache
from types import ModuleType
from typing import Any
import regex as re
from typing_extensions import Never
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
"""
Lazy initialization of the Hugging Face modules.
"""
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
def import_pynvml():
"""
Historical comments:
libnvml.so is the library behind nvidia-smi, and
pynvml is a Python wrapper around it. We use it to get GPU
status without initializing CUDA context in the current process.
Historically, there are two packages that provide pynvml:
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
wrapper. It is a dependency of vLLM, and is installed when users
install vLLM. It provides a Python module named `pynvml`.
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
Prior to version 12.0, it also provides a Python module `pynvml`,
and therefore conflicts with the official one. What's worse,
the module is a Python package, and has higher priority than
the official one which is a standalone Python file.
This causes errors when both of them are installed.
Starting from version 12.0, it migrates to a new module
named `pynvml_utils` to avoid the conflict.
It is so confusing that many packages in the community use the
unofficial one by mistake, and we have to handle this case.
For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial
one, and it will cause errors, see the issue
https://github.com/vllm-project/vllm/issues/12847 for example.
After all the troubles, we decide to copy the official `pynvml`
module to our codebase, and use it directly.
"""
import vllm.third_party.pynvml as pynvml
return pynvml
def import_from_path(module_name: str, file_path: str | os.PathLike):
"""
Import a Python file according to its file path.
Based on the official recipe:
https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
"""
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ModuleNotFoundError(f"No module named {module_name!r}")
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully-qualified class name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)
@cache
def get_vllm_optional_dependencies():
metadata = importlib.metadata.metadata("vllm")
requirements = metadata.get_all("Requires-Dist", [])
extras = metadata.get_all("Provides-Extra", [])
return {
extra: [
re.split(r";|>=|<=|==", req)[0]
for req in requirements
if req.endswith(f'extra == "{extra}"')
]
for extra in extras
}
class _PlaceholderBase:
"""
Disallows downstream usage of placeholder modules.
We need to explicitly override each dunder method because
[`__getattr__`][vllm.utils.import_utils._PlaceholderBase.__getattr__]
is not called when they are accessed.
Info:
[Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup)
"""
def __getattr__(self, key: str) -> Never:
"""
The main class should implement this to throw an error
for attribute accesses representing downstream usage.
"""
raise NotImplementedError
# [Basic customization]
def __lt__(self, other: object):
return self.__getattr__("__lt__")
def __le__(self, other: object):
return self.__getattr__("__le__")
def __eq__(self, other: object):
return self.__getattr__("__eq__")
def __ne__(self, other: object):
return self.__getattr__("__ne__")
def __gt__(self, other: object):
return self.__getattr__("__gt__")
def __ge__(self, other: object):
return self.__getattr__("__ge__")
def __hash__(self):
return self.__getattr__("__hash__")
def __bool__(self):
return self.__getattr__("__bool__")
# [Callable objects]
def __call__(self, *args: object, **kwargs: object):
return self.__getattr__("__call__")
# [Container types]
def __len__(self):
return self.__getattr__("__len__")
def __getitem__(self, key: object):
return self.__getattr__("__getitem__")
def __setitem__(self, key: object, value: object):
return self.__getattr__("__setitem__")
def __delitem__(self, key: object):
return self.__getattr__("__delitem__")
# __missing__ is optional according to __getitem__ specification,
# so it is skipped
# __iter__ and __reversed__ have a default implementation
# based on __len__ and __getitem__, so they are skipped.
# [Numeric Types]
def __add__(self, other: object):
return self.__getattr__("__add__")
def __sub__(self, other: object):
return self.__getattr__("__sub__")
def __mul__(self, other: object):
return self.__getattr__("__mul__")
def __matmul__(self, other: object):
return self.__getattr__("__matmul__")
def __truediv__(self, other: object):
return self.__getattr__("__truediv__")
def __floordiv__(self, other: object):
return self.__getattr__("__floordiv__")
def __mod__(self, other: object):
return self.__getattr__("__mod__")
def __divmod__(self, other: object):
return self.__getattr__("__divmod__")
def __pow__(self, other: object, modulo: object = ...):
return self.__getattr__("__pow__")
def __lshift__(self, other: object):
return self.__getattr__("__lshift__")
def __rshift__(self, other: object):
return self.__getattr__("__rshift__")
def __and__(self, other: object):
return self.__getattr__("__and__")
def __xor__(self, other: object):
return self.__getattr__("__xor__")
def __or__(self, other: object):
return self.__getattr__("__or__")
# r* and i* methods have lower priority than
# the methods for left operand so they are skipped
def __neg__(self):
return self.__getattr__("__neg__")
def __pos__(self):
return self.__getattr__("__pos__")
def __abs__(self):
return self.__getattr__("__abs__")
def __invert__(self):
return self.__getattr__("__invert__")
# __complex__, __int__ and __float__ have a default implementation
# based on __index__, so they are skipped.
def __index__(self):
return self.__getattr__("__index__")
def __round__(self, ndigits: object = ...):
return self.__getattr__("__round__")
def __trunc__(self):
return self.__getattr__("__trunc__")
def __floor__(self):
return self.__getattr__("__floor__")
def __ceil__(self):
return self.__getattr__("__ceil__")
# [Context managers]
def __enter__(self):
return self.__getattr__("__enter__")
def __exit__(self, *args: object, **kwargs: object):
return self.__getattr__("__exit__")
class PlaceholderModule(_PlaceholderBase):
"""
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exist.
"""
def __init__(self, name: str) -> None:
super().__init__()
# Apply name mangling to avoid conflicting with module attributes
self.__name = name
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self, attr_path)
def __getattr__(self, key: str) -> Never:
name = self.__name
try:
importlib.import_module(name)
except ImportError as exc:
for extra, names in get_vllm_optional_dependencies().items():
if name in names:
msg = f"Please install vllm[{extra}] for {extra} support"
raise ImportError(msg) from exc
raise exc
raise AssertionError(
"PlaceholderModule should not be used "
"when the original module can be imported"
)
class _PlaceholderModuleAttr(_PlaceholderBase):
def __init__(self, module: PlaceholderModule, attr_path: str) -> None:
super().__init__()
# Apply name mangling to avoid conflicting with module attributes
self.__module = module
self.__attr_path = attr_path
def placeholder_attr(self, attr_path: str):
return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}")
def __getattr__(self, key: str) -> Never:
getattr(self.__module, f"{self.__attr_path}.{key}")
raise AssertionError(
"PlaceholderModule should not be used "
"when the original module can be imported"
)
class LazyLoader(ModuleType):
"""
`LazyLoader` module borrowed from [Tensorflow]
(https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py)
with an addition of "module caching".
Lazily import a module, mainly to avoid pulling in large dependencies.
Modules such as `xgrammar` might do additional side effects, so we
only want to use this when it is needed, delaying all eager effects.
"""
def __init__(
self,
local_name: str,
parent_module_globals: dict[str, Any],
name: str,
):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module: ModuleType | None = None
super().__init__(str(name))
def _load(self) -> ModuleType:
# Import the target module and insert it into the parent's namespace
try:
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# The additional add to sys.modules
# ensures library is actually loaded.
sys.modules[self._local_name] = module
except ModuleNotFoundError as err:
raise err from None
# Update this object's dict so that if someone keeps a
# reference to the LazyLoader, lookups are efficient
# (__getattr__ is only called on lookups that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item: Any) -> Any:
if self._module is None:
self._module = self._load()
return getattr(self._module, item)
def __dir__(self) -> list[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)
# Optional dependency detection utilities
@cache
def _has_module(module_name: str) -> bool:
"""Return True if *module_name* can be found in the current environment.
The result is cached so that subsequent queries for the same module incur
no additional overhead.
"""
return importlib.util.find_spec(module_name) is not None
def has_pplx() -> bool:
"""Whether the optional `pplx_kernels` package is available."""
return _has_module("pplx_kernels")
def has_deep_ep() -> bool:
"""Whether the optional `deep_ep` package is available."""
return _has_module("deep_ep")
def has_deep_gemm() -> bool:
"""Whether the optional `deep_gemm` package is available."""
return _has_module("deep_gemm")
def has_triton_kernels() -> bool:
"""Whether the optional `triton_kernels` package is available."""
return _has_module("triton_kernels")
def has_tilelang() -> bool:
"""Whether the optional `tilelang` package is available."""
return _has_module("tilelang")
def has_arctic_inference() -> bool:
"""Whether the optional `arctic_inference` package is available."""
return _has_module("arctic_inference")

165
utils/jsontree.py Normal file
View File

@@ -0,0 +1,165 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helper functions to work with nested JSON structures."""
from collections.abc import Callable, Iterable
from functools import reduce
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload
if TYPE_CHECKING:
import torch
from vllm.multimodal.inputs import BatchedTensorInputs
_T = TypeVar("_T")
_U = TypeVar("_U")
JSONTree: TypeAlias = (
dict[str, "JSONTree[_T]"] | list["JSONTree[_T]"] | tuple["JSONTree[_T]", ...] | _T
)
"""A nested JSON structure where the leaves need not be JSON-serializable."""
_JSONTree: TypeAlias = (
dict[str, "JSONTree[_T]"]
| list["JSONTree[_T]"]
| tuple["JSONTree[_T]", ...]
| dict[str, _T]
| list[_T]
| tuple[_T, ...]
| _T
)
"""
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
"""
def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
"""Iterate through each leaf in a nested JSON structure."""
if isinstance(value, dict):
for v in value.values():
yield from json_iter_leaves(v)
elif isinstance(value, (list, tuple)):
for v in value:
yield from json_iter_leaves(v)
else:
yield value
@overload
def json_map_leaves(
func: Callable[["torch.Tensor"], "torch.Tensor"],
value: "BatchedTensorInputs",
) -> "BatchedTensorInputs": ...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: _T | dict[str, _T],
) -> _U | dict[str, _U]: ...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: _T | list[_T],
) -> _U | list[_U]: ...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: _T | tuple[_T, ...],
) -> _U | tuple[_U, ...]: ...
@overload
def json_map_leaves(
func: Callable[[_T], _U],
value: JSONTree[_T],
) -> JSONTree[_U]: ...
def json_map_leaves(
func: Callable[[_T], _U],
value: "BatchedTensorInputs" | _JSONTree[_T],
) -> "BatchedTensorInputs" | _JSONTree[_U]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {
k: json_map_leaves(func, v) # type: ignore[arg-type]
for k, v in value.items()
}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: _T | dict[str, _T],
/,
) -> _T: ...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: _T | list[_T],
/,
) -> _T: ...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: _T | tuple[_T, ...],
/,
) -> _T: ...
@overload
def json_reduce_leaves(
func: Callable[[_T, _T], _T],
value: JSONTree[_T],
/,
) -> _T: ...
@overload
def json_reduce_leaves(
func: Callable[[_U, _T], _U],
value: JSONTree[_T],
initial: _U,
/,
) -> _U: ...
def json_reduce_leaves(
func: Callable[..., _T | _U],
value: _JSONTree[_T],
initial: _U = cast(_U, ...), # noqa: B008
/,
) -> _T | _U:
"""
Apply a function of two arguments cumulatively to each leaf in a
nested JSON structure, from left to right, so as to reduce the
sequence to a single value.
"""
if initial is ...:
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
return reduce(
func, # type: ignore[arg-type]
json_iter_leaves(value),
initial,
)
def json_count_leaves(value: JSONTree[_T]) -> int:
"""Count the number of leaves in a nested JSON structure."""
return sum(1 for _ in json_iter_leaves(value))

32
utils/math_utils.py Normal file
View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Math utility functions for vLLM."""
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
def next_power_of_2(n: int) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
def round_up(x: int, y: int) -> int:
"""Round up x to the nearest multiple of y."""
return ((x + y - 1) // y) * y
def round_down(x: int, y: int) -> int:
"""Round down x to the nearest multiple of y."""
return (x // y) * y

13
utils/mem_constants.py Normal file
View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
MB_bytes = 1_000_000
"""The number of bytes in one megabyte (MB)."""
MiB_bytes = 1 << 20
"""The number of bytes in one mebibyte (MiB)."""
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""

232
utils/mem_utils.py Normal file
View File

@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import gc
import time
from collections.abc import Generator
from dataclasses import dataclass, field
from functools import cache
import psutil
import torch
import torch.types
from .mem_constants import GiB_bytes
@cache
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
from vllm import _custom_ops as ops
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
return int(max_shared_mem)
def get_cpu_memory() -> int:
"""Returns the total CPU memory of the node in bytes."""
return psutil.virtual_memory().total
class DeviceMemoryProfiler:
def __init__(self, device: torch.types.Device | None = None):
self.device = device
def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect()
return current_platform.get_current_memory_usage(self.device)
def __enter__(self):
self.initial_memory = self.current_memory_usage()
# This allows us to call methods of the context manager if needed
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.final_memory = self.current_memory_usage()
self.consumed_memory = self.final_memory - self.initial_memory
# Force garbage collection
gc.collect()
@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak: int = 0
free_memory: int = 0
total_memory: int = 0
cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
auto_measure: bool = True
def __post_init__(self):
if self.auto_measure:
self.measure()
def measure(self):
from vllm.platforms import current_platform
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if (
current_platform.is_cuda()
and current_platform.get_device_capability() in shared_sysmem_device_mem_sms
):
# On UMA (Orin, Thor and Spark) platform,
# where both CPU and GPU rely on system memory,
# the cudaMemGetInfo function shows the amount of free system memory
# rather than whats actually available.
# In the case,
# torch.cuda.mem_get_info() only reports "free" memory,
# which can be lower than what is actually
# available due to not including cache memory.
# Theres also a comprehensive reference page
# that explains how you can compute the proper value yourself.
# https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device
self.free_memory = psutil.virtual_memory().available
self.cuda_memory = self.total_memory - self.free_memory
# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
total_memory=self.total_memory - other.total_memory,
cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)
@dataclass
class MemoryProfilingResult:
"""Memory profiling result. All numbers are in bytes."""
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
def __repr__(self) -> str:
return (
f"Memory profiling takes {self.profile_time:.2f} seconds. "
f"Total non KV cache memory: "
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
f"torch peak memory increase: "
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
f"non-torch forward increase memory: "
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB."
)
@contextlib.contextmanager
def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
result = MemoryProfilingResult()
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory = weights_memory
result.before_profile.measure()
yield result
gc.collect()
torch.cuda.empty_cache()
result.after_profile.measure()
diff_profile = result.after_profile - result.before_profile
diff_from_create = result.after_profile - result.before_create
result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp
non_torch_memory = result.non_torch_increase
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa

64
utils/nccl.py Normal file
View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import importlib
import os
import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def find_nccl_library() -> str:
"""Return NCCL/RCCL shared library name to load.
Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend.
"""
so_file = envs.VLLM_NCCL_SO_PATH
if so_file:
logger.info(
"Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file
)
else:
if torch.version.cuda is not None:
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.debug_once("Found nccl from library %s", so_file)
return so_file
def find_nccl_include_paths() -> list[str] | None:
"""Return possible include paths containing `nccl.h`.
Considers `VLLM_NCCL_INCLUDE_PATH` and the `nvidia-nccl-cuXX` package.
"""
paths: list[str] = []
inc = envs.VLLM_NCCL_INCLUDE_PATH
if inc and os.path.isdir(inc):
paths.append(inc)
try:
spec = importlib.util.find_spec("nvidia.nccl")
if spec and getattr(spec, "submodule_search_locations", None):
for loc in spec.submodule_search_locations:
inc_dir = os.path.join(loc, "include")
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
paths.append(inc_dir)
except Exception as e:
logger.debug("Failed to find nccl include path from nvidia.nccl package: %s", e)
seen: set[str] = set()
out: list[str] = []
for p in paths:
if p and p not in seen:
out.append(p)
seen.add(p)
return out or None

331
utils/network_utils.py Normal file
View File

@@ -0,0 +1,331 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import ipaddress
import os
import socket
import sys
import warnings
from collections.abc import (
Iterator,
Sequence,
)
from typing import Any
from urllib.parse import urlparse
from uuid import uuid4
import psutil
import zmq
import zmq.asyncio
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
for sock in sockets:
if sock is not None:
sock.close(linger=0)
def get_ip() -> str:
host_ip = envs.VLLM_HOST_IP
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
logger.warning(
"The environment variable HOST_IP is deprecated and ignored, as"
" it is often used by Docker and other software to"
" interact with the container's network stack. Please "
"use VLLM_HOST_IP instead to set the IP address for vLLM processes"
" to communicate with each other."
)
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s:
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" VLLM_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def test_loopback_bind(address, family):
try:
s = socket.socket(family, socket.SOCK_DGRAM)
s.bind((address, 0)) # Port 0 = auto assign
s.close()
return True
except OSError:
return False
def get_loopback_ip() -> str:
loopback_ip = envs.VLLM_LOOPBACK_IP
if loopback_ip:
return loopback_ip
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
if test_loopback_bind("127.0.0.1", socket.AF_INET):
return "127.0.0.1"
elif test_loopback_bind("::1", socket.AF_INET6):
return "::1"
else:
raise RuntimeError(
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
"Set the VLLM_LOOPBACK_IP environment variable explicitly."
)
def is_valid_ipv6_address(address: str) -> bool:
try:
ipaddress.IPv6Address(address)
return True
except ValueError:
return False
def split_host_port(host_port: str) -> tuple[str, int]:
# ipv6
if host_port.startswith("["):
host, port = host_port.rsplit("]", 1)
host = host[1:]
port = port.split(":")[1]
return host, int(port)
else:
host, port = host_port.split(":")
return host, int(port)
def join_host_port(host: str, port: int) -> str:
if is_valid_ipv6_address(host):
return f"[{host}]:{port}"
else:
return f"{host}:{port}"
def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str:
if is_valid_ipv6_address(ip):
return f"tcp://[{ip}]:{port}"
else:
return f"tcp://{ip}:{port}"
def get_open_zmq_ipc_path() -> str:
base_rpc_path = envs.VLLM_RPC_BASE_PATH
return f"ipc://{base_rpc_path}/{uuid4()}"
def get_open_zmq_inproc_path() -> str:
return f"inproc://{uuid4()}"
def get_open_port() -> int:
"""
Get an open port for the vLLM process to listen on.
An edge case to handle, is when we run data parallel,
we need to avoid ports that are potentially used by
the data parallel master process.
Right now we reserve 10 ports for the data parallel master
process. Currently it uses 2 ports.
"""
if "VLLM_DP_MASTER_PORT" in os.environ:
dp_master_port = envs.VLLM_DP_MASTER_PORT
reserved_port_range = range(dp_master_port, dp_master_port + 10)
while True:
candidate_port = _get_open_port()
if candidate_port not in reserved_port_range:
return candidate_port
return _get_open_port()
def get_open_ports_list(count: int = 5) -> list[int]:
"""Get a list of open ports."""
ports = set[int]()
while len(ports) < count:
ports.add(get_open_port())
return list(ports)
def _get_open_port() -> int:
port = envs.VLLM_PORT
if port is not None:
while True:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", port))
return port
except OSError:
port += 1 # Increment port number if already in use
logger.info("Port %d is already in use, trying port %d", port - 1, port)
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def find_process_using_port(port: int) -> psutil.Process | None:
# TODO: We can not check for running processes with network
# port on macOS. Therefore, we can not have a full graceful shutdown
# of vLLM. For now, let's not look for processes in this case.
# Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
if sys.platform.startswith("darwin"):
return None
our_pid = os.getpid()
for conn in psutil.net_connections():
if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
try:
return psutil.Process(conn.pid)
except psutil.NoSuchProcess:
return None
return None
def split_zmq_path(path: str) -> tuple[str, str, str]:
"""Split a zmq path into its parts."""
parsed = urlparse(path)
if not parsed.scheme:
raise ValueError(f"Invalid zmq path: {path}")
scheme = parsed.scheme
host = parsed.hostname or ""
port = str(parsed.port or "")
if scheme == "tcp" and not all((host, port)):
# The host and port fields are required for tcp
raise ValueError(f"Invalid zmq path: {path}")
if scheme != "tcp" and port:
# port only makes sense with tcp
raise ValueError(f"Invalid zmq path: {path}")
return scheme, host, port
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
"""Make a ZMQ path from its parts.
Args:
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
host: The host - can be an IPv4 address, IPv6 address, or hostname.
port: Optional port number, only used for TCP sockets.
Returns:
A properly formatted ZMQ path string.
"""
if port is None:
return f"{scheme}://{host}"
if is_valid_ipv6_address(host):
return f"{scheme}://[{host}]:{port}"
return f"{scheme}://{host}:{port}"
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined]
path: str,
socket_type: Any,
bind: bool | None = None,
identity: bytes | None = None,
linger: int | None = None,
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem = psutil.virtual_memory()
socket = ctx.socket(socket_type)
# Calculate buffer size based on system memory
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
# For systems with substantial memory (>32GB total, >16GB available):
# - Set a large 0.5GB buffer to improve throughput
# For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
if bind is None:
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, buf_size)
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, buf_size)
if identity is not None:
socket.setsockopt(zmq.IDENTITY, identity)
if linger is not None:
socket.setsockopt(zmq.LINGER, linger)
if socket_type == zmq.XPUB:
socket.setsockopt(zmq.XPUB_VERBOSE, True)
# Determine if the path is a TCP socket with an IPv6 address.
# Enable IPv6 on the zmq socket if so.
scheme, host, _ = split_zmq_path(path)
if scheme == "tcp" and is_valid_ipv6_address(host):
socket.setsockopt(zmq.IPV6, 1)
if bind:
socket.bind(path)
else:
socket.connect(path)
return socket
@contextlib.contextmanager
def zmq_socket_ctx(
path: str,
socket_type: Any,
bind: bool | None = None,
linger: int = 0,
identity: bytes | None = None,
) -> Iterator[zmq.Socket]:
"""Context manager for a ZMQ socket"""
ctx = zmq.Context() # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")
finally:
ctx.destroy(linger=linger)

59
utils/platform_utils.py Normal file
View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing
from collections.abc import Sequence
from concurrent.futures.process import ProcessPoolExecutor
from functools import cache
from typing import Any
import torch
def cuda_is_initialized() -> bool:
"""Check if CUDA is initialized."""
if not torch.cuda._is_compiled():
return False
return torch.cuda.is_initialized()
def xpu_is_initialized() -> bool:
"""Check if XPU is initialized."""
if not torch.xpu._is_compiled():
return False
return torch.xpu.is_initialized()
def get_cu_count(device_id: int = 0) -> int:
"""Returns the total number of compute units (CU) on single GPU."""
return torch.cuda.get_device_properties(device_id).multi_processor_count
def cuda_get_device_properties(
device, names: Sequence[str], init_cuda=False
) -> tuple[Any, ...]:
"""Get specified CUDA device property values without initializing CUDA in
the current process."""
if init_cuda or cuda_is_initialized():
props = torch.cuda.get_device_properties(device)
return tuple(getattr(props, name) for name in names)
# Run in subprocess to avoid initializing CUDA as a side effect.
mp_ctx = multiprocessing.get_context("fork")
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
return executor.submit(cuda_get_device_properties, device, names, True).result()
@cache
def is_pin_memory_available() -> bool:
from vllm.platforms import current_platform
return current_platform.is_pin_memory_available()
@cache
def is_uva_available() -> bool:
"""Check if Unified Virtual Addressing (UVA) is available."""
# UVA requires pinned memory.
# TODO: Add more requirements for UVA if needed.
return is_pin_memory_available()

56
utils/profiling.py Normal file
View File

@@ -0,0 +1,56 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import contextlib
from collections.abc import Callable
from functools import wraps
from typing import Any
@contextlib.contextmanager
def cprofile_context(save_file: str | None = None):
"""Run a cprofile
Args:
save_file: path to save the profile result. "1" or
None will result in printing to stdout.
"""
import cProfile
prof = cProfile.Profile()
prof.enable()
try:
yield
finally:
prof.disable()
if save_file and save_file != "1":
prof.dump_stats(save_file)
else:
prof.print_stats(sort="cumtime")
def cprofile(save_file: str | None = None, enabled: bool = True):
"""Decorator to profile a Python method using cProfile.
Args:
save_file: Path to save the profile result.
If "1", None, or "", results will be printed to stdout.
enabled: Set to false to turn this into a no-op
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args: Any, **kwargs: Any):
if not enabled:
# If profiling is disabled, just call the function directly.
return func(*args, **kwargs)
with cprofile_context(save_file):
return func(*args, **kwargs)
return wrapper
return decorator

49
utils/registry.py Normal file
View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
class ExtensionManager:
"""
A registry for managing pluggable extension classes.
This class provides a simple mechanism to register and instantiate
extension classes by name. It is commonly used to implement plugin
systems where different implementations can be swapped at runtime.
Examples:
Basic usage with a registry instance:
>>> FOO_REGISTRY = ExtensionManager()
>>> @FOO_REGISTRY.register("my_foo_impl")
... class MyFooImpl(Foo):
... def __init__(self, value):
... self.value = value
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)
"""
def __init__(self) -> None:
"""
Initialize an empty extension registry.
"""
self.name2class: dict[str, type] = {}
def register(self, name: str):
"""
Decorator to register a class with the given name.
"""
def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register
return wrap
def load(self, cls_name: str, *args, **kwargs) -> Any:
"""
Instantiate and return a registered extension class by name.
"""
cls = self.name2class.get(cls_name)
assert cls is not None, f"Extension class {cls_name} not found"
return cls(*args, **kwargs)

169
utils/serial_utils.py Normal file
View File

@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import sys
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
from typing_extensions import assert_never
from vllm import PoolingRequestOutput
sys_byteorder = sys.byteorder
EMBED_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
"float32": torch.float32,
"float16": torch.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": torch.float16,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
"float32": np.float32,
"float16": np.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": np.float16,
"fp8_e4m3": np.uint8,
"fp8_e5m2": np.uint8,
}
ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes"]
def tensor2binary(
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
) -> bytes:
assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
np_array = (
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()
return np_array.tobytes()
def binary2tensor(
binary: bytes,
shape: tuple[int, ...],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()
return torch.from_numpy(np_array).view(torch_dtype)
def encode_pooling_output(
output: PoolingRequestOutput,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> list[float] | str | bytes:
if encoding_format == "float":
return output.outputs.data.tolist()
elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes":
return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
):
num_prompt_tokens = 0
items: list[dict[str, MetadataItem]] = []
body = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
item = {
"index": idx,
"embed_dtype": embed_dtype,
"endianness": endianness,
"start": offset,
"end": offset + size,
"shape": output.outputs.data.shape,
}
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
usage = {
"prompt_tokens": num_prompt_tokens,
"total_tokens": num_prompt_tokens,
}
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
items.sort(key=lambda x: x.index)
tensor_list: list[torch.Tensor] = []
for item in items:
binary = body[item.start : item.end]
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
tensor_list.append(tensor)
return tensor_list

229
utils/system_utils.py Normal file
View File

@@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import contextlib
import multiprocessing
import os
import signal
import sys
from collections.abc import Callable, Iterator
from pathlib import Path
from typing import TextIO
import psutil
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.ray.lazy_utils import is_in_ray_actor
from .platform_utils import cuda_is_initialized, xpu_is_initialized
logger = init_logger(__name__)
CYAN = "\033[1;36m"
RESET = "\033[0;0m"
# Environment variable utilities
def update_environment_variables(envs_dict: dict[str, str]):
"""Update multiple environment variables with logging."""
for k, v in envs_dict.items():
if k in os.environ and os.environ[k] != v:
logger.warning(
"Overwriting environment variable %s from '%s' to '%s'",
k,
os.environ[k],
v,
)
os.environ[k] = v
@contextlib.contextmanager
def set_env_var(key: str, value: str) -> Iterator[None]:
"""Temporarily set an environment variable."""
old = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
if old is None:
os.environ.pop(key, None)
else:
os.environ[key] = old
# File path utilities
def unique_filepath(fn: Callable[[int], Path]) -> Path:
"""Generate a unique file path by trying incrementing integers.
Note: This function has a TOCTOU race condition.
Caller should use atomic operations (e.g., open with 'x' mode)
when creating the file to ensure thread safety.
"""
i = 0
while True:
p = fn(i)
if not p.exists():
return p
i += 1
# Process management utilities
def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reasons = []
if is_in_ray_actor():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
reasons.append("In a Ray actor and can only be spawned")
if cuda_is_initialized():
reasons.append("CUDA is initialized")
elif xpu_is_initialized():
reasons.append("XPU is initialized")
if reasons:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/usage/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reasons: %s",
"; ".join(reasons),
)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def set_process_title(
name: str,
suffix: str = "",
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX,
) -> None:
"""Set the current process title with optional suffix."""
try:
import setproctitle
except ImportError:
return
if suffix:
name = f"{name}_{suffix}"
setproctitle.setproctitle(f"{prefix}::{name}")
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
"""Add colored prefix to file output for log decoration."""
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
file_write = file.write
def write_with_prefix(s: str):
if not s:
return
if file.start_new_line: # type: ignore[attr-defined]
file_write(prefix)
idx = 0
while (next_idx := s.find("\n", idx)) != -1:
next_idx += 1
file_write(s[idx:next_idx])
if next_idx == len(s):
file.start_new_line = True # type: ignore[attr-defined]
return
file_write(prefix)
idx = next_idx
file_write(s[idx:])
file.start_new_line = False # type: ignore[attr-defined]
file.start_new_line = True # type: ignore[attr-defined]
file.write = write_with_prefix # type: ignore[method-assign]
def decorate_logs(process_name: str | None = None) -> None:
"""Decorate stdout/stderr with process name and PID prefix."""
if process_name is None:
process_name = get_mp_context().current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
def kill_process_tree(pid: int):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Args:
pid (int): Process ID of the parent process
"""
try:
parent = psutil.Process(pid)
except psutil.NoSuchProcess:
return
# Get all children recursively
children = parent.children(recursive=True)
# Send SIGKILL to all children first
for child in children:
with contextlib.suppress(ProcessLookupError):
os.kill(child.pid, signal.SIGKILL)
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)
# Resource utilities
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630
def set_ulimit(target_soft_limit: int = 65535):
if sys.platform.startswith("win"):
logger.info("Windows detected, skipping ulimit adjustment.")
return
import resource
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warning(
"Found ulimit of %s and failed to automatically increase "
"with error %s. This can cause fd limit errors like "
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n",
current_soft,
e,
)

255
utils/tensor_schema.py Normal file
View File

@@ -0,0 +1,255 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import UnionType
from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
class TensorShape:
def __init__(
self,
*dims: int | str,
dynamic_dims: set[str] | None = None,
) -> None:
super().__init__()
self.dims = dims
self.dynamic_dims = dynamic_dims if dynamic_dims else set()
def resolve(self, **bindings: int) -> tuple[int | str, ...]:
resolved = list[int | str]()
for dim in self.dims:
if isinstance(dim, str) and dim in bindings:
resolved.append(bindings[dim])
else:
resolved.append(dim)
return tuple(resolved)
def __str__(self) -> str:
"""Return a string representation of the tensor shape."""
dim_strs = []
for dim in self.dims:
if isinstance(dim, str):
if dim in self.dynamic_dims:
dim_strs.append(f"{dim}*") # Mark dynamic dimensions with *
else:
dim_strs.append(dim)
else:
dim_strs.append(str(dim))
return f"({', '.join(dim_strs)})"
class TensorSchema:
def __init__(
self,
*,
validate: bool = True,
resolve_bindings: dict[str, int] | None = None,
**kwargs: Any,
) -> None:
super().__init__()
self._resolve_bindings = resolve_bindings if resolve_bindings else {}
for key, value in kwargs.items():
setattr(self, key, value)
if validate:
self.validate()
def __getitem__(self, key: str) -> Any:
return getattr(self, key)
def get(self, key: str, default: Any = None) -> Any:
return getattr(self, key, default)
def _match_shape_with_dynamic(
self,
actual: tuple[int, ...],
reference: tuple[int, ...],
expected_shape: tuple[int | str, ...],
dynamic_dims: set[str],
) -> bool:
if len(actual) != len(reference) or len(actual) > len(expected_shape):
return False
for i, (a, r) in enumerate(zip(actual, reference)):
# When validating list inputs, we match shape suffixes only
# (e.g. "p", 3, "h", "w"), assuming the list length corresponds
# to the leading symbolic dim (e.g. "bn"). This allows comparing
# only the trailing dimensions of each element in the list.
dim = expected_shape[-len(actual) + i]
# Skip this dimension if it's marked dynamic
if dim in dynamic_dims:
continue
if a != r:
return False
return True
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
if not idxs:
return ""
return str(list(idxs))
def _validate_field(
self,
value: object,
field_name: str,
expected_shape: tuple[int | str, ...],
dynamic_dims: set[str],
leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]:
"""Validate a field and return the actual shape."""
if isinstance(value, (int, float)):
return () # Scalar
if isinstance(value, torch.Tensor):
return value.shape
if not isinstance(value, (list, tuple)):
raise TypeError(
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
f"one of the expected types: int, float, Tensor, list, tuple. "
f"Got: {type(value)}"
)
if len(value) == 0:
raise ValueError(
f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence"
)
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
for i, v in enumerate(value):
shape = self._validate_field(
v,
field_name,
expected_shape[1:],
dynamic_dims,
leading_idxs=leading_idxs + (i,),
)
if i == 0:
first_shape = shape
elif not self._match_shape_with_dynamic(
shape,
first_shape,
expected_shape,
dynamic_dims,
):
raise ValueError(
f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"contains inconsistent shapes: {first_shape} "
f"(index 0) vs {shape} (index {i})"
)
# Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape)
return (len(value),) + first_shape
def _validate_tensor_shape_expected(
self,
actual_shape: tuple[int, ...],
expected_shape: tuple[int | str, ...],
field_name: str,
shape_env: dict[str, int],
dynamic_dims: set[str],
) -> None:
"""Validate that the actual tensor shape matches the expected shape."""
if len(actual_shape) != len(expected_shape):
raise ValueError(
f"{field_name} has rank {len(actual_shape)} "
f"but expected {len(expected_shape)}. "
f"Expected shape: {expected_shape}, "
f"but got {actual_shape}"
)
for i, dim in enumerate(expected_shape):
if dim in dynamic_dims:
continue
elif isinstance(dim, int):
if actual_shape[i] != dim:
raise ValueError(
f"{field_name} dim[{i}] expected "
f"{dim}, got {actual_shape[i]}. "
f"Expected shape: {expected_shape}, "
f"but got {actual_shape}"
)
elif isinstance(dim, str):
if dim in shape_env:
if actual_shape[i] != shape_env[dim]:
raise ValueError(
f"{field_name} dim[{i}] expected "
f"'{dim}'={shape_env[dim]}, got "
f"{actual_shape[i]}"
)
else:
shape_env[dim] = actual_shape[i]
else:
raise TypeError(
f"{field_name} dim[{i}] has unsupported type: {type(dim)}"
)
def validate(self) -> None:
type_hints = get_type_hints(self.__class__, include_extras=True)
shape_env = dict[str, int]()
for field_name, field_type in type_hints.items():
# Check if field is missing
if not hasattr(self, field_name) or getattr(self, field_name) is None:
# Check if field is marked as optional
actual_type = field_type
if get_origin(field_type) is Annotated:
args = get_args(field_type)
actual_type = args[0]
# Check arg was provided as Union
if get_origin(actual_type) in {Union, UnionType}:
# Union for Union[X, Y] and UnionType for X | Y
args = get_args(actual_type)
# Skip validation when Union contains None
if type(None) in args:
continue
# Otherwise field is required, raise error
raise ValueError(f"Required field '{field_name}' is missing")
# Field exists, proceed with validation
value = getattr(self, field_name)
if get_origin(field_type) is not None:
args = get_args(field_type)
for arg in args:
if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings)
actual_shape = self._validate_field(
value,
field_name,
expected_shape,
arg.dynamic_dims,
)
self._validate_tensor_shape_expected(
actual_shape,
expected_shape,
field_name,
shape_env,
arg.dynamic_dims,
)
def print_shapes(self) -> None:
"""Print TensorShape annotations for debugging."""
logger.debug("Shapes in %s:", self.__class__.__name__)
type_hints = get_type_hints(self.__class__, include_extras=True)
for field_name, field_type in type_hints.items():
if get_origin(field_type) is not None:
args = get_args(field_type)
for arg in args:
if isinstance(arg, TensorShape):
logger.debug(" %s: %s", field_name, str(arg))

658
utils/torch_utils.py Normal file
View File

@@ -0,0 +1,658 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import importlib.metadata
import os
import threading
from collections.abc import Callable, Collection
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar
import numpy as np
import numpy.typing as npt
import torch
from packaging import version
from packaging.version import Version
from torch.library import Library
import vllm.envs as envs
import ixformer.inference.functions as ixfops
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.sequence import IntermediateTensors
else:
ModelConfig = object
IntermediateTensors = object
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"half": torch.half,
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
torch.uint8: np.uint8,
torch.int32: np.int32,
torch.int64: np.int64,
}
T = TypeVar("T")
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
"""Sets the default number of threads for PyTorch to the given value."""
old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)
@contextlib.contextmanager
def guard_cuda_initialization():
"""Avoid unexpected CUDA initialization."""
from vllm.platforms import current_platform
if not current_platform.is_cuda():
yield
return
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
yield
except Exception as e:
if "No CUDA GPUs are available" in str(e):
err_msg = "CUDA initialization is blocked."
else:
err_msg = str(e)
raise RuntimeError(err_msg) from e
finally:
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")
def get_dtype_size(dtype: torch.dtype) -> int:
"""Get the size of the data type in bytes."""
return torch.tensor([], dtype=dtype).element_size()
# bool = 0, int = 1, float = 2, complex = 3
def _get_precision_level(dtype: torch.dtype) -> int:
# NOTE: Complex dtypes return `is_floating_point=False`
return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2
def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype):
"""
Test whether it is lossless to cast a tensor from
`src_dtype` to `tgt_dtype`.
"""
if src_dtype == tgt_dtype:
return True
src_level = _get_precision_level(src_dtype)
tgt_level = _get_precision_level(tgt_dtype)
if src_level < tgt_level:
return True
if src_level > tgt_level:
return False
# Compare integral types
if not src_dtype.is_floating_point and not src_dtype.is_complex:
src_info = torch.iinfo(src_dtype)
tgt_info = torch.iinfo(tgt_dtype)
return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max
# Compare floating-point types
src_info = torch.finfo(src_dtype)
tgt_info = torch.finfo(tgt_dtype)
return (
src_info.min >= tgt_info.min
and src_info.max <= tgt_info.max
and src_info.resolution >= tgt_info.resolution
)
def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
"""
Get the common `dtype` where all of the other `dtypes` can be
cast to it without losing any information.
"""
return max(
dtypes,
key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes),
)
def _generate_random_fp8(
tensor: torch.Tensor,
low: float,
high: float,
) -> None:
# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 data.
# For example, s.11111.00 in fp8e5m2 format represents Inf.
# | E4M3 | E5M2
# -----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
from vllm import _custom_ops as ops
tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
tensor_tmp.uniform_(low, high)
ops.convert_fp8(tensor, tensor_tmp)
del tensor_tmp
def get_kv_cache_torch_dtype(
cache_dtype: str | torch.dtype | None,
model_dtype: str | torch.dtype | None = None,
) -> torch.dtype:
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype
else:
raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype):
torch_dtype = cache_dtype
else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
return torch_dtype
def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype:
if kv_cache_dtype == "auto":
# Model config may not be specified for unit tests, default to float16
return model_config.dtype if model_config else torch.half
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
def create_kv_caches_with_random_flash(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: str | torch.dtype | None,
model_dtype: str | torch.dtype | None = None,
seed: int | None = None,
device: str | None = "cuda",
cache_layout: str | None = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4)
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order)
scale = head_size**-0.5
key_caches: list[torch.Tensor] = []
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_value_cache = torch.empty(
size=kv_cache_allocation_shape, dtype=dtype, device=device
).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(key_value_cache, -scale, scale)
else:
raise ValueError(f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_value_cache[:, 0])
value_caches.append(key_value_cache[:, 1])
return key_caches, value_caches
def create_kv_caches_with_random(
num_blocks: int,
block_size: int,
num_layers: int,
num_heads: int,
head_size: int,
cache_dtype: str | torch.dtype | None,
model_dtype: str | torch.dtype | None = None,
seed: int | None = None,
device: str | None = "cuda",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
scale = head_size**-0.5
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_caches: list[torch.Tensor] = []
for _ in range(num_layers):
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(key_cache, -scale, scale)
else:
raise ValueError(f"Does not support key cache of type {cache_dtype}")
key_caches.append(key_cache)
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_caches: list[torch.Tensor] = []
for _ in range(num_layers):
value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device)
if cache_dtype in ["auto", "half", "bfloat16", "float"]:
value_cache.uniform_(-scale, scale)
elif cache_dtype == "fp8":
_generate_random_fp8(value_cache, -scale, scale)
else:
raise ValueError(f"Does not support value cache of type {cache_dtype}")
value_caches.append(value_cache)
return key_caches, value_caches
def async_tensor_h2d(
data: list,
dtype: torch.dtype,
target_device: str | torch.device,
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def make_ndarray_with_pad(
x: list[list[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len: int | None = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
if max_len is None:
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
padded_x = np.full((len(x), max_len), pad, dtype=dtype)
for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len
padded_x[ind, : len(blocktb)] = blocktb
return padded_x
def make_tensor_with_pad(
x: list[list[T]],
pad: T,
dtype: torch.dtype,
*,
max_len: int | None = None,
device: str | torch.device | None = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
return tensor
prev_set_stream = torch.cuda.set_stream
_current_stream_tls = threading.local()
def _patched_set_stream(stream: torch.cuda.Stream) -> None:
_current_stream_tls.value = stream
prev_set_stream(stream)
torch.cuda.set_stream = _patched_set_stream
class _StreamPlaceholder:
def __init__(self):
self.synchronize = lambda: None
def current_stream() -> torch.cuda.Stream:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
it turns out that `torch.cuda.current_stream()` is quite expensive,
as it will construct a new stream object at each call.
here we patch `torch.cuda.set_stream` to keep track of the current stream
directly, so that we can avoid calling `torch.cuda.current_stream()`.
the underlying hypothesis is that we do not call `torch._C._cuda_setStream`
from C/C++ code.
"""
from vllm.platforms import current_platform
if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None:
# when this function is called before any stream is set,
# we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream
# per process
if current_platform.is_rocm():
# torch.cuda.set_stream here is the alias of _pathed_set_stream
torch.cuda.set_stream(torch.cuda.Stream())
elif current_platform.is_cpu():
_current_stream_tls.value = _StreamPlaceholder()
else:
current_stream = current_platform.current_stream
if current_stream is not None:
_current_stream_tls.value = current_stream()
else:
raise ValueError(
"Fail to set current stream, current platform "
"may not support current_stream with torch API"
)
return _current_stream_tls.value
# Global auxilary stream for running operations in background streams.
# We have single global auxilary stream to avoid an explosion of streams
# for every layer (and make profiling look sane).
#
# aux_stream() is currently used for:
# - MoE shared_expert overlap with router
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
from vllm.platforms import current_platform
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None and current_platform.is_cuda():
_aux_stream = torch.cuda.Stream()
return _aux_stream
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version
from vllm.platforms import current_platform
if not torch.cuda._is_compiled():
return 0
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
else:
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
if isinstance(tensor, torch.Tensor):
return ixfops.weak_ref_tensor(tensor)
else:
return tensor
def weak_ref_tensors(
tensors: torch.Tensor
| list[torch.Tensor]
| tuple[torch.Tensor]
| IntermediateTensors,
) -> torch.Tensor | list[Any] | tuple[Any] | Any:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
# For IntermediateTensors used in pipeline parallelism
from vllm.sequence import IntermediateTensors
if isinstance(tensors, IntermediateTensors):
ret = IntermediateTensors(
{key: weak_ref_tensor(val) for key, val in tensors.tensors.items()}
)
return ret
raise ValueError("Invalid type for tensors")
def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
"""
Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA).
"""
assert cpu_tensor.is_pinned(), "CPU tensor must be pinned"
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
# Helper function used in testing.
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
return _is_torch_equal_or_newer(str(torch.__version__), target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version("torch")) >= Version(target)
def _is_torch_equal(target: str) -> bool:
assert target.count(".") == 2
torch_version = str(torch.__version__)
torch_version = version.parse(torch_version)
# torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu"
# or "2.6.0+cu128" but never "2.6.0.1"
return (
torch_version >= version.parse(target)
and version.parse(target + ".1") > torch_version
)
def is_torch_equal(target: str) -> bool:
"""Check if the installed torch version is == the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
return _is_torch_equal(target)
except Exception:
return Version(importlib.metadata.version("torch")) == Version(target)
# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0.
# In particular, the FakeScalarType is not supported for earlier versions of
# PyTorch which breaks dynamo for any ops registered using ScalarType.
def supports_dynamo() -> bool:
return is_torch_equal_or_newer("2.4.0")
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
def supports_xccl() -> bool:
return (
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
)
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
return hasattr(torch.library, "custom_op")
# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str] | None = None,
fake_impl: Callable | None = None,
target_lib: Library | None = None,
dispatch_key: str | None = None,
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies."
)
return
if mutates_args is None:
mutates_args = []
if dispatch_key is None:
from vllm.platforms import current_platform
dispatch_key = current_platform.dispatch_key
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)