diff --git a/vllm/_utils/__init__.py b/vllm/_utils/__init__.py new file mode 100644 index 0000000..2d64a71 --- /dev/null +++ b/vllm/_utils/__init__.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import uuid +import warnings +from typing import Any + +import torch + +from vllm.logger import init_logger + +from utils import * + +_DEPRECATED_MAPPINGS = { + "cprofile": "profiling", + "cprofile_context": "profiling", + # Used by lm-eval + "get_open_port": "network_utils", +} + + +def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring + """Module-level getattr to handle deprecated utilities.""" + if name in _DEPRECATED_MAPPINGS: + submodule_name = _DEPRECATED_MAPPINGS[name] + warnings.warn( + f"vllm.utils.{name} is deprecated and will be removed in a future version. " + f"Use vllm.utils.{submodule_name}.{name} instead.", + DeprecationWarning, + stacklevel=2, + ) + module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name]) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + # expose deprecated names in dir() for better UX/tab-completion + return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys())) + + +logger = init_logger(__name__) + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def length_from_prompt_token_ids_or_embeds( + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, +) -> int: + """Calculate the request length (in number of tokens) give either + prompt_token_ids or prompt_embeds. + """ + prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids) + prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds) + + if prompt_token_len is None: + if prompt_embeds_len is None: + raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.") + return prompt_embeds_len + else: + if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len: + raise ValueError( + "Prompt token ids and prompt embeds had different lengths" + f" prompt_token_ids={prompt_token_len}" + f" prompt_embeds={prompt_embeds_len}" + ) + return prompt_token_len diff --git a/vllm/_utils/argparse_utils.py b/vllm/_utils/argparse_utils.py new file mode 100644 index 0000000..3d105a3 --- /dev/null +++ b/vllm/_utils/argparse_utils.py @@ -0,0 +1,487 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Argument parsing utilities for vLLM.""" + +import json +import sys +import textwrap +from argparse import ( + Action, + ArgumentDefaultsHelpFormatter, + ArgumentParser, + ArgumentTypeError, + Namespace, + RawDescriptionHelpFormatter, + _ArgumentGroup, +) +from collections import defaultdict +from typing import Any + +import regex as re +import yaml + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SortedHelpFormatter(ArgumentDefaultsHelpFormatter, RawDescriptionHelpFormatter): + """SortedHelpFormatter that sorts arguments by their option strings.""" + + def _split_lines(self, text, width): + """ + 1. Sentences split across lines have their single newlines removed. + 2. Paragraphs and explicit newlines are split into separate lines. + 3. Each line is wrapped to the specified width (width of terminal). + """ + # The patterns also include whitespace after the newline + single_newline = re.compile(r"(? to the front, e,g: + # [Before] + # vllm serve -tp 2 --model --enforce-eager --port 8001 + # [After] + # vllm serve -tp 2 --enforce-eager --port 8001 + args = [ + "serve", + model_tag, + *args[1:model_idx], + *args[rest_start_idx:], + ] + except StopIteration: + pass + + if "--config" in args: + args = self._pull_args_from_config(args) + + def repl(match: re.Match) -> str: + """Replaces underscores with dashes in the matched string.""" + return match.group(0).replace("_", "-") + + # Everything between the first -- and the first . + pattern = re.compile(r"(?<=--)[^\.]*") + + # Convert underscores to dashes and vice versa in argument names + processed_args = list[str]() + for i, arg in enumerate(args): + if arg.startswith("--help="): + FlexibleArgumentParser._search_keyword = arg.split("=", 1)[-1].lower() + processed_args.append("--help") + elif arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) + key = pattern.sub(repl, key, count=1) + processed_args.append(f"{key}={value}") + else: + key = pattern.sub(repl, arg, count=1) + processed_args.append(key) + elif arg.startswith("-O") and arg != "-O" and arg[2] != ".": + # allow -O flag to be used without space, e.g. -O3 or -Odecode + # -O.<...> handled later + # also handle -O= here + mode = arg[3:] if arg[2] == "=" else arg[2:] + processed_args.append(f"-O.mode={mode}") + elif ( + arg == "-O" + and i + 1 < len(args) + and args[i + 1] in {"0", "1", "2", "3"} + ): + # Convert -O to -O.mode + processed_args.append("-O.mode") + else: + processed_args.append(arg) + + def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]: + """Creates a nested dictionary from a list of keys and a value. + + For example, `keys = ["a", "b", "c"]` and `value = 1` will create: + `{"a": {"b": {"c": 1}}}` + """ + nested_dict: Any = value + for key in reversed(keys): + nested_dict = {key: nested_dict} + return nested_dict + + def recursive_dict_update( + original: dict[str, Any], + update: dict[str, Any], + ) -> set[str]: + """Recursively updates a dictionary with another dictionary. + Returns a set of duplicate keys that were overwritten. + """ + duplicates = set[str]() + for k, v in update.items(): + if isinstance(v, dict) and isinstance(original.get(k), dict): + nested_duplicates = recursive_dict_update(original[k], v) + duplicates |= {f"{k}.{d}" for d in nested_duplicates} + elif isinstance(v, list) and isinstance(original.get(k), list): + original[k] += v + else: + if k in original: + duplicates.add(k) + original[k] = v + return duplicates + + delete = set[int]() + dict_args = defaultdict[str, dict[str, Any]](dict) + duplicates = set[str]() + for i, processed_arg in enumerate(processed_args): + if i in delete: # skip if value from previous arg + continue + + if processed_arg.startswith("-") and "." in processed_arg: + if "=" in processed_arg: + processed_arg, value_str = processed_arg.split("=", 1) + if "." not in processed_arg: + # False positive, '.' was only in the value + continue + else: + value_str = processed_args[i + 1] + delete.add(i + 1) + + if processed_arg.endswith("+"): + processed_arg = processed_arg[:-1] + value_str = json.dumps(list(value_str.split(","))) + + key, *keys = processed_arg.split(".") + try: + value = json.loads(value_str) + except json.decoder.JSONDecodeError: + value = value_str + + # Merge all values with the same key into a single dict + arg_dict = create_nested_dict(keys, value) + arg_duplicates = recursive_dict_update(dict_args[key], arg_dict) + duplicates |= {f"{key}.{d}" for d in arg_duplicates} + delete.add(i) + # Filter out the dict args we set to None + processed_args = [a for i, a in enumerate(processed_args) if i not in delete] + if duplicates: + logger.warning("Found duplicate keys %s", ", ".join(duplicates)) + + # Add the dict args back as if they were originally passed as JSON + for dict_arg, dict_value in dict_args.items(): + processed_args.append(dict_arg) + processed_args.append(json.dumps(dict_value)) + + return super().parse_args(processed_args, namespace) + + def check_port(self, value): + try: + value = int(value) + except ValueError: + msg = "Port must be an integer" + raise ArgumentTypeError(msg) from None + + if not (1024 <= value <= 65535): + raise ArgumentTypeError("Port must be between 1024 and 65535") + + return value + + def _pull_args_from_config(self, args: list[str]) -> list[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count("--config") <= 1, "More than one config file specified!" + + index = args.index("--config") + if index == len(args) - 1: + raise ValueError( + "No config file specified! \ + Please check your command-line arguments." + ) + + file_path = args[index + 1] + + config_args = self.load_config_file(file_path) + + # 0th index might be the sub command {serve,chat,complete,...} + # optionally followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0].startswith("-"): + # No sub command (e.g., api_server entry point) + args = config_args + args[0:index] + args[index + 2 :] + elif args[0] == "serve": + model_in_cli = len(args) > 1 and not args[1].startswith("-") + model_in_config = any(arg == "--model" for arg in config_args) + + if not model_in_cli and not model_in_config: + raise ValueError( + "No model specified! Please specify model either " + "as a positional argument or in a config file." + ) + + if model_in_cli: + # Model specified as positional arg, keep CLI version + args = ( + [args[0]] + + [args[1]] + + config_args + + args[2:index] + + args[index + 2 :] + ) + else: + # No model in CLI, use config if available + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] + else: + args = [args[0]] + config_args + args[1:index] + args[index + 2 :] + + return args + + def load_config_file(self, file_path: str) -> list[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + """ + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml"): + raise ValueError( + f"Config file must be of a yaml/yml type. {extension} supplied" + ) + + # only expecting a flat dictionary of atomic types + processed_args: list[str] = [] + + config: dict[str, int | str] = {} + try: + with open(file_path) as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. Check path correctness", + file_path, + ) + raise ex + + for key, value in config.items(): + if isinstance(value, bool): + if value: + processed_args.append("--" + key) + elif isinstance(value, list): + if value: + processed_args.append("--" + key) + for item in value: + processed_args.append(str(item)) + else: + processed_args.append("--" + key) + processed_args.append(str(value)) + + return processed_args diff --git a/vllm/_utils/async_utils.py b/vllm/_utils/async_utils.py new file mode 100644 index 0000000..b6c24e1 --- /dev/null +++ b/vllm/_utils/async_utils.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers related to asynchronous code. + +This is similar in concept to the `asyncio` module. +""" + +import asyncio +import contextlib +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task +from collections.abc import AsyncGenerator, Awaitable, Callable +from concurrent.futures import Executor, ThreadPoolExecutor +from functools import partial +from typing import TypeVar + +from transformers.tokenization_utils_base import BatchEncoding +from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + + +class AsyncMicrobatchTokenizer: + """Asynchronous tokenizer with micro-batching. + + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used + so the event loop stays responsive. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + self._loop = asyncio.get_running_loop() + self._queues: dict[ + tuple, + asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]], + ] = {} + self._batcher_tasks: list[Task] = [] + + # Single-thread executor for blocking tokenizer calls. + self._executor = ThreadPoolExecutor(max_workers=1) + + # === Public async API === + async def __call__(self, prompt, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("encode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((prompt, kwargs, result_future)) + return await result_future + + async def decode(self, token_ids, **kwargs): + result_future: Future = self._loop.create_future() + key = self._queue_key("decode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((token_ids, result_future)) + return await result_future + + # === Internal helpers === + def _get_queue( + self, loop: asyncio.AbstractEventLoop, key: tuple + ) -> asyncio.Queue[tuple[str, dict, Future] | tuple[list[int], Future]]: + """Get the request queue for the given operation key, creating a new + queue and batcher task if needed.""" + queue = self._queues.get(key) + if queue is None: + self._queues[key] = queue = asyncio.Queue() + if key[0] == "encode": + can_batch = key[1] != "other" + coro = self._batch_encode_loop(queue, can_batch) + else: + assert key[0] == "decode", f"Unknown operation type: {key[0]}." + coro = self._batch_decode_loop(queue) + self._batcher_tasks.append(loop.create_task(coro)) + return queue + + async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): + """Batch incoming encode requests for efficiency.""" + while True: + prompt, kwargs, result_future = await queue.get() + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(prompts) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + prompt, kwargs, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + prompts.append(prompt) + result_futures.append(result_future) + if not can_batch: + kwargs_list.append(kwargs) + except asyncio.TimeoutError: + break + + try: + # If every request uses identical kwargs we can run a single + # batched tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + batch_encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await self._loop.run_in_executor( + self._executor, batch_encode_fn + ) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(BatchEncoding(data)) + else: + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs) + ] + results = await self._loop.run_in_executor( + self._executor, encode_fn + ) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + async def _batch_decode_loop(self, queue: asyncio.Queue): + """Batch incoming decode requests for efficiency.""" + while True: + token_ids, result_future = await queue.get() + token_ids_list = [token_ids] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(token_ids_list) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + token_ids, result_future = await asyncio.wait_for( + queue.get(), timeout + ) + token_ids_list.append(token_ids) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + try: + # Perform a single batched decode call for all requests + results = await self._loop.run_in_executor( + self._executor, self.tokenizer.batch_decode, token_ids_list + ) + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def _queue_key(self, op: str, kwargs: dict) -> tuple: + """ + Return a normalized key describing operation + kwargs. + + - `add_special_tokens`: {True/False} + - `truncation`: {True/False} + - If `truncation` is False (`max_length` is None), + returns a key for a can_batch queue. + - If `truncation` is True and `max_length` is None or equals + `tokenizer.model_max_length`, returns a key for a can_batch queue. + - Otherwise, returns a key for a cannot_batch queue. + + Examples: + - Decode: ("decode",) + - Encode typical: + ("encode", add_special_tokens, bool_truncation, max_length_label) + - Fallback: ("encode", "other") + """ + + if op == "decode": + return ("decode",) + + add_special_tokens = kwargs.get("add_special_tokens", True) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + if not truncation: + return "encode", add_special_tokens, False, None + + model_max = getattr(self.tokenizer, "model_max_length", None) + if max_length is None or (model_max is not None and max_length == model_max): + return "encode", add_special_tokens, True, "model_max" + + return "encode", "other" + + def __del__(self): + if ( + (tasks := getattr(self, "_batcher_tasks", None)) + and (loop := getattr(self, "_loop", None)) + and not loop.is_closed() + ): + + def cancel_tasks(): + for task in tasks: + task.cancel() + + loop.call_soon_threadsafe(cancel_tasks) + + +def cancel_task_threadsafe(task: Task): + if task and not task.done(): + run_in_loop(task.get_loop(), task.cancel) + + +def make_async( + func: Callable[P, T], + executor: Executor | None = None, +) -> Callable[P, Awaitable[T]]: + """ + Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> Future[T]: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=executor, func=p_func) + + return _async_wrapper + + +def run_in_loop(loop: AbstractEventLoop, function: Callable, *args): + if in_loop(loop): + function(*args) + elif not loop.is_closed(): + loop.call_soon_threadsafe(function, *args) + + +def in_loop(event_loop: AbstractEventLoop) -> bool: + try: + return asyncio.get_running_loop() == event_loop + except RuntimeError: + return False + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], +) -> AsyncGenerator[tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + if len(iterators) == 1: + # Fast-path single iterator case. + async for item in iterators[0]: + yield 0, item + return + + loop = asyncio.get_running_loop() + + awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} + try: + while awaits: + done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[loop.create_task(anext(it))] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator(iterator: AsyncGenerator[T, None]) -> list[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items diff --git a/vllm/_utils/cache.py b/vllm/_utils/cache.py new file mode 100644 index 0000000..4338983 --- /dev/null +++ b/vllm/_utils/cache.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import UserDict +from collections.abc import Callable, Hashable, Iterator, KeysView, Mapping +from types import MappingProxyType +from typing import NamedTuple, TypeVar, cast, overload + +import cachetools + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") +_T = TypeVar("_T") + + +class _Sentinel: ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class _MappingOrderCacheView(UserDict[_K, _V]): + def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]): + super().__init__(data) + self.ordered_keys = ordered_keys + + def __iter__(self) -> Iterator[_K]: + return iter(self.ordered_keys) + + def keys(self) -> KeysView[_K]: + return KeysView(self.ordered_keys) + + +class CacheInfo(NamedTuple): + hits: int + total: int + + @property + def hit_ratio(self) -> float: + if self.total == 0: + return 0 + + return self.hits / self.total + + def __sub__(self, other: "CacheInfo"): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + + +class LRUCache(cachetools.LRUCache[_K, _V]): + def __init__(self, capacity: float, getsizeof: Callable[[_V], float] | None = None): + super().__init__(capacity, getsizeof) + + self.pinned_items = set[_K]() + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value + + def __delitem__(self, key: _K) -> None: + run_on_remove = key in self + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + super().__delitem__(key) + if key in self.pinned_items: + # Todo: add warning to inform that del pinned item + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + + @property + def cache(self) -> Mapping[_K, _V]: + """Return the internal cache dictionary in order (read-only).""" + return _MappingOrderCacheView( + self._Cache__data, # type: ignore + self.order, + ) + + @property + def order(self) -> Mapping[_K, None]: + """Return the internal order dictionary (read-only).""" + return MappingProxyType(self._LRUCache__order) # type: ignore + + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If `delta=True`, instead gets these statistics + since the last call that also passed `delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info + + def touch(self, key: _K) -> None: + try: + self._LRUCache__order.move_to_end(key) # type: ignore + except KeyError: + self._LRUCache__order[key] = None # type: ignore + + @overload + def get(self, key: _K, /) -> _V | None: ... + + @overload + def get(self, key: _K, /, default: _V | _T) -> _V | _T: ... + + def get(self, key: _K, /, default: _V | _T | None = None) -> _V | _T | None: + value: _V | _T | None + if key in self: + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + + self._hits += 1 + else: + value = default + + self._total += 1 + return value + + @overload + def pop(self, key: _K) -> _V: ... + + @overload + def pop(self, key: _K, default: _V | _T) -> _V | _T: ... + + def pop(self, key: _K, default: _V | _T | None = None) -> _V | _T | None: + value: _V | _T | None + if key not in self: + return default + + value = self.__getitem__(key, update_info=False) # type: ignore[call-arg] + self.__delitem__(key) + return value + + def put(self, key: _K, value: _V) -> None: + self.__setitem__(key, value) + + def pin(self, key: _K) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: _K) -> None: + """ + Unpins a key in the cache allowing it to be + evicted in the LRU order. + """ + self.pinned_items.remove(key) + + def _on_remove(self, key: _K, value: _V | None) -> None: + pass + + def remove_oldest(self, *, remove_pinned: bool = False) -> None: + if len(self) == 0: + return + + self.popitem(remove_pinned=remove_pinned) + + def _remove_old_if_needed(self) -> None: + while self.currsize > self.capacity: + self.remove_oldest() + + def popitem(self, remove_pinned: bool = False): + """Remove and return the `(key, value)` pair least recently used.""" + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.order if key not in self.pinned_items), + ALL_PINNED_SENTINEL, + ) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError( + "All items are pinned, cannot remove oldest from the cache." + ) + else: + lru_key = next(iter(self.order)) + value = self.pop(cast(_K, lru_key)) + return (lru_key, value) + + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) diff --git a/vllm/_utils/collection_utils.py b/vllm/_utils/collection_utils.py new file mode 100644 index 0000000..5727131 --- /dev/null +++ b/vllm/_utils/collection_utils.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to collections. + +This is similar in concept to the `collections` module. +""" + +from collections import UserDict, defaultdict +from collections.abc import Callable, Generator, Hashable, Iterable, Mapping +from typing import Generic, Literal, TypeVar + +from typing_extensions import TypeIs, assert_never + +T = TypeVar("T") +U = TypeVar("U") + +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +class ClassRegistry(UserDict[type[T], _V]): + """ + A registry that acts like a dictionary but searches for other classes + in the MRO if the original class is not found. + """ + + def __getitem__(self, key: type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + return self.contains(key) + + def contains(self, key: object, *, strict: bool = False) -> bool: + if not isinstance(key, type): + return False + + if strict: + return key in self.data + + return any(cls in self.data for cls in key.mro()) + + +class LazyDict(Mapping[str, T], Generic[T]): + """ + Evaluates dictionary items only when they are accessed. + + Adapted from: https://stackoverflow.com/a/47212782/5082708 + """ + + def __init__(self, factory: dict[str, Callable[[], T]]): + self._factory = factory + self._dict: dict[str, T] = {} + + def __getitem__(self, key: str) -> T: + if key not in self._dict: + if key not in self._factory: + raise KeyError(key) + self._dict[key] = self._factory[key]() + return self._dict[key] + + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + + def __iter__(self): + return iter(self._factory) + + def __len__(self): + return len(self._factory) + + +def as_list(maybe_list: Iterable[T]) -> list[T]: + """Convert iterable to list, unless it's already a list.""" + return maybe_list if isinstance(maybe_list, list) else list(maybe_list) + + +def as_iter(obj: T | Iterable[T]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + return [obj] # type: ignore[list-item] + return obj + + +def is_list_of( + value: object, + typ: type[T] | tuple[type[T], ...], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[list[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]: + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i : i + chunk_size] + + +def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike [`itertools.groupby`][], groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + +def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: + """Swap values between two keys.""" + v1 = obj.get(key1) + v2 = obj.get(key2) + if v1 is not None: + obj[key2] = v1 + else: + obj.pop(key2, None) + if v2 is not None: + obj[key1] = v2 + else: + obj.pop(key1, None) diff --git a/vllm/_utils/counter.py b/vllm/_utils/counter.py new file mode 100644 index 0000000..c2dce32 --- /dev/null +++ b/vllm/_utils/counter.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import threading + + +class Counter: + def __init__(self, start: int = 0) -> None: + super().__init__() + + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial: int = 0) -> None: + """Initialize a new atomic counter to given initial value""" + super().__init__() + + self._value = initial + self._lock = threading.Lock() + + @property + def value(self) -> int: + return self._value + + def inc(self, num: int = 1) -> int: + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num: int = 1) -> int: + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value diff --git a/vllm/_utils/deep_gemm.py b/vllm/_utils/deep_gemm.py new file mode 100644 index 0000000..b5ab375 --- /dev/null +++ b/vllm/_utils/deep_gemm.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for DeepGEMM API changes. + +Users of vLLM should always import **only** these wrappers. +""" + +import functools +import importlib +import os +from collections.abc import Callable +from enum import Enum +from typing import Any, NoReturn + +import torch + +import vllm.envs as envs +from vllm.logger import logger +from vllm.platforms import current_platform +from vllm.utils.import_utils import has_deep_gemm +from vllm.utils.math_utils import cdiv + + +class DeepGemmQuantScaleFMT(Enum): + # Float32 scales in Float32 tensor + FLOAT32 = 0 + # Compute float32 scales and ceil the scales to UE8M0. + # Keep the scales in Float32 tensor. + FLOAT32_CEIL_UE8M0 = 1 + # Compute float32 scales and ceil the scales to UE8M0. + # Pack the scales into a int32 tensor where each int32 + # element contains 4 scale values. + UE8M0 = 2 + + @staticmethod + def from_oracle() -> "DeepGemmQuantScaleFMT": + if not is_deep_gemm_e8m0_used(): + return DeepGemmQuantScaleFMT.FLOAT32 + return ( + DeepGemmQuantScaleFMT.UE8M0 + if current_platform.is_device_capability(100) + else DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0 + ) + + +@functools.cache +def is_deep_gemm_supported() -> bool: + """Return `True` if DeepGEMM is supported on the current platform. + Currently, only Hopper and Blackwell GPUs are supported. + """ + is_supported_arch = current_platform.is_cuda() and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability(100) + ) + return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch + + +@functools.cache +def is_deep_gemm_e8m0_used() -> bool: + """Return `True` if vLLM is configured to use DeepGEMM " + "E8M0 scale on a Hopper or Blackwell-class GPU. + """ + if not is_deep_gemm_supported(): + logger.debug_once( + "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system." + ) + return False + + _lazy_init() + + if _fp8_gemm_nt_impl is None: + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + return False + + if envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on current platform.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable DeepGEMM backend.""" + raise RuntimeError( + "DeepGEMM backend is not available or outdated. Please install or " + "update the `deep_gemm` to a newer version to enable FP8 kernels." + ) + + +_fp8_gemm_nt_impl: Callable[..., Any] | None = None +_grouped_impl: Callable[..., Any] | None = None +_grouped_masked_impl: Callable[..., Any] | None = None +_fp8_mqa_logits_impl: Callable[..., Any] | None = None +_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None +_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None +_get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None +_transform_sf_into_required_layout_impl: Callable[..., Any] | None = None + + +def _lazy_init() -> None: + """Import deep_gemm and resolve symbols on first use.""" + global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl + global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _get_paged_mqa_logits_metadata_impl + global _get_mn_major_tma_aligned_tensor_impl + global _get_mk_alignment_for_contiguous_layout_impl + global _transform_sf_into_required_layout_impl + # fast path + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + or _get_mk_alignment_for_contiguous_layout_impl is not None + or _transform_sf_into_required_layout_impl is not None + ): + return + + if not has_deep_gemm(): + return + + # Set up deep_gemm cache path + DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR" + if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None): + os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join( + envs.VLLM_CACHE_ROOT, "deep_gemm" + ) + + _dg = importlib.import_module("deep_gemm") + + _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) + _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) + _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) + _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) + _get_paged_mqa_logits_metadata_impl = getattr( + _dg, "get_paged_mqa_logits_metadata", None + ) + _get_mn_major_tma_aligned_tensor_impl = getattr( + _dg, "get_mn_major_tma_aligned_tensor", None + ) + _get_mk_alignment_for_contiguous_layout_impl = getattr( + _dg, "get_mk_alignment_for_contiguous_layout", None + ) + _transform_sf_into_required_layout_impl = getattr( + _dg, "transform_sf_into_required_layout", None + ) + + +def get_num_sms() -> int: + _lazy_init() + _dg = importlib.import_module("deep_gemm") + return int(_dg.get_num_sms()) + + +@functools.cache +def get_mk_alignment_for_contiguous_layout() -> list[int]: + _lazy_init() + if _get_mk_alignment_for_contiguous_layout_impl is None: + return _missing() + mk_align_size = _get_mk_alignment_for_contiguous_layout_impl() + return [mk_align_size, mk_align_size] + + +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor""" + _lazy_init() + if _get_mn_major_tma_aligned_tensor_impl is None: + return _missing() + return _get_mn_major_tma_aligned_tensor_impl(x) + + +def fp8_gemm_nt(*args, **kwargs): + _lazy_init() + if _fp8_gemm_nt_impl is None: + return _missing(*args, **kwargs) + if "is_deep_gemm_e8m0_used" in kwargs: + use_ue8m0 = kwargs["is_deep_gemm_e8m0_used"] + del kwargs["is_deep_gemm_e8m0_used"] + else: + use_ue8m0 = is_deep_gemm_e8m0_used() + return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs) + + +def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): + _lazy_init() + if _grouped_impl is None: + return _missing(*args, **kwargs) + return _grouped_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + +def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): + _lazy_init() + if _grouped_masked_impl is None: + return _missing(*args, **kwargs) + return _grouped_masked_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + +def transform_sf_into_required_layout(*args, **kwargs): + _lazy_init() + if _transform_sf_into_required_layout_impl is None: + return _missing(*args, **kwargs) + return _transform_sf_into_required_layout_impl( + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs + ) + + +def fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + _lazy_init() + if _fp8_mqa_logits_impl is None: + return _missing() + return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: + """Build scheduling metadata for paged MQA logits. + + Args: + context_lens: Tensor of shape [B], dtype int32; effective context length + per batch element. + block_size: KV-cache block size in tokens (e.g., 64). + num_sms: Number of SMs available. 132 for Hopper + + Returns: + Backend-specific tensor consumed by `fp8_paged_mqa_logits` to + schedule work across SMs. + """ + _lazy_init() + if _get_paged_mqa_logits_metadata_impl is None: + return _missing() + return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) + + +def fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + _lazy_init() + if _fp8_paged_mqa_logits_impl is None: + return _missing() + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True, + ) + + +def _ceil_to_ue8m0(x: torch.Tensor): + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def _align(x: int, y: int) -> int: + return cdiv(x, y) * y + + +DEFAULT_BLOCK_SIZE = [128, 128] + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/dd6ed14acbc7445dcef224248a77ab4d22b5f240/deep_gemm/utils/math.py#L38 +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def per_block_cast_to_fp8( + x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False +) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + block_m, block_n = block_size + x_padded = torch.zeros( + (_align(m, block_m), _align(n, block_n)), dtype=x.dtype, device=x.device + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( + x_view.size(0), x_view.size(2) + ) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + """Return a global difference metric for unit tests. + + DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element + error, causing `torch.testing.assert_close` to fail. Instead of checking + every element, we compute a cosine-style similarity over the whole tensor + and report `1 - sim`. Once kernel accuracy improves this helper can be + removed. + """ + + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def should_use_deepgemm_for_fp8_linear( + output_dtype: torch.dtype, + weight: torch.Tensor, + supports_deep_gemm: bool | None = None, +): + if supports_deep_gemm is None: + supports_deep_gemm = is_deep_gemm_supported() + return ( + supports_deep_gemm + and output_dtype == torch.bfloat16 + and weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0 + ) + + +__all__ = [ + "calc_diff", + "fp8_gemm_nt", + "m_grouped_fp8_gemm_nt_contiguous", + "fp8_m_grouped_gemm_nt_masked", + "fp8_mqa_logits", + "fp8_paged_mqa_logits", + "get_paged_mqa_logits_metadata", + "per_block_cast_to_fp8", + "is_deep_gemm_e8m0_used", + "is_deep_gemm_supported", + "get_num_sms", + "should_use_deepgemm_for_fp8_linear", + "get_col_major_tma_aligned_tensor", + "get_mk_alignment_for_contiguous_layout", +] diff --git a/vllm/_utils/flashinfer.py b/vllm/_utils/flashinfer.py new file mode 100644 index 0000000..79e5a4c --- /dev/null +++ b/vllm/_utils/flashinfer.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for FlashInfer API changes. + +Users of vLLM should always import **only** these wrappers. +""" + +import contextlib +import functools +import importlib +import importlib.util +import os +import shutil +from collections.abc import Callable +from typing import Any, NoReturn + +import requests +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import ( + vllm_is_batch_invariant, +) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# This is the storage path for the cubins, it can be replaced +# with a local path for testing. +# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501 +FLASHINFER_CUBINS_REPOSITORY = os.environ.get( + "FLASHINFER_CUBINS_REPOSITORY", + "https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501 +) + + +@functools.cache +def has_flashinfer_cubin() -> bool: + """Return `True` if flashinfer-cubin package is available.""" + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + if importlib.util.find_spec("flashinfer_cubin") is not None: + return True + logger.debug_once("flashinfer-cubin package was not found") + return False + + +@functools.cache +def has_flashinfer() -> bool: + """Return `True` if flashinfer-python package is available.""" + # Use find_spec to check if the module exists without importing it + # This avoids potential CUDA initialization side effects + if importlib.util.find_spec("flashinfer") is None: + logger.debug_once("FlashInfer unavailable since package was not found") + return False + # When not using flashinfer cubin, + # Also check if nvcc is available since it's required to JIT compile flashinfer + if not has_flashinfer_cubin() and shutil.which("nvcc") is None: + logger.debug_once( + "FlashInfer unavailable since nvcc was not found " + "and not using pre-downloaded cubins" + ) + return False + return True + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable FlashInfer backend.""" + raise RuntimeError( + "FlashInfer backend is not available. Please install the package " + "to enable FlashInfer kernels: " + "https://github.com/flashinfer-ai/flashinfer" + ) + + +def _get_submodule(module_name: str) -> Any | None: + """Safely import a submodule and return it, or None if not available.""" + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +# General lazy import wrapper +def _lazy_import_wrapper( + module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing +): + """Create a lazy import wrapper for a specific function.""" + + @functools.cache + def _get_impl(): + if not has_flashinfer(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) + + return wrapper + + +# Create lazy wrappers for each function +flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe" +) +flashinfer_trtllm_fp8_per_tensor_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe" +) +flashinfer_cutlass_fused_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "cutlass_fused_moe" +) +flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") +nvfp4_block_scale_interleave = _lazy_import_wrapper( + "flashinfer", "nvfp4_block_scale_interleave" +) +trtllm_fp4_block_scale_moe = _lazy_import_wrapper( + "flashinfer", "trtllm_fp4_block_scale_moe" +) + +# Special case for autotune since it returns a context manager +autotune = _lazy_import_wrapper( + "flashinfer.autotuner", + "autotune", + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext(), +) + + +@functools.cache +def has_flashinfer_comm() -> bool: + """Return `True` if FlashInfer comm module is available.""" + return has_flashinfer() and importlib.util.find_spec("flashinfer.comm") is not None + + +@functools.cache +def has_flashinfer_all2all() -> bool: + """Return `True` if FlashInfer mnnvl all2all is available.""" + if not has_flashinfer_comm(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.comm", "Mapping"), + ("flashinfer.comm.mnnvl", "MnnvlMemory"), + ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"), + ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + +@functools.cache +def has_flashinfer_moe() -> bool: + """Return `True` if FlashInfer MoE module is available.""" + return ( + has_flashinfer() + and importlib.util.find_spec("flashinfer.fused_moe") is not None + ) + + +@functools.cache +def has_flashinfer_cutlass_fused_moe() -> bool: + """Return `True` if FlashInfer CUTLASS fused MoE is available.""" + if not has_flashinfer_moe(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.fused_moe", "cutlass_fused_moe"), + ("flashinfer", "fp4_quantize"), + ("flashinfer", "nvfp4_block_scale_interleave"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + +@functools.cache +def has_nvidia_artifactory() -> bool: + """Return `True` if NVIDIA's artifactory is accessible. + + This checks connectivity to the kernel inference library artifactory + which is required for downloading certain cubin kernels like TRTLLM FHMA. + """ + # If we have pre-downloaded cubins, we can assume the cubins are available. + if has_flashinfer_cubin(): + return True + + try: + # Use a short timeout to avoid blocking for too long + response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) + accessible = response.status_code == 200 + if accessible: + logger.debug_once("NVIDIA artifactory is accessible") + else: + logger.warning_once( + "NVIDIA artifactory returned failed status code: %d", + response.status_code, + ) + return accessible + except Exception as e: + logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e) + return False + + +@functools.cache +def supports_trtllm_attention() -> bool: + """ + TRTLLM attention is supported if the platform is SM100, + NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. + """ + # Batch-invariant mode disables TRTLLM attention + if vllm_is_batch_invariant(): + return False + + # Requires SM100 and NVIDIA artifactory to be accessible to download cubins + return current_platform.is_device_capability(100) and has_nvidia_artifactory() + + +@functools.cache +def _force_use_trtllm_attention(env_value: bool | None) -> bool | None: + """Cache the env value for VLLM_USE_TRTLLM_ATTENTION""" + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_ATTENTION is set to %s", env_value) + return env_value + + +def force_use_trtllm_attention() -> bool | None: + """ + Return `None` if VLLM_USE_TRTLLM_ATTENTION is not set, + return `True` if TRTLLM attention is forced to be used, + return `False` if TRTLLM attention is forced to be not used. + """ + return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) + + +def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool: + """Check if the current configuration supports TRTLLM attention.""" + if force_use_trtllm_attention() is False: + return False + has_trtllm = supports_trtllm_attention() + return has_trtllm and (num_qo_heads % num_kv_heads == 0) + + +def use_trtllm_attention( + num_qo_heads: int, + num_kv_heads: int, + num_tokens: int, + max_seq_len: int, + dcp_world_size: int, + kv_cache_dtype: str, + q_dtype: torch.dtype, + is_prefill: bool, + has_sinks: bool = False, + has_spec: bool = False, +) -> bool: + """Return `True` if TRTLLM attention is used.""" + force_use_trtllm = force_use_trtllm_attention() + + # Environment variable is set to 0 - respect it + if force_use_trtllm is not None and not force_use_trtllm: + return False + + # Decode context parallel is not supported + if dcp_world_size > 1: + logger.warning_once( + "Trtllm does not support returning LSE and as a result " + "does not support DCP, reverting to FlashInfer" + ) + return False + + # The platform is not supported + if not supports_trtllm_attention(): + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported on this platform, " + "but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) + return False + + # The combination of query and key heads is not supported + if num_qo_heads % num_kv_heads != 0: + if force_use_trtllm: + logger.warning_once( + "TRTLLM attention is not supported for this combination of " + "query and key heads, but VLLM_USE_TRTLLM_ATTENTION is set to 1" + ) + return False + + if has_spec and not is_prefill: + # Speculative decoding requires TRTLLM attention for decodes + logger.info_once("Using TRTLLM attention (enabled for speculative decoding).") + return True + + # Must use TRTLLM attention if query is FP8 quantized + if q_dtype == current_platform.fp8_dtype(): + logger.info_once("Using TRTLLM attention (query is quantized).") + return True + + # If sinks are being used, we must use TRTLLM attention as it's + # the only backend that supports them + if has_sinks: + logger.info_once("Using TRTLLM attention (required for attention sinks).") + return True + + if force_use_trtllm is None: + # Environment variable not set - use auto-detection + if is_prefill: + # Prefill auto-detection + use_trtllm = max_seq_len <= 131072 and kv_cache_dtype == "auto" + if use_trtllm: + logger.warning_once("Using TRTLLM prefill attention (auto-detected).") + else: + # Decode auto-detection + use_trtllm = ( + num_tokens <= 256 and max_seq_len <= 131072 and kv_cache_dtype == "auto" + ) + if use_trtllm: + logger.warning_once("Using TRTLLM decode attention (auto-detected).") + return use_trtllm + + # Environment variable is set to 1 - respect it + logger.info_once("Using TRTLLM attention (VLLM_USE_TRTLLM_ATTENTION is set to 1)") + return True + + +if has_flashinfer(): + + @torch.library.custom_op( + "vllm::flashinfer_mm_fp4", + mutates_args=[], + device_types="cuda", + ) + def flashinfer_mm_fp4( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + + return flashinfer_mm_fp4_( + A, B, A_scale, B_scale, g_scale, dtype, block_size=16, backend=backend + ) + + @torch.library.register_fake( + "vllm::flashinfer_mm_fp4", + ) + def flashinfer_mm_fp4_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + g_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty(A.shape[0], B.shape[1], dtype=dtype, device=A.device) + + @torch.library.custom_op( + "vllm::bmm_fp8", + mutates_args=[], + device_types="cuda", + ) + def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + from flashinfer import bmm_fp8 as bmm_fp8_ + + return bmm_fp8_(A, B, A_scale, B_scale, dtype, None, backend) + + @torch.library.register_fake( + "vllm::bmm_fp8", + ) + def bmm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + backend: str, + ) -> torch.Tensor: + return torch.empty( + A.shape[0], A.shape[1], B.shape[2], dtype=dtype, device=A.device + ) + + +def flashinfer_scaled_fp4_mm( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype, + backend: str, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 and b.stride(-1) == 1 + assert a.shape[1] == b.shape[1] + + if backend == "cutlass": + block_scale_a = block_scale_a.view(torch.uint8) + block_scale_b = block_scale_b.view(torch.uint8) + + return flashinfer_mm_fp4( + a, + b.t(), + block_scale_a, + block_scale_b.t(), + alpha, + out_dtype, + backend=backend, + ) + + +def flashinfer_scaled_fp8_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + assert a.shape[1] == b.shape[0] + assert scale_a.numel() == 1 and scale_b.numel() == 1 + assert a.dtype == torch.float8_e4m3fn and b.dtype == torch.float8_e4m3fn + assert a.device.type == "cuda" and b.device.type == "cuda" + assert scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32 + assert scale_a.device.type == "cuda" and scale_b.device.type == "cuda" + + output = bmm_fp8( + a.unsqueeze(0), + b.unsqueeze(0), + scale_a, + scale_b, + out_dtype, + "auto", + ).view(a.shape[0], b.shape[1]) + + if bias is not None: + output = output + bias + return output + + +@functools.cache +def flashinfer_disable_q_quantization() -> bool: + """Cache result which only depends on the environment""" + return envs.VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION + + +__all__ = [ + "has_flashinfer", + "flashinfer_trtllm_fp8_block_scale_moe", + "flashinfer_cutlass_fused_moe", + "flashinfer_fp4_quantize", + "nvfp4_block_scale_interleave", + "trtllm_fp4_block_scale_moe", + "autotune", + "has_flashinfer_moe", + "has_flashinfer_comm", + "has_flashinfer_all2all", + "has_flashinfer_cutlass_fused_moe", + "has_nvidia_artifactory", + "supports_trtllm_attention", + "can_use_trtllm_attention", + "use_trtllm_attention", + "flashinfer_disable_q_quantization", + "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp8_mm", +] diff --git a/vllm/_utils/func_utils.py b/vllm/_utils/func_utils.py new file mode 100644 index 0000000..c061a0d --- /dev/null +++ b/vllm/_utils/func_utils.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers that are applied to functions. + +This is similar in concept to the `functools` module. +""" + +import inspect +import threading +import warnings +from collections.abc import Callable, Mapping +from functools import lru_cache, partial, wraps +from typing import Any, TypeVar + +from typing_extensions import ParamSpec + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +P = ParamSpec("P") +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) + + +def identity(value: T, **kwargs) -> T: + """Returns the first provided value.""" + return value + + +def run_once(f: Callable[P, None]) -> Callable[P, None]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if wrapper.has_run: # type: ignore[attr-defined] + return + + with wrapper.lock: # type: ignore[attr-defined] + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + wrapper.lock = threading.Lock() # type: ignore[attr-defined] + return wrapper + + +def deprecate_args( + start_index: int, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + params = inspect.signature(fn).parameters + pos_types = ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + pos_kws = [kw for kw, param in params.items() if param.kind in pos_types] + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_args = pos_kws[start_index : len(args)] + if deprecated_args: + msg = ( + f"The positional arguments {deprecated_args} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +def deprecate_kwargs( + *kws: str, + is_deprecated: bool | Callable[[], bool] = True, + additional_message: str | None = None, +) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update." + ) + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache +def supports_kw( + callable: Callable[..., object], + kw_name: str, + *, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set( + ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if ( + requires_kw_only + and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY + ): + return False + if (requires_kw_only and param_val.kind == inspect.Parameter.KEYWORD_ONLY) or ( + not requires_kw_only and is_sig_param + ): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return ( + last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name + ) + + return False + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Mapping[str, object] | None, + *, + requires_kw_only: bool = True, + allow_var_kwargs: bool = False, +) -> dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw( + callable, + kwarg_name, + requires_kw_only=requires_kw_only, + allow_var_kwargs=allow_var_kwargs, + ) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + if requires_kw_only: + logger.warning( + "The following intended overrides are not keyword-only args " + "and will be dropped: %s", + dropped_keys, + ) + else: + logger.warning( + "The following intended overrides are not keyword args " + "and will be dropped: %s", + dropped_keys, + ) + + return filtered_overrides diff --git a/vllm/_utils/gc_utils.py b/vllm/_utils/gc_utils.py new file mode 100644 index 0000000..160ac9a --- /dev/null +++ b/vllm/_utils/gc_utils.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import json +import time +from collections import Counter +from contextlib import suppress +from typing import Any + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class GCDebugConfig: + """ + Config for GC Debugger. + - 0: disable GC debugger + - 1: enable GC debugger with gc.collect elpased times + - '{"top_objects":5}': enable GC debugger with top 5 collected objects + """ + + def __init__(self, gc_debug_conf: str | None = None) -> None: + self.enabled: bool = False + self.top_objects: int = -1 + + if not gc_debug_conf or gc_debug_conf == "0": + pass + elif gc_debug_conf == "1": + self.enabled = True + else: + try: + json_conf = json.loads(gc_debug_conf) + self.enabled = True + self.top_objects = json_conf.get("top_objects", -1) + except Exception: + self.enabled = False + logger.error("Failed to parse VLLM_GC_DEBUG(%s)", envs.VLLM_GC_DEBUG) + logger.debug("GC Debug Config. %s", str(self)) + + def __repr__(self) -> str: + return f"enabled:{self.enabled},top_objects:{self.top_objects}" + + +class GCDebugger: + """ + Debugger for GC which logs helpful information for GC understanding. + To enable, you should call maybe_attach_gc_debug_callback in the process. + """ + + def __init__(self, config: GCDebugConfig) -> None: + self.config = config + # Start time in micro second of this GC cycle + self.start_time_ns: int = time.monotonic_ns() + # If config.top_objects is positive, + # compute top collected objects by object types + self.gc_top_collected_objects: str = "" + + def handle(self, phase: str, info: dict[str, int]) -> None: + """ + Handles a GC event (e.g. GC start or GC finish) + """ + generation = info.get("generation") + if generation is None: + return + if phase == "start": + # Before GC started, record GC start time + # and top collected objects + self.start_time_ns = time.monotonic_ns() + self.gc_top_collected_objects = _compute_top_gc_collected_objects( + gc.get_objects(generation), self.config.top_objects + ) + elif phase == "stop": + # After GC finished, Record GC elapsed time and + # optionally top collected objects + elpased_ms = (time.monotonic_ns() - self.start_time_ns) / 1e6 + logger.info( + "GC took %.3fms to complete. " + "Collected %s objects in GC generation %d.%s", + elpased_ms, + str(info.get("collected", "?")), + generation, + ( + f" Top collected objects: \n{self.gc_top_collected_objects}" + if self.gc_top_collected_objects + else "" + ), + ) + + +def freeze_gc_heap() -> None: + """ + Freeze all objects tracked by the garbage collector. It should be invoked + after server init / warmup, to reduce GC overhead from static objects + during serving time. + """ + # Ensure all static objects are pushed down to the oldest generation for + # freeze + gc.collect(0) + gc.collect(1) + gc.collect(2) + # Freeze all GC tracked objects + gc.freeze() + + +def maybe_attach_gc_debug_callback() -> None: + """ + Attached a callback for GC debug when VLLM_GC_DEBUG is enabled. + """ + config = GCDebugConfig(envs.VLLM_GC_DEBUG) + if config.enabled: + debugger: GCDebugger = GCDebugger(config) + + def gc_callback(phase: str, info: dict[str, int]) -> None: + debugger.handle(phase, info) + + gc.callbacks.append(gc_callback) + + +def _compute_detailed_type(o: Any) -> str: + """ + Detailed object type. + + TODO(Jialin): Further enhance the detailed type with element types for + easier debugging. We tried but occasionally it would run into signals + which kills the engine. + """ + size_str: str = "" + # Object doesn't support len() - this can happen with type objects + # or other objects that don't implement __len__ properly + with suppress(Exception): + size_str = f"(size:{len(o)})" + return f"{str(type(o))}{size_str}" + + +def _compute_top_gc_collected_objects(objects: list[Any], top: int) -> str: + """ + Group collected objects by types. + """ + if top <= 0: + return "" + object_types = [_compute_detailed_type(o) for o in objects] + return "\n".join( + f"{count:>5}:{object_type}" + for object_type, count in Counter(object_types).most_common(top) + ) diff --git a/vllm/_utils/hashing.py b/vllm/_utils/hashing.py new file mode 100644 index 0000000..49f4f13 --- /dev/null +++ b/vllm/_utils/hashing.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import hashlib +import pickle +from collections.abc import Callable +from typing import Any + +import cbor2 + + +def sha256(input: Any) -> bytes: + """Hash any picklable Python object using SHA-256. + + The input is serialized using pickle before hashing, which allows + arbitrary Python objects to be used. Note that this function does + not use a hash seed—if you need one, prepend it explicitly to the input. + + Args: + input: Any picklable Python object. + + Returns: + Bytes representing the SHA-256 hash of the serialized input. + """ + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return hashlib.sha256(input_bytes).digest() + + +def sha256_cbor(input: Any) -> bytes: + """Hash objects using CBOR serialization and SHA-256. + + This option is useful for non-Python-dependent serialization and hashing. + + Args: + input: Object to be serialized and hashed. Supported types include + basic Python types and complex structures like lists, tuples, and + dictionaries. + Custom classes must implement CBOR serialization methods. + + Returns: + Bytes representing the SHA-256 hash of the CBOR serialized input. + """ + input_bytes = cbor2.dumps(input, canonical=True) + return hashlib.sha256(input_bytes).digest() + + +def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: + """Get a hash function by name, or raise an error if the function is not found. + + Args: + hash_fn_name: Name of the hash function. + + Returns: + A hash function. + """ + if hash_fn_name == "sha256": + return sha256 + if hash_fn_name == "sha256_cbor": + return sha256_cbor + + raise ValueError(f"Unsupported hash function: {hash_fn_name}") diff --git a/vllm/_utils/import_utils.py b/vllm/_utils/import_utils.py new file mode 100644 index 0000000..f01d2c7 --- /dev/null +++ b/vllm/_utils/import_utils.py @@ -0,0 +1,411 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Contains helpers related to importing modules. + +This is similar in concept to the `importlib` module. +""" + +import importlib.metadata +import importlib.util +import os +import sys +from functools import cache +from types import ModuleType +from typing import Any + +import regex as re +from typing_extensions import Never + + +# TODO: This function can be removed if transformer_modules classes are +# serialized by value when communicating between processes +def init_cached_hf_modules() -> None: + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + + init_hf_modules() + + +def import_pynvml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of vLLM, and is installed when users + install vLLM. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one. What's worse, + the module is a Python package, and has higher priority than + the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import vllm.third_party.pynvml as pynvml + + return pynvml + + +def import_from_path(module_name: str, file_path: str | os.PathLike): + """ + Import a Python file according to its file path. + + Based on the official recipe: + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + """ + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ModuleNotFoundError(f"No module named {module_name!r}") + + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully-qualified class name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +@cache +def get_vllm_optional_dependencies(): + metadata = importlib.metadata.metadata("vllm") + requirements = metadata.get_all("Requires-Dist", []) + extras = metadata.get_all("Provides-Extra", []) + + return { + extra: [ + re.split(r";|>=|<=|==", req)[0] + for req in requirements + if req.endswith(f'extra == "{extra}"') + ] + for extra in extras + } + + +class _PlaceholderBase: + """ + Disallows downstream usage of placeholder modules. + + We need to explicitly override each dunder method because + [`__getattr__`][vllm.utils.import_utils._PlaceholderBase.__getattr__] + is not called when they are accessed. + + Info: + [Special method lookup](https://docs.python.org/3/reference/datamodel.html#special-lookup) + """ + + def __getattr__(self, key: str) -> Never: + """ + The main class should implement this to throw an error + for attribute accesses representing downstream usage. + """ + raise NotImplementedError + + # [Basic customization] + + def __lt__(self, other: object): + return self.__getattr__("__lt__") + + def __le__(self, other: object): + return self.__getattr__("__le__") + + def __eq__(self, other: object): + return self.__getattr__("__eq__") + + def __ne__(self, other: object): + return self.__getattr__("__ne__") + + def __gt__(self, other: object): + return self.__getattr__("__gt__") + + def __ge__(self, other: object): + return self.__getattr__("__ge__") + + def __hash__(self): + return self.__getattr__("__hash__") + + def __bool__(self): + return self.__getattr__("__bool__") + + # [Callable objects] + + def __call__(self, *args: object, **kwargs: object): + return self.__getattr__("__call__") + + # [Container types] + + def __len__(self): + return self.__getattr__("__len__") + + def __getitem__(self, key: object): + return self.__getattr__("__getitem__") + + def __setitem__(self, key: object, value: object): + return self.__getattr__("__setitem__") + + def __delitem__(self, key: object): + return self.__getattr__("__delitem__") + + # __missing__ is optional according to __getitem__ specification, + # so it is skipped + + # __iter__ and __reversed__ have a default implementation + # based on __len__ and __getitem__, so they are skipped. + + # [Numeric Types] + + def __add__(self, other: object): + return self.__getattr__("__add__") + + def __sub__(self, other: object): + return self.__getattr__("__sub__") + + def __mul__(self, other: object): + return self.__getattr__("__mul__") + + def __matmul__(self, other: object): + return self.__getattr__("__matmul__") + + def __truediv__(self, other: object): + return self.__getattr__("__truediv__") + + def __floordiv__(self, other: object): + return self.__getattr__("__floordiv__") + + def __mod__(self, other: object): + return self.__getattr__("__mod__") + + def __divmod__(self, other: object): + return self.__getattr__("__divmod__") + + def __pow__(self, other: object, modulo: object = ...): + return self.__getattr__("__pow__") + + def __lshift__(self, other: object): + return self.__getattr__("__lshift__") + + def __rshift__(self, other: object): + return self.__getattr__("__rshift__") + + def __and__(self, other: object): + return self.__getattr__("__and__") + + def __xor__(self, other: object): + return self.__getattr__("__xor__") + + def __or__(self, other: object): + return self.__getattr__("__or__") + + # r* and i* methods have lower priority than + # the methods for left operand so they are skipped + + def __neg__(self): + return self.__getattr__("__neg__") + + def __pos__(self): + return self.__getattr__("__pos__") + + def __abs__(self): + return self.__getattr__("__abs__") + + def __invert__(self): + return self.__getattr__("__invert__") + + # __complex__, __int__ and __float__ have a default implementation + # based on __index__, so they are skipped. + + def __index__(self): + return self.__getattr__("__index__") + + def __round__(self, ndigits: object = ...): + return self.__getattr__("__round__") + + def __trunc__(self): + return self.__getattr__("__trunc__") + + def __floor__(self): + return self.__getattr__("__floor__") + + def __ceil__(self): + return self.__getattr__("__ceil__") + + # [Context managers] + + def __enter__(self): + return self.__getattr__("__enter__") + + def __exit__(self, *args: object, **kwargs: object): + return self.__getattr__("__exit__") + + +class PlaceholderModule(_PlaceholderBase): + """ + A placeholder object to use when a module does not exist. + + This enables more informative errors when trying to access attributes + of a module that does not exist. + """ + + def __init__(self, name: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__name = name + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self, attr_path) + + def __getattr__(self, key: str) -> Never: + name = self.__name + + try: + importlib.import_module(name) + except ImportError as exc: + for extra, names in get_vllm_optional_dependencies().items(): + if name in names: + msg = f"Please install vllm[{extra}] for {extra} support" + raise ImportError(msg) from exc + + raise exc + + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) + + +class _PlaceholderModuleAttr(_PlaceholderBase): + def __init__(self, module: PlaceholderModule, attr_path: str) -> None: + super().__init__() + + # Apply name mangling to avoid conflicting with module attributes + self.__module = module + self.__attr_path = attr_path + + def placeholder_attr(self, attr_path: str): + return _PlaceholderModuleAttr(self.__module, f"{self.__attr_path}.{attr_path}") + + def __getattr__(self, key: str) -> Never: + getattr(self.__module, f"{self.__attr_path}.{key}") + + raise AssertionError( + "PlaceholderModule should not be used " + "when the original module can be imported" + ) + + +class LazyLoader(ModuleType): + """ + `LazyLoader` module borrowed from [Tensorflow] + (https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py) + with an addition of "module caching". + + Lazily import a module, mainly to avoid pulling in large dependencies. + Modules such as `xgrammar` might do additional side effects, so we + only want to use this when it is needed, delaying all eager effects. + """ + + def __init__( + self, + local_name: str, + parent_module_globals: dict[str, Any], + name: str, + ): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._module: ModuleType | None = None + + super().__init__(str(name)) + + def _load(self) -> ModuleType: + # Import the target module and insert it into the parent's namespace + try: + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # The additional add to sys.modules + # ensures library is actually loaded. + sys.modules[self._local_name] = module + except ModuleNotFoundError as err: + raise err from None + + # Update this object's dict so that if someone keeps a + # reference to the LazyLoader, lookups are efficient + # (__getattr__ is only called on lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item: Any) -> Any: + if self._module is None: + self._module = self._load() + return getattr(self._module, item) + + def __dir__(self) -> list[str]: + if self._module is None: + self._module = self._load() + return dir(self._module) + + +# Optional dependency detection utilities +@cache +def _has_module(module_name: str) -> bool: + """Return True if *module_name* can be found in the current environment. + + The result is cached so that subsequent queries for the same module incur + no additional overhead. + """ + return importlib.util.find_spec(module_name) is not None + + +def has_pplx() -> bool: + """Whether the optional `pplx_kernels` package is available.""" + return _has_module("pplx_kernels") + + +def has_deep_ep() -> bool: + """Whether the optional `deep_ep` package is available.""" + return _has_module("deep_ep") + + +def has_deep_gemm() -> bool: + """Whether the optional `deep_gemm` package is available.""" + return _has_module("deep_gemm") + + +def has_triton_kernels() -> bool: + """Whether the optional `triton_kernels` package is available.""" + return _has_module("triton_kernels") + + +def has_tilelang() -> bool: + """Whether the optional `tilelang` package is available.""" + return _has_module("tilelang") + + +def has_arctic_inference() -> bool: + """Whether the optional `arctic_inference` package is available.""" + + return _has_module("arctic_inference") diff --git a/vllm/_utils/jsontree.py b/vllm/_utils/jsontree.py new file mode 100644 index 0000000..cde9aa6 --- /dev/null +++ b/vllm/_utils/jsontree.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Helper functions to work with nested JSON structures.""" + +from collections.abc import Callable, Iterable +from functools import reduce +from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload + +if TYPE_CHECKING: + import torch + + from vllm.multimodal.inputs import BatchedTensorInputs + +_T = TypeVar("_T") +_U = TypeVar("_U") + +JSONTree: TypeAlias = ( + dict[str, "JSONTree[_T]"] | list["JSONTree[_T]"] | tuple["JSONTree[_T]", ...] | _T +) +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + +_JSONTree: TypeAlias = ( + dict[str, "JSONTree[_T]"] + | list["JSONTree[_T]"] + | tuple["JSONTree[_T]", ...] + | dict[str, _T] + | list[_T] + | tuple[_T, ...] + | _T +) +""" +Same as `JSONTree` but with additional `Union` members to satisfy overloads. +""" + + +def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]: + """Iterate through each leaf in a nested JSON structure.""" + if isinstance(value, dict): + for v in value.values(): + yield from json_iter_leaves(v) + elif isinstance(value, (list, tuple)): + for v in value: + yield from json_iter_leaves(v) + else: + yield value + + +@overload +def json_map_leaves( + func: Callable[["torch.Tensor"], "torch.Tensor"], + value: "BatchedTensorInputs", +) -> "BatchedTensorInputs": ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: _T | dict[str, _T], +) -> _U | dict[str, _U]: ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: _T | list[_T], +) -> _U | list[_U]: ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: _T | tuple[_T, ...], +) -> _U | tuple[_U, ...]: ... + + +@overload +def json_map_leaves( + func: Callable[[_T], _U], + value: JSONTree[_T], +) -> JSONTree[_U]: ... + + +def json_map_leaves( + func: Callable[[_T], _U], + value: "BatchedTensorInputs" | _JSONTree[_T], +) -> "BatchedTensorInputs" | _JSONTree[_U]: + """Apply a function to each leaf in a nested JSON structure.""" + if isinstance(value, dict): + return { + k: json_map_leaves(func, v) # type: ignore[arg-type] + for k, v in value.items() + } + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: _T | dict[str, _T], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: _T | list[_T], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: _T | tuple[_T, ...], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_T, _T], _T], + value: JSONTree[_T], + /, +) -> _T: ... + + +@overload +def json_reduce_leaves( + func: Callable[[_U, _T], _U], + value: JSONTree[_T], + initial: _U, + /, +) -> _U: ... + + +def json_reduce_leaves( + func: Callable[..., _T | _U], + value: _JSONTree[_T], + initial: _U = cast(_U, ...), # noqa: B008 + /, +) -> _T | _U: + """ + Apply a function of two arguments cumulatively to each leaf in a + nested JSON structure, from left to right, so as to reduce the + sequence to a single value. + """ + if initial is ...: + return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type] + + return reduce( + func, # type: ignore[arg-type] + json_iter_leaves(value), + initial, + ) + + +def json_count_leaves(value: JSONTree[_T]) -> int: + """Count the number of leaves in a nested JSON structure.""" + return sum(1 for _ in json_iter_leaves(value)) diff --git a/vllm/_utils/math_utils.py b/vllm/_utils/math_utils.py new file mode 100644 index 0000000..bdfa5fd --- /dev/null +++ b/vllm/_utils/math_utils.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Math utility functions for vLLM.""" + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +def next_power_of_2(n: int) -> int: + """The next power of 2 (inclusive)""" + if n < 1: + return 1 + return 1 << (n - 1).bit_length() + + +def prev_power_of_2(n: int) -> int: + """The previous power of 2 (inclusive)""" + if n <= 0: + return 0 + return 1 << (n.bit_length() - 1) + + +def round_up(x: int, y: int) -> int: + """Round up x to the nearest multiple of y.""" + return ((x + y - 1) // y) * y + + +def round_down(x: int, y: int) -> int: + """Round down x to the nearest multiple of y.""" + return (x // y) * y diff --git a/vllm/_utils/mem_constants.py b/vllm/_utils/mem_constants.py new file mode 100644 index 0000000..62b725f --- /dev/null +++ b/vllm/_utils/mem_constants.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +MB_bytes = 1_000_000 +"""The number of bytes in one megabyte (MB).""" + +MiB_bytes = 1 << 20 +"""The number of bytes in one mebibyte (MiB).""" + +GB_bytes = 1_000_000_000 +"""The number of bytes in one gigabyte (GB).""" + +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" diff --git a/vllm/_utils/mem_utils.py b/vllm/_utils/mem_utils.py new file mode 100644 index 0000000..c6a6757 --- /dev/null +++ b/vllm/_utils/mem_utils.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import gc +import time +from collections.abc import Generator +from dataclasses import dataclass, field +from functools import cache + +import psutil +import torch +import torch.types + +from .mem_constants import GiB_bytes + + +@cache +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops + + max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +class DeviceMemoryProfiler: + def __init__(self, device: torch.types.Device | None = None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + from vllm.platforms import current_platform + + gc.collect() + return current_platform.get_current_memory_usage(self.device) + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + + torch_peak: int = 0 + free_memory: int = 0 + total_memory: int = 0 + cuda_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 + timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() + + def measure(self): + from vllm.platforms import current_platform + + # we measure the torch peak memory usage via allocated_bytes, + # rather than `torch.cuda.memory_reserved()` . + # After `torch.cuda.reset_peak_memory_stats()`, + # `torch.cuda.memory_reserved()` will keep growing, and only shrink + # when we call `torch.cuda.empty_cache()` or OOM happens. + self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0) + + self.free_memory, self.total_memory = torch.cuda.mem_get_info() + shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark + if ( + current_platform.is_cuda() + and current_platform.get_device_capability() in shared_sysmem_device_mem_sms + ): + # On UMA (Orin, Thor and Spark) platform, + # where both CPU and GPU rely on system memory, + # the cudaMemGetInfo function shows the amount of free system memory + # rather than what’s actually available. + # In the case, + # torch.cuda.mem_get_info() only reports "free" memory, + # which can be lower than what is actually + # available due to not including cache memory. + # There’s also a comprehensive reference page + # that explains how you can compute the proper value yourself. + # https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/#estimating-total-allocatable-device-memory-on-an-integrated-gpu-device + self.free_memory = psutil.virtual_memory().available + + self.cuda_memory = self.total_memory - self.free_memory + + # torch.cuda.memory_reserved() is how many bytes + # PyTorch gets from cuda (by calling cudaMalloc, etc.) + # this is used to measure the non-torch memory usage + self.torch_memory = torch.cuda.memory_reserved() + + self.non_torch_memory = self.cuda_memory - self.torch_memory + self.timestamp = time.time() + + def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": + return MemorySnapshot( + torch_peak=self.torch_peak - other.torch_peak, + free_memory=self.free_memory - other.free_memory, + total_memory=self.total_memory - other.total_memory, + cuda_memory=self.cuda_memory - other.cuda_memory, + torch_memory=self.torch_memory - other.torch_memory, + non_torch_memory=self.non_torch_memory - other.non_torch_memory, + timestamp=self.timestamp - other.timestamp, + auto_measure=False, + ) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. All numbers are in bytes.""" + + non_kv_cache_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + weights_memory: float = 0 + before_create: MemorySnapshot = field(default_factory=MemorySnapshot) + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + def __repr__(self) -> str: + return ( + f"Memory profiling takes {self.profile_time:.2f} seconds. " + f"Total non KV cache memory: " + f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " + f"torch peak memory increase: " + f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " + f"non-torch forward increase memory: " + f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " + f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB." + ) + + +@contextlib.contextmanager +def memory_profiling( + baseline_snapshot: MemorySnapshot, weights_memory: int +) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). + """ # noqa + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.before_create = baseline_snapshot + # the part of memory used for holding the model weights + result.weights_memory = weights_memory + + result.before_profile.measure() + + yield result + + gc.collect() + torch.cuda.empty_cache() + + result.after_profile.measure() + + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + + non_torch_memory = result.non_torch_increase + peak_activation_memory = result.torch_peak_increase + result.non_kv_cache_memory = ( + non_torch_memory + peak_activation_memory + result.weights_memory + ) # noqa diff --git a/vllm/_utils/nccl.py b/vllm/_utils/nccl.py new file mode 100644 index 0000000..b1459fc --- /dev/null +++ b/vllm/_utils/nccl.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import importlib +import os + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def find_nccl_library() -> str: + """Return NCCL/RCCL shared library name to load. + + Uses `VLLM_NCCL_SO_PATH` if set; otherwise chooses by torch backend. + """ + so_file = envs.VLLM_NCCL_SO_PATH + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", so_file + ) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.debug_once("Found nccl from library %s", so_file) + return so_file + + +def find_nccl_include_paths() -> list[str] | None: + """Return possible include paths containing `nccl.h`. + + Considers `VLLM_NCCL_INCLUDE_PATH` and the `nvidia-nccl-cuXX` package. + """ + paths: list[str] = [] + inc = envs.VLLM_NCCL_INCLUDE_PATH + if inc and os.path.isdir(inc): + paths.append(inc) + + try: + spec = importlib.util.find_spec("nvidia.nccl") + if spec and getattr(spec, "submodule_search_locations", None): + for loc in spec.submodule_search_locations: + inc_dir = os.path.join(loc, "include") + if os.path.exists(os.path.join(inc_dir, "nccl.h")): + paths.append(inc_dir) + except Exception as e: + logger.debug("Failed to find nccl include path from nvidia.nccl package: %s", e) + + seen: set[str] = set() + out: list[str] = [] + for p in paths: + if p and p not in seen: + out.append(p) + seen.add(p) + return out or None diff --git a/vllm/_utils/network_utils.py b/vllm/_utils/network_utils.py new file mode 100644 index 0000000..0a68e48 --- /dev/null +++ b/vllm/_utils/network_utils.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import ipaddress +import os +import socket +import sys +import warnings +from collections.abc import ( + Iterator, + Sequence, +) +from typing import Any +from urllib.parse import urlparse +from uuid import uuid4 + +import psutil +import zmq +import zmq.asyncio + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]): + for sock in sockets: + if sock is not None: + sock.close(linger=0) + + +def get_ip() -> str: + host_ip = envs.VLLM_HOST_IP + if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ: + logger.warning( + "The environment variable HOST_IP is deprecated and ignored, as" + " it is often used by Docker and other software to" + " interact with the container's network stack. Please " + "use VLLM_HOST_IP instead to set the IP address for vLLM processes" + " to communicate with each other." + ) + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s: + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def test_loopback_bind(address, family): + try: + s = socket.socket(family, socket.SOCK_DGRAM) + s.bind((address, 0)) # Port 0 = auto assign + s.close() + return True + except OSError: + return False + + +def get_loopback_ip() -> str: + loopback_ip = envs.VLLM_LOOPBACK_IP + if loopback_ip: + return loopback_ip + + # VLLM_LOOPBACK_IP is not set, try to get it based on network interface + + if test_loopback_bind("127.0.0.1", socket.AF_INET): + return "127.0.0.1" + elif test_loopback_bind("::1", socket.AF_INET6): + return "::1" + else: + raise RuntimeError( + "Neither 127.0.0.1 nor ::1 are bound to a local interface. " + "Set the VLLM_LOOPBACK_IP environment variable explicitly." + ) + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def split_host_port(host_port: str) -> tuple[str, int]: + # ipv6 + if host_port.startswith("["): + host, port = host_port.rsplit("]", 1) + host = host[1:] + port = port.split(":")[1] + return host, int(port) + else: + host, port = host_port.split(":") + return host, int(port) + + +def join_host_port(host: str, port: int) -> str: + if is_valid_ipv6_address(host): + return f"[{host}]:{port}" + else: + return f"{host}:{port}" + + +def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: + if is_valid_ipv6_address(ip): + return f"tcp://[{ip}]:{port}" + else: + return f"tcp://{ip}:{port}" + + +def get_open_zmq_ipc_path() -> str: + base_rpc_path = envs.VLLM_RPC_BASE_PATH + return f"ipc://{base_rpc_path}/{uuid4()}" + + +def get_open_zmq_inproc_path() -> str: + return f"inproc://{uuid4()}" + + +def get_open_port() -> int: + """ + Get an open port for the vLLM process to listen on. + An edge case to handle, is when we run data parallel, + we need to avoid ports that are potentially used by + the data parallel master process. + Right now we reserve 10 ports for the data parallel master + process. Currently it uses 2 ports. + """ + if "VLLM_DP_MASTER_PORT" in os.environ: + dp_master_port = envs.VLLM_DP_MASTER_PORT + reserved_port_range = range(dp_master_port, dp_master_port + 10) + while True: + candidate_port = _get_open_port() + if candidate_port not in reserved_port_range: + return candidate_port + return _get_open_port() + + +def get_open_ports_list(count: int = 5) -> list[int]: + """Get a list of open ports.""" + ports = set[int]() + while len(ports) < count: + ports.add(get_open_port()) + return list(ports) + + +def _get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def find_process_using_port(port: int) -> psutil.Process | None: + # TODO: We can not check for running processes with network + # port on macOS. Therefore, we can not have a full graceful shutdown + # of vLLM. For now, let's not look for processes in this case. + # Ref: https://www.florianreinhard.de/accessdenied-in-psutil/ + if sys.platform.startswith("darwin"): + return None + + our_pid = os.getpid() + for conn in psutil.net_connections(): + if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid): + try: + return psutil.Process(conn.pid) + except psutil.NoSuchProcess: + return None + return None + + +def split_zmq_path(path: str) -> tuple[str, str, str]: + """Split a zmq path into its parts.""" + parsed = urlparse(path) + if not parsed.scheme: + raise ValueError(f"Invalid zmq path: {path}") + + scheme = parsed.scheme + host = parsed.hostname or "" + port = str(parsed.port or "") + + if scheme == "tcp" and not all((host, port)): + # The host and port fields are required for tcp + raise ValueError(f"Invalid zmq path: {path}") + + if scheme != "tcp" and port: + # port only makes sense with tcp + raise ValueError(f"Invalid zmq path: {path}") + + return scheme, host, port + + +def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str: + """Make a ZMQ path from its parts. + + Args: + scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc). + host: The host - can be an IPv4 address, IPv6 address, or hostname. + port: Optional port number, only used for TCP sockets. + + Returns: + A properly formatted ZMQ path string. + """ + if port is None: + return f"{scheme}://{host}" + if is_valid_ipv6_address(host): + return f"{scheme}://[{host}]:{port}" + return f"{scheme}://{host}:{port}" + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501 +def make_zmq_socket( + ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined] + path: str, + socket_type: Any, + bind: bool | None = None, + identity: bytes | None = None, + linger: int | None = None, +) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined] + """Make a ZMQ socket with the proper bind/connect semantics.""" + + mem = psutil.virtual_memory() + socket = ctx.socket(socket_type) + + # Calculate buffer size based on system memory + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + # For systems with substantial memory (>32GB total, >16GB available): + # - Set a large 0.5GB buffer to improve throughput + # For systems with less memory: + # - Use system default (-1) to avoid excessive memory consumption + buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1 + + if bind is None: + bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) + + if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + if identity is not None: + socket.setsockopt(zmq.IDENTITY, identity) + + if linger is not None: + socket.setsockopt(zmq.LINGER, linger) + + if socket_type == zmq.XPUB: + socket.setsockopt(zmq.XPUB_VERBOSE, True) + + # Determine if the path is a TCP socket with an IPv6 address. + # Enable IPv6 on the zmq socket if so. + scheme, host, _ = split_zmq_path(path) + if scheme == "tcp" and is_valid_ipv6_address(host): + socket.setsockopt(zmq.IPV6, 1) + + if bind: + socket.bind(path) + else: + socket.connect(path) + + return socket + + +@contextlib.contextmanager +def zmq_socket_ctx( + path: str, + socket_type: Any, + bind: bool | None = None, + linger: int = 0, + identity: bytes | None = None, +) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx = zmq.Context() # type: ignore[attr-defined] + try: + yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity) + except KeyboardInterrupt: + logger.debug("Got Keyboard Interrupt.") + + finally: + ctx.destroy(linger=linger) diff --git a/vllm/_utils/platform_utils.py b/vllm/_utils/platform_utils.py new file mode 100644 index 0000000..433c673 --- /dev/null +++ b/vllm/_utils/platform_utils.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import multiprocessing +from collections.abc import Sequence +from concurrent.futures.process import ProcessPoolExecutor +from functools import cache +from typing import Any + +import torch + + +def cuda_is_initialized() -> bool: + """Check if CUDA is initialized.""" + if not torch.cuda._is_compiled(): + return False + return torch.cuda.is_initialized() + + +def xpu_is_initialized() -> bool: + """Check if XPU is initialized.""" + if not torch.xpu._is_compiled(): + return False + return torch.xpu.is_initialized() + + +def get_cu_count(device_id: int = 0) -> int: + """Returns the total number of compute units (CU) on single GPU.""" + return torch.cuda.get_device_properties(device_id).multi_processor_count + + +def cuda_get_device_properties( + device, names: Sequence[str], init_cuda=False +) -> tuple[Any, ...]: + """Get specified CUDA device property values without initializing CUDA in + the current process.""" + if init_cuda or cuda_is_initialized(): + props = torch.cuda.get_device_properties(device) + return tuple(getattr(props, name) for name in names) + + # Run in subprocess to avoid initializing CUDA as a side effect. + mp_ctx = multiprocessing.get_context("fork") + with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor: + return executor.submit(cuda_get_device_properties, device, names, True).result() + + +@cache +def is_pin_memory_available() -> bool: + from vllm.platforms import current_platform + + return current_platform.is_pin_memory_available() + + +@cache +def is_uva_available() -> bool: + """Check if Unified Virtual Addressing (UVA) is available.""" + # UVA requires pinned memory. + # TODO: Add more requirements for UVA if needed. + return is_pin_memory_available() diff --git a/vllm/_utils/profiling.py b/vllm/_utils/profiling.py new file mode 100644 index 0000000..b669106 --- /dev/null +++ b/vllm/_utils/profiling.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +from collections.abc import Callable +from functools import wraps +from typing import Any + + +@contextlib.contextmanager +def cprofile_context(save_file: str | None = None): + """Run a cprofile + + Args: + save_file: path to save the profile result. "1" or + None will result in printing to stdout. + """ + import cProfile + + prof = cProfile.Profile() + prof.enable() + + try: + yield + finally: + prof.disable() + if save_file and save_file != "1": + prof.dump_stats(save_file) + else: + prof.print_stats(sort="cumtime") + + +def cprofile(save_file: str | None = None, enabled: bool = True): + """Decorator to profile a Python method using cProfile. + + Args: + save_file: Path to save the profile result. + If "1", None, or "", results will be printed to stdout. + enabled: Set to false to turn this into a no-op + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args: Any, **kwargs: Any): + if not enabled: + # If profiling is disabled, just call the function directly. + return func(*args, **kwargs) + + with cprofile_context(save_file): + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/vllm/_utils/registry.py b/vllm/_utils/registry.py new file mode 100644 index 0000000..ac9b859 --- /dev/null +++ b/vllm/_utils/registry.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + + +class ExtensionManager: + """ + A registry for managing pluggable extension classes. + + This class provides a simple mechanism to register and instantiate + extension classes by name. It is commonly used to implement plugin + systems where different implementations can be swapped at runtime. + + Examples: + Basic usage with a registry instance: + + >>> FOO_REGISTRY = ExtensionManager() + >>> @FOO_REGISTRY.register("my_foo_impl") + ... class MyFooImpl(Foo): + ... def __init__(self, value): + ... self.value = value + >>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123) + + """ + + def __init__(self) -> None: + """ + Initialize an empty extension registry. + """ + self.name2class: dict[str, type] = {} + + def register(self, name: str): + """ + Decorator to register a class with the given name. + """ + + def wrap(cls_to_register): + self.name2class[name] = cls_to_register + return cls_to_register + + return wrap + + def load(self, cls_name: str, *args, **kwargs) -> Any: + """ + Instantiate and return a registered extension class by name. + """ + cls = self.name2class.get(cls_name) + assert cls is not None, f"Extension class {cls_name} not found" + return cls(*args, **kwargs) diff --git a/vllm/_utils/serial_utils.py b/vllm/_utils/serial_utils.py new file mode 100644 index 0000000..b89fa6c --- /dev/null +++ b/vllm/_utils/serial_utils.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import base64 +import sys +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch +from typing_extensions import assert_never + +from vllm import PoolingRequestOutput + +sys_byteorder = sys.byteorder + + +EMBED_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + # I'm not sure if other platforms' CPUs support the fp8 data format. + # EMBED_DTYPE only uses the fp8 data representation, + # does not use fp8 computation, and only occurs on the CPU. + # Apologize for any possible break. + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, +} + + +EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = { + "float32": torch.float32, + "float16": torch.float16, + # numpy does not support bfloat16 and fp8 + "bfloat16": torch.float16, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + +EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = { + "float32": np.float32, + "float16": np.float16, + # numpy does not support bfloat16 and fp8 + "bfloat16": np.float16, + "fp8_e4m3": np.uint8, + "fp8_e5m2": np.uint8, +} + +ENDIANNESS = ["native", "big", "little"] + +EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"] +Endianness = Literal["native", "big", "little"] +EncodingFormat = Literal["float", "base64", "bytes"] + + +def tensor2binary( + tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness +) -> bytes: + assert isinstance(tensor, torch.Tensor) + assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE + assert endianness in ENDIANNESS + + torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] + torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype] + + np_array = ( + tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy() + ) + + if endianness != "native" and endianness != sys_byteorder: + np_array = np_array.byteswap() + + return np_array.tobytes() + + +def binary2tensor( + binary: bytes, + shape: tuple[int, ...], + embed_dtype: EmbedDType, + endianness: Endianness, +) -> torch.Tensor: + assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE + assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW + assert endianness in ENDIANNESS + + torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype] + np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype] + + np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape) + + if endianness != "native" and endianness != sys_byteorder: + np_array = np_array.byteswap() + + return torch.from_numpy(np_array).view(torch_dtype) + + +def encode_pooling_output( + output: PoolingRequestOutput, + encoding_format: EncodingFormat, + embed_dtype: EmbedDType, + endianness: Endianness, +) -> list[float] | str | bytes: + if encoding_format == "float": + return output.outputs.data.tolist() + elif encoding_format == "base64": + embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness) + return base64.b64encode(embedding_bytes).decode("utf-8") + elif encoding_format == "bytes": + return tensor2binary(output.outputs.data, embed_dtype, endianness) + assert_never(encoding_format) + + +@dataclass +class MetadataItem: + index: int + embed_dtype: EmbedDType + endianness: Endianness + start: int + end: int + shape: tuple[int, ...] + + +def encode_pooling_bytes( + pooling_outputs: list[PoolingRequestOutput], + embed_dtype: EmbedDType, + endianness: Endianness, +): + num_prompt_tokens = 0 + items: list[dict[str, MetadataItem]] = [] + body = [] + offset = 0 + for idx, output in enumerate(pooling_outputs): + binary = tensor2binary( + tensor=output.outputs.data, + embed_dtype=embed_dtype, + endianness=endianness, + ) + size = len(binary) + + item = { + "index": idx, + "embed_dtype": embed_dtype, + "endianness": endianness, + "start": offset, + "end": offset + size, + "shape": output.outputs.data.shape, + } + + body.append(binary) + items.append(item) + prompt_token_ids = output.prompt_token_ids + num_prompt_tokens += len(prompt_token_ids) + offset += size + + usage = { + "prompt_tokens": num_prompt_tokens, + "total_tokens": num_prompt_tokens, + } + return body, items, usage + + +def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]: + items.sort(key=lambda x: x.index) + + tensor_list: list[torch.Tensor] = [] + for item in items: + binary = body[item.start : item.end] + tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness) + tensor_list.append(tensor) + return tensor_list diff --git a/vllm/_utils/system_utils.py b/vllm/_utils/system_utils.py new file mode 100644 index 0000000..5968884 --- /dev/null +++ b/vllm/_utils/system_utils.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +import contextlib +import multiprocessing +import os +import signal +import sys +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import TextIO + +import psutil + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.ray.lazy_utils import is_in_ray_actor + +from .platform_utils import cuda_is_initialized, xpu_is_initialized + +logger = init_logger(__name__) + +CYAN = "\033[1;36m" +RESET = "\033[0;0m" + + +# Environment variable utilities + + +def update_environment_variables(envs_dict: dict[str, str]): + """Update multiple environment variables with logging.""" + for k, v in envs_dict.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + +@contextlib.contextmanager +def set_env_var(key: str, value: str) -> Iterator[None]: + """Temporarily set an environment variable.""" + old = os.environ.get(key) + os.environ[key] = value + try: + yield + finally: + if old is None: + os.environ.pop(key, None) + else: + os.environ[key] = old + + +# File path utilities + + +def unique_filepath(fn: Callable[[int], Path]) -> Path: + """Generate a unique file path by trying incrementing integers. + + Note: This function has a TOCTOU race condition. + Caller should use atomic operations (e.g., open with 'x' mode) + when creating the file to ensure thread safety. + """ + i = 0 + while True: + p = fn(i) + if not p.exists(): + return p + i += 1 + + +# Process management utilities + + +def _maybe_force_spawn(): + """Check if we need to force the use of the `spawn` multiprocessing start + method. + """ + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": + return + + reasons = [] + if is_in_ray_actor(): + # even if we choose to spawn, we need to pass the ray address + # to the subprocess so that it knows how to connect to the ray cluster. + # env vars are inherited by subprocesses, even if we use spawn. + import ray + + os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address + reasons.append("In a Ray actor and can only be spawned") + + if cuda_is_initialized(): + reasons.append("CUDA is initialized") + elif xpu_is_initialized(): + reasons.append("XPU is initialized") + + if reasons: + logger.warning( + "We must use the `spawn` multiprocessing start method. " + "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " + "See https://docs.vllm.ai/en/latest/usage/" + "troubleshooting.html#python-multiprocessing " + "for more information. Reasons: %s", + "; ".join(reasons), + ) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def get_mp_context(): + """Get a multiprocessing context with a particular method (spawn or fork). + By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to + determine the multiprocessing method (default is fork). However, under + certain conditions, we may enforce spawn and override the value of + VLLM_WORKER_MULTIPROC_METHOD. + """ + _maybe_force_spawn() + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) + + +def set_process_title( + name: str, + suffix: str = "", + prefix: str = envs.VLLM_PROCESS_NAME_PREFIX, +) -> None: + """Set the current process title with optional suffix.""" + try: + import setproctitle + except ImportError: + return + + if suffix: + name = f"{name}_{suffix}" + + setproctitle.setproctitle(f"{prefix}::{name}") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Add colored prefix to file output for log decoration.""" + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find("\n", idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def decorate_logs(process_name: str | None = None) -> None: + """Decorate stdout/stderr with process name and PID prefix.""" + if process_name is None: + process_name = get_mp_context().current_process().name + + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + +def kill_process_tree(pid: int): + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) + + +# Resource utilities + + +# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 +def set_ulimit(target_soft_limit: int = 65535): + if sys.platform.startswith("win"): + logger.info("Windows detected, skipping ulimit adjustment.") + return + + import resource + + resource_type = resource.RLIMIT_NOFILE + current_soft, current_hard = resource.getrlimit(resource_type) + + if current_soft < target_soft_limit: + try: + resource.setrlimit(resource_type, (target_soft_limit, current_hard)) + except ValueError as e: + logger.warning( + "Found ulimit of %s and failed to automatically increase " + "with error %s. This can cause fd limit errors like " + "`OSError: [Errno 24] Too many open files`. Consider " + "increasing with ulimit -n", + current_soft, + e, + ) diff --git a/vllm/_utils/tensor_schema.py b/vllm/_utils/tensor_schema.py new file mode 100644 index 0000000..526dfd3 --- /dev/null +++ b/vllm/_utils/tensor_schema.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import UnionType +from typing import Annotated, Any, Union, get_args, get_origin, get_type_hints + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class TensorShape: + def __init__( + self, + *dims: int | str, + dynamic_dims: set[str] | None = None, + ) -> None: + super().__init__() + + self.dims = dims + self.dynamic_dims = dynamic_dims if dynamic_dims else set() + + def resolve(self, **bindings: int) -> tuple[int | str, ...]: + resolved = list[int | str]() + for dim in self.dims: + if isinstance(dim, str) and dim in bindings: + resolved.append(bindings[dim]) + else: + resolved.append(dim) + return tuple(resolved) + + def __str__(self) -> str: + """Return a string representation of the tensor shape.""" + dim_strs = [] + for dim in self.dims: + if isinstance(dim, str): + if dim in self.dynamic_dims: + dim_strs.append(f"{dim}*") # Mark dynamic dimensions with * + else: + dim_strs.append(dim) + else: + dim_strs.append(str(dim)) + return f"({', '.join(dim_strs)})" + + +class TensorSchema: + def __init__( + self, + *, + validate: bool = True, + resolve_bindings: dict[str, int] | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + + self._resolve_bindings = resolve_bindings if resolve_bindings else {} + + for key, value in kwargs.items(): + setattr(self, key, value) + + if validate: + self.validate() + + def __getitem__(self, key: str) -> Any: + return getattr(self, key) + + def get(self, key: str, default: Any = None) -> Any: + return getattr(self, key, default) + + def _match_shape_with_dynamic( + self, + actual: tuple[int, ...], + reference: tuple[int, ...], + expected_shape: tuple[int | str, ...], + dynamic_dims: set[str], + ) -> bool: + if len(actual) != len(reference) or len(actual) > len(expected_shape): + return False + + for i, (a, r) in enumerate(zip(actual, reference)): + # When validating list inputs, we match shape suffixes only + # (e.g. "p", 3, "h", "w"), assuming the list length corresponds + # to the leading symbolic dim (e.g. "bn"). This allows comparing + # only the trailing dimensions of each element in the list. + dim = expected_shape[-len(actual) + i] + # Skip this dimension if it's marked dynamic + if dim in dynamic_dims: + continue + if a != r: + return False + return True + + def _fmt_indexer(self, idxs: tuple[int, ...]) -> str: + if not idxs: + return "" + + return str(list(idxs)) + + def _validate_field( + self, + value: object, + field_name: str, + expected_shape: tuple[int | str, ...], + dynamic_dims: set[str], + leading_idxs: tuple[int, ...] = (), + ) -> tuple[int, ...]: + """Validate a field and return the actual shape.""" + if isinstance(value, (int, float)): + return () # Scalar + if isinstance(value, torch.Tensor): + return value.shape + + if not isinstance(value, (list, tuple)): + raise TypeError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is not " + f"one of the expected types: int, float, Tensor, list, tuple. " + f"Got: {type(value)}" + ) + + if len(value) == 0: + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence" + ) + + # Ensure all tensors in the list have the same + # shape, besides dynamic dimensions + for i, v in enumerate(value): + shape = self._validate_field( + v, + field_name, + expected_shape[1:], + dynamic_dims, + leading_idxs=leading_idxs + (i,), + ) + + if i == 0: + first_shape = shape + elif not self._match_shape_with_dynamic( + shape, + first_shape, + expected_shape, + dynamic_dims, + ): + raise ValueError( + f"{field_name}{self._fmt_indexer(leading_idxs)} " + f"contains inconsistent shapes: {first_shape} " + f"(index 0) vs {shape} (index {i})" + ) + + # Treat the list as a stacked tensor: + # shape = (len(list), *tensor.shape) + return (len(value),) + first_shape + + def _validate_tensor_shape_expected( + self, + actual_shape: tuple[int, ...], + expected_shape: tuple[int | str, ...], + field_name: str, + shape_env: dict[str, int], + dynamic_dims: set[str], + ) -> None: + """Validate that the actual tensor shape matches the expected shape.""" + + if len(actual_shape) != len(expected_shape): + raise ValueError( + f"{field_name} has rank {len(actual_shape)} " + f"but expected {len(expected_shape)}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) + + for i, dim in enumerate(expected_shape): + if dim in dynamic_dims: + continue + elif isinstance(dim, int): + if actual_shape[i] != dim: + raise ValueError( + f"{field_name} dim[{i}] expected " + f"{dim}, got {actual_shape[i]}. " + f"Expected shape: {expected_shape}, " + f"but got {actual_shape}" + ) + elif isinstance(dim, str): + if dim in shape_env: + if actual_shape[i] != shape_env[dim]: + raise ValueError( + f"{field_name} dim[{i}] expected " + f"'{dim}'={shape_env[dim]}, got " + f"{actual_shape[i]}" + ) + else: + shape_env[dim] = actual_shape[i] + else: + raise TypeError( + f"{field_name} dim[{i}] has unsupported type: {type(dim)}" + ) + + def validate(self) -> None: + type_hints = get_type_hints(self.__class__, include_extras=True) + shape_env = dict[str, int]() + + for field_name, field_type in type_hints.items(): + # Check if field is missing + if not hasattr(self, field_name) or getattr(self, field_name) is None: + # Check if field is marked as optional + actual_type = field_type + if get_origin(field_type) is Annotated: + args = get_args(field_type) + actual_type = args[0] + + # Check arg was provided as Union + if get_origin(actual_type) in {Union, UnionType}: + # Union for Union[X, Y] and UnionType for X | Y + args = get_args(actual_type) + # Skip validation when Union contains None + if type(None) in args: + continue + # Otherwise field is required, raise error + raise ValueError(f"Required field '{field_name}' is missing") + + # Field exists, proceed with validation + value = getattr(self, field_name) + if get_origin(field_type) is not None: + args = get_args(field_type) + + for arg in args: + if isinstance(arg, TensorShape): + expected_shape = arg.resolve(**self._resolve_bindings) + actual_shape = self._validate_field( + value, + field_name, + expected_shape, + arg.dynamic_dims, + ) + + self._validate_tensor_shape_expected( + actual_shape, + expected_shape, + field_name, + shape_env, + arg.dynamic_dims, + ) + + def print_shapes(self) -> None: + """Print TensorShape annotations for debugging.""" + logger.debug("Shapes in %s:", self.__class__.__name__) + type_hints = get_type_hints(self.__class__, include_extras=True) + + for field_name, field_type in type_hints.items(): + if get_origin(field_type) is not None: + args = get_args(field_type) + for arg in args: + if isinstance(arg, TensorShape): + logger.debug(" %s: %s", field_name, str(arg)) diff --git a/vllm/_utils/torch_utils.py b/vllm/_utils/torch_utils.py new file mode 100644 index 0000000..7c094e1 --- /dev/null +++ b/vllm/_utils/torch_utils.py @@ -0,0 +1,657 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import importlib.metadata +import os +import threading +from collections.abc import Callable, Collection +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeVar + +import numpy as np +import numpy.typing as npt +import torch +from packaging import version +from packaging.version import Version +from torch.library import Library + +import vllm.envs as envs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.sequence import IntermediateTensors +else: + ModelConfig = object + IntermediateTensors = object + + +STR_DTYPE_TO_TORCH_DTYPE = { + "float32": torch.float32, + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, + "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + + +T = TypeVar("T") + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +@contextlib.contextmanager +def set_default_torch_num_threads(num_threads: int): + """Sets the default number of threads for PyTorch to the given value.""" + old_num_threads = torch.get_num_threads() + torch.set_num_threads(num_threads) + yield + torch.set_num_threads(old_num_threads) + + +@contextlib.contextmanager +def guard_cuda_initialization(): + """Avoid unexpected CUDA initialization.""" + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + yield + return + + had_key = "CUDA_VISIBLE_DEVICES" in os.environ + old_value = os.environ.get("CUDA_VISIBLE_DEVICES") + os.environ["CUDA_VISIBLE_DEVICES"] = "" + try: + yield + except Exception as e: + if "No CUDA GPUs are available" in str(e): + err_msg = "CUDA initialization is blocked." + else: + err_msg = str(e) + raise RuntimeError(err_msg) from e + finally: + if had_key: + os.environ["CUDA_VISIBLE_DEVICES"] = old_value + else: + os.environ.pop("CUDA_VISIBLE_DEVICES") + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# bool = 0, int = 1, float = 2, complex = 3 +def _get_precision_level(dtype: torch.dtype) -> int: + # NOTE: Complex dtypes return `is_floating_point=False` + return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 + + +def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): + """ + Test whether it is lossless to cast a tensor from + `src_dtype` to `tgt_dtype`. + """ + if src_dtype == tgt_dtype: + return True + + src_level = _get_precision_level(src_dtype) + tgt_level = _get_precision_level(tgt_dtype) + + if src_level < tgt_level: + return True + if src_level > tgt_level: + return False + + # Compare integral types + if not src_dtype.is_floating_point and not src_dtype.is_complex: + src_info = torch.iinfo(src_dtype) + tgt_info = torch.iinfo(tgt_dtype) + return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max + + # Compare floating-point types + src_info = torch.finfo(src_dtype) + tgt_info = torch.finfo(tgt_dtype) + return ( + src_info.min >= tgt_info.min + and src_info.max <= tgt_info.max + and src_info.resolution >= tgt_info.resolution + ) + + +def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): + """ + Get the common `dtype` where all of the other `dtypes` can be + cast to it without losing any information. + """ + return max( + dtypes, + key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), + ) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + # -----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, +) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def kv_cache_dtype_str_to_dtype( + kv_cache_dtype: str, model_config: ModelConfig +) -> torch.dtype: + if kv_cache_dtype == "auto": + # Model config may not be specified for unit tests, default to float16 + return model_config.dtype if model_config else torch.half + return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", + cache_layout: str | None = "NHD", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + assert cache_layout in ("NHD", "HND") + stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) + + kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) + scale = head_size**-0.5 + + key_caches: list[torch.Tensor] = [] + value_caches: list[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty( + size=kv_cache_allocation_shape, dtype=dtype, device=device + ).permute(*stride_order) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: str | torch.dtype | None, + model_dtype: str | torch.dtype | None = None, + seed: int | None = None, + device: str | None = "cuda", +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + from vllm.platforms import current_platform + + current_platform.seed_everything(seed) + + dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError(f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: list[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == "fp8": + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError(f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: str | torch.device, + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def make_ndarray_with_pad( + x: list[list[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: int | None = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, : len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: list[list[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: int | None = None, + device: str | torch.device | None = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +prev_set_stream = torch.cuda.set_stream + +_current_stream_tls = threading.local() + + +def _patched_set_stream(stream: torch.cuda.Stream) -> None: + _current_stream_tls.value = stream + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +class _StreamPlaceholder: + def __init__(self): + self.synchronize = lambda: None + + +def current_stream() -> torch.cuda.Stream: + """ + replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from vllm.platforms import current_platform + + if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + if current_platform.is_rocm(): + # torch.cuda.set_stream here is the alias of _pathed_set_stream + torch.cuda.set_stream(torch.cuda.Stream()) + elif current_platform.is_cpu(): + _current_stream_tls.value = _StreamPlaceholder() + else: + current_stream = current_platform.current_stream + if current_stream is not None: + _current_stream_tls.value = current_stream() + else: + raise ValueError( + "Fail to set current stream, current platform " + "may not support current_stream with torch API" + ) + return _current_stream_tls.value + + +# Global auxilary stream for running operations in background streams. +# We have single global auxilary stream to avoid an explosion of streams +# for every layer (and make profiling look sane). +# +# aux_stream() is currently used for: +# - MoE shared_expert overlap with router +_aux_stream: torch.cuda.Stream | None = None + + +def aux_stream() -> torch.cuda.Stream | None: + """ + Ensures aux_stream is initialized only once + """ + global _aux_stream + + from vllm.platforms import current_platform + + # TODO: validate this works properly on ROCm platform. + if _aux_stream is None and current_platform.is_cuda(): + _aux_stream = torch.cuda.Stream() + + return _aux_stream + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + from vllm.platforms import current_platform + + if not torch.cuda._is_compiled(): + return 0 + if current_platform.is_rocm(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = ( + torch.cuda._device_count_amdsmi() + if (hasattr(torch.cuda, "_device_count_amdsmi")) + else -1 + ) + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def weak_ref_tensor(tensor: Any) -> Any: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + if isinstance(tensor, torch.Tensor): + return torch.ops._C.weak_ref_tensor(tensor) + else: + return tensor + + +def weak_ref_tensors( + tensors: torch.Tensor + | list[torch.Tensor] + | tuple[torch.Tensor] + | IntermediateTensors, +) -> torch.Tensor | list[Any] | tuple[Any] | Any: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + + # For IntermediateTensors used in pipeline parallelism + from vllm.sequence import IntermediateTensors + + if isinstance(tensors, IntermediateTensors): + ret = IntermediateTensors( + {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} + ) + return ret + raise ValueError("Invalid type for tensors") + + +def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: + """ + Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). + """ + assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" + return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target) + + +def is_torch_equal_or_newer(target: str) -> bool: + """Check if the installed torch version is >= the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal_or_newer(str(torch.__version__), target) + except Exception: + # Fallback to PKG-INFO to load the package info, needed by the doc gen. + return Version(importlib.metadata.version("torch")) >= Version(target) + + +def _is_torch_equal(target: str) -> bool: + assert target.count(".") == 2 + torch_version = str(torch.__version__) + torch_version = version.parse(torch_version) + # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" + # or "2.6.0+cu128" but never "2.6.0.1" + return ( + torch_version >= version.parse(target) + and version.parse(target + ".1") > torch_version + ) + + +def is_torch_equal(target: str) -> bool: + """Check if the installed torch version is == the target version. + + Args: + target: a version string, like "2.6.0". + + Returns: + Whether the condition meets. + """ + try: + return _is_torch_equal(target) + except Exception: + return Version(importlib.metadata.version("torch")) == Version(target) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + return is_torch_equal_or_newer("2.4.0") + + +# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform +def supports_xccl() -> bool: + return ( + is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() + ) + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: list[str] | None = None, + fake_impl: Callable | None = None, + target_lib: Library | None = None, + dispatch_key: str | None = None, + tags: tuple[torch.Tag, ...] = (), +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if not supports_custom_op(): + from vllm.platforms import current_platform + + assert not current_platform.is_cuda_alike(), ( + "cuda platform needs torch>=2.4 to support custom op, " + "chances are you are using an old version of pytorch " + "or a custom build of pytorch. It is recommended to " + "use vLLM in a fresh new environment and let it install " + "the required dependencies." + ) + return + + if mutates_args is None: + mutates_args = [] + + if dispatch_key is None: + from vllm.platforms import current_platform + + dispatch_key = current_platform.dispatch_key + + import torch.library + + if hasattr(torch.library, "infer_schema"): + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + else: + # for pytorch 2.4 + import torch._custom_op.impl + + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py index e69de29..0d1135c 100644 --- a/vllm/compilation/__init__.py +++ b/vllm/compilation/__init__.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import uuid +import warnings +from typing import Any, TypeVar + +import torch + +from vllm.logger import init_logger + +_DEPRECATED_MAPPINGS = { + "cprofile": "profiling", + "cprofile_context": "profiling", + # Used by lm-eval + "get_open_port": "network_utils", +} + + +def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring + """Module-level getattr to handle deprecated utilities.""" + if name in _DEPRECATED_MAPPINGS: + submodule_name = _DEPRECATED_MAPPINGS[name] + warnings.warn( + f"vllm.utils.{name} is deprecated and will be removed in a future version. " + f"Use vllm.utils.{submodule_name}.{name} instead.", + DeprecationWarning, + stacklevel=2, + ) + module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name]) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + # expose deprecated names in dir() for better UX/tab-completion + return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys())) + + +logger = init_logger(__name__) + +# This value is chosen to have a balance between ITL and TTFT. Note it is +# not optimized for throughput. +DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 +POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + + +T = TypeVar("T") + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def length_from_prompt_token_ids_or_embeds( + prompt_token_ids: list[int] | None, + prompt_embeds: torch.Tensor | None, +) -> int: + """Calculate the request length (in number of tokens) give either + prompt_token_ids or prompt_embeds. + """ + prompt_token_len = None if prompt_token_ids is None else len(prompt_token_ids) + prompt_embeds_len = None if prompt_embeds is None else len(prompt_embeds) + + if prompt_token_len is None: + if prompt_embeds_len is None: + raise ValueError("Neither prompt_token_ids nor prompt_embeds were defined.") + return prompt_embeds_len + else: + if prompt_embeds_len is not None and prompt_embeds_len != prompt_token_len: + raise ValueError( + "Prompt token ids and prompt embeds had different lengths" + f" prompt_token_ids={prompt_token_len}" + f" prompt_embeds={prompt_embeds_len}" + ) + return prompt_token_len \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index aa13f22..330f4db 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -358,7 +358,7 @@ class ModelConfig: for multimodal models.""" use_async_output_proc: bool = True """Whether to use async output processor.""" - config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it will try to load in mistral format.\n @@ -522,8 +522,8 @@ class ModelConfig: raise ValueError( "Sleep mode is not supported on current platform.") - if isinstance(self.config_format, str): - self.config_format = ConfigFormat(self.config_format) + # if isinstance(self.config_format, str): + # self.config_format = ConfigFormat(self.config_format) hf_config = get_config(self.hf_config_path or self.model, self.trust_remote_code, self.revision, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4ce1b41..6f3f106 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -522,7 +522,6 @@ class EngineArgs: help="Disable async output processing. This may result in " "lower performance.") model_group.add_argument("--config-format", - choices=[f.value for f in ConfigFormat], **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 492a9c9..5166c42 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,25 +4,43 @@ import enum from enum import Enum from fractions import Fraction -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase -from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, + QuantizeMethodBase, +) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( - get_linear_quant_method) -from vllm.model_executor.parameter import (ChannelQuantScaleParameter, - GroupQuantScaleParameter, - PackedColumnParameter, - PackedvLLMParameter, - RowvLLMParameter) + get_linear_quant_method, +) +from vllm.model_executor.parameter import ( + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter, +) +from vllm.transformers_utils.config import get_safetensors_params_metadata +from vllm.utils import is_list_of + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper +else: + QuantizationMethods = str + +logger = init_logger(__name__) + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -35,7 +53,10 @@ class GPTQConfig(QuantizationConfig): group_size: int, desc_act: bool, lm_head_quantized: bool, - dynamic: dict[str, dict[str, Union[int, bool]]], + dynamic: dict[str, dict[str, int | bool]], + autoround_version: str = "", + modules_in_block_to_quantize: list[str] | None = None, + checkpoint_format: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -71,23 +92,44 @@ class GPTQConfig(QuantizationConfig): if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " - f"supported for GPTQ, but got {self.weight_bits} bits.") + f"supported for GPTQ, but got {self.weight_bits} bits." + ) + # Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future. + # For now, show a warning, since gptq_marlin will be used by default. + if self.weight_bits == 4: + logger.warning_once( + "Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. " + "Please switch to gptq_marlin or gptq_bitblas." + ) + + self.modules_in_block_to_quantize = modules_in_block_to_quantize or [] + + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version + + # GPTQ v1 and v2 format deals with zero points differently. + # Currently GPTQModel stores v1 format checkpoints by default, + # but provides the option to set `format="gptq_v2"` in `QuantizeConfig`. + self.checkpoint_format = checkpoint_format def __repr__(self) -> str: - return (f"GPTQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"desc_act={self.desc_act}), " - f"lm_head_quantized={self.lm_head_quantized}), " - f"dynamic={self.dynamic}") + return ( + f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}), " + f"lm_head_quantized={self.lm_head_quantized}, " + f"dynamic={self.dynamic}, " + f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), " + f"checkpoint_format={self.checkpoint_format})" + ) @classmethod def get_name(cls) -> QuantizationMethods: return "gptq" @classmethod - def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half, torch.bfloat16] + return [torch.half] @classmethod # Need to figure it out @@ -106,18 +148,77 @@ class GPTQConfig(QuantizationConfig): weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], - default=False) - return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or( + config, ["autoround_version"], default="" + ) + modules_in_block_to_quantize = cls.get_from_keys_or( + config, ["modules_in_block_to_quantize"], default=None + ) + checkpoint_format = cls.get_from_keys_or( + config, ["checkpoint_format"], default="" + ) + return cls( + weight_bits, + group_size, + desc_act, + lm_head_quantized, + dynamic, + autoround_version, + modules_in_block_to_quantize, + checkpoint_format, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None: + if isinstance(layer, FusedMoE): + # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility + from .moe_wna16 import MoeWNA16Config + + print("Using MoeWNA16Config for GPTQ MoE layer quantization.") + # TODO: maybe update this for GPTQv2 format checkpoints + config = { + "quant_method": "gptq", + "bits": self.weight_bits, + "group_size": self.group_size, + "sym": True, # GPTQ typically uses symmetric quantization + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQLinearMethod"]: return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_in_block_to_quantize is not None: + self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( + self.modules_in_block_to_quantize + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_in_block_to_quantize: + if is_list_of(self.modules_in_block_to_quantize, list): + # original modules_in_block_to_quantize: list[list[str]] + # flatten original modules_in_block_to_quantize + self.modules_in_block_to_quantize = [ + item + for sublist in self.modules_in_block_to_quantize + for item in sublist + ] + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_in_block_to_quantize = list(quant_layers) + class ExllamaState(Enum): - UNUSED = enum.auto() UNINITIALIZED = enum.auto() READY = enum.auto() @@ -133,6 +234,9 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config + # GPTQ v1 and v2 format deals with zero points differently + self.use_v2_format = quant_config.checkpoint_format == "gptq_v2" + def create_weights( self, layer: torch.nn.Module, @@ -149,14 +253,15 @@ class GPTQLinearMethod(LinearMethodBase): raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) output_size_per_partition = sum(output_partition_sizes) - if (output_size_per_partition % self.quant_config.pack_factor.numerator - != 0): + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " - "tensor parallel size.") + "tensor parallel size." + ) if self.quant_config.group_size != -1: group_size = self.quant_config.group_size @@ -165,8 +270,10 @@ class GPTQLinearMethod(LinearMethodBase): exllama_state = ExllamaState.UNINITIALIZED scale_and_zero_size = input_size // group_size scale_and_zero_input_dim = None - if (input_size != input_size_per_partition - and self.quant_config.group_size != -1): + if ( + input_size != input_size_per_partition + and self.quant_config.group_size != -1 + ): # For act-order models, we cannot use Exllama for row parallel layer if self.quant_config.desc_act: exllama_state = ExllamaState.UNUSED @@ -185,56 +292,56 @@ class GPTQLinearMethod(LinearMethodBase): output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, - weight_loader=weight_loader) + weight_loader=weight_loader, + ) - g_idx = RowvLLMParameter(data=torch.tensor( - [ - i // self.quant_config.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ), - input_dim=0, - weight_loader=weight_loader) + g_idx = RowvLLMParameter( + data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader, + ) qzeros_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } weight_scale_args = { - "data": - torch.empty( + "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), - "weight_loader": - weight_loader + "weight_loader": weight_loader, } if scale_and_zero_input_dim is None: - scales = ChannelQuantScaleParameter(output_dim=1, - **weight_scale_args) + scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) else: - scales = GroupQuantScaleParameter(output_dim=1, - input_dim=0, - **weight_scale_args) + scales = GroupQuantScaleParameter( + output_dim=1, input_dim=0, **weight_scale_args + ) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, - **qzeros_args) + **qzeros_args, + ) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) @@ -252,79 +359,23 @@ class GPTQLinearMethod(LinearMethodBase): # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if self.quant_config.group_size == 128 or self.quant_config.group_size == 64: + if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) + layer.g_idx.data = torch.empty( + (0,), dtype=torch.int, device=layer.g_idx.device + ) layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) - - if layer.scales.dtype != torch.bfloat16: - perm_space = torch.empty(0) - temp_space = torch.empty(0) - if self.quant_config.weight_bits == 4: - # warmup - reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda") - _ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits, - self.quant_config.group_size, - perm_space, temp_space, - False) - if self.quant_config.weight_bits == 8: - # warmup - reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda") - _ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits, - self.quant_config.group_size, - perm_space, temp_space, - False) - else: - if layer.exllama_state == ExllamaState.UNINITIALIZED: - if self.quant_config.desc_act: - layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) - else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) - layer.exllama_state = ExllamaState.READY - ops.gptq_shuffle(layer.qweight, layer.g_idx, - self.quant_config.weight_bits) + ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) - """ - perm_space = torch.empty(0) - if self.quant_config.weight_bits == 4: - # warmup - reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda") - _ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits, - self.quant_config.group_size, - perm_space) - if self.quant_config.weight_bits == 8: - # warmup - reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda") - _ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, - layer.scales, layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits, - self.quant_config.group_size, - perm_space) - """ - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1],) reshaped_x = x.reshape(-1, x.shape[-1]) perm_space = torch.empty(0) @@ -334,11 +385,12 @@ class GPTQLinearMethod(LinearMethodBase): if self.quant_config.desc_act: perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1], dtype=torch.float16, device="cuda") - + if reshaped_x.dtype == torch.bfloat16: temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1], dtype=torch.float32, device="cuda") - + # GPTQ v1 and v2 format checkpoints deals with zero points differently, + # and require different gemm kernels. output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, layer.scales, layer.g_idx, layer.exllama_state == ExllamaState.READY, @@ -348,4 +400,4 @@ class GPTQLinearMethod(LinearMethodBase): True if reshaped_x.dtype == torch.bfloat16 else False) if bias is not None: output.add_(bias) - return output.reshape(out_shape) + return output.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index daae03c..e6322ef 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -298,6 +298,10 @@ class MoeWNA16Method(FusedMoEMethodBase): e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts assert activation == "silu", "Only SiLU activation is supported." diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 823197f..22c09cd 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -122,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): self.gate = ReplicatedLinear(config.hidden_size, config.num_experts, bias=False, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.gate") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -294,7 +294,7 @@ class Qwen3MoeDecoderLayer(nn.Module): positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states @@ -532,4 +532,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 52a7a90..b7418cf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,51 +1,45 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum import json import os import time +from collections.abc import Callable +from dataclasses import asdict from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, Union +from typing import Any, Literal, TypeVar import huggingface_hub -from huggingface_hub import get_safetensors_metadata, hf_hub_download +from huggingface_hub import ( + get_safetensors_metadata, + hf_hub_download, + try_to_load_from_cache, +) from huggingface_hub import list_repo_files as hf_list_repo_files -from huggingface_hub import try_to_load_from_cache -from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, - HFValidationError, LocalEntryNotFoundError, - RepositoryNotFoundError, - RevisionNotFoundError) -from torch import nn -from transformers import GenerationConfig, PretrainedConfig -from transformers.models.auto.image_processing_auto import ( - get_image_processor_config) +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from transformers import DeepseekV3Config, GenerationConfig, PretrainedConfig +from transformers.models.auto.image_processing_auto import get_image_processor_config from transformers.models.auto.modeling_auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, + MODEL_MAPPING_NAMES, +) from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm import envs from vllm.logger import init_logger -# yapf conflicts with isort for this block -# yapf: disable -from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, - DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, - InternVLChatConfig, JAISConfig, - KimiVLConfig, MedusaConfig, - MiniMaxText01Config, - MiniMaxVL01Config, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, NVLM_D_Config, - OvisConfig, RWConfig, - SkyworkR1VChatConfig, SolarConfig, - Telechat2Config, UltravoxConfig) -# yapf: enable -from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import resolve_obj_by_qualname +from vllm.transformers_utils.config_parser_base import ConfigParserBase +from vllm.transformers_utils.utils import ( + check_gguf_file, + parse_safetensors_file_metadata, +) if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -56,43 +50,241 @@ MISTRAL_CONFIG_NAME = "params.json" logger = init_logger(__name__) -_CONFIG_REGISTRY_OVERRIDE_HF: dict[str, type[PretrainedConfig]] = { - "mllama": MllamaConfig + +def _get_hf_token() -> str | None: + """ + Get the HuggingFace token from environment variable. + + Returns None if the token is not set, is an empty string, + or contains only whitespace. + This follows the same pattern as huggingface_hub library which + treats empty string tokens as None to avoid authentication errors. + """ + token = os.getenv("HF_TOKEN") + if token and token.strip(): + return token + return None + + +class LazyConfigDict(dict): + def __getitem__(self, key): + if isinstance(value := super().__getitem__(key), type): + return value + + import vllm.transformers_utils.configs as configs + + return getattr(configs, value) + + +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( + chatglm="ChatGLMConfig", + deepseek_vl_v2="DeepseekVLV2Config", + deepseek_v32=DeepseekV3Config, + flex_olmo="FlexOlmoConfig", + kimi_linear="KimiLinearConfig", + kimi_vl="KimiVLConfig", + RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) + RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct) + jais="JAISConfig", + mlp_speculator="MLPSpeculatorConfig", + medusa="MedusaConfig", + midashenglm="MiDashengLMConfig", + eagle="EAGLEConfig", + speculators="SpeculatorsConfig", + nemotron="NemotronConfig", + olmo3="Olmo3Config", + ovis="OvisConfig", + ultravox="UltravoxConfig", + step3_vl="Step3VLConfig", + step3_text="Step3TextConfig", + qwen3_next="Qwen3NextConfig", + lfm2_moe="Lfm2MoeConfig", +) + +_CONFIG_ATTRS_MAPPING: dict[str, str] = { + "llm_config": "text_config", } -_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { - "chatglm": ChatGLMConfig, - "cohere2": Cohere2Config, - "dbrx": DbrxConfig, - "deepseek_vl_v2": DeepseekVLV2Config, - "kimi_vl": KimiVLConfig, - "mpt": MPTConfig, - "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) - "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) - "jais": JAISConfig, - "mlp_speculator": MLPSpeculatorConfig, - "medusa": MedusaConfig, - "eagle": EAGLEConfig, - "exaone": ExaoneConfig, - "h2ovl_chat": H2OVLChatConfig, - "internvl_chat": InternVLChatConfig, - "minimax_text_01": MiniMaxText01Config, - "minimax_vl_01": MiniMaxVL01Config, - "nemotron": NemotronConfig, - "NVLM_D": NVLM_D_Config, - "ovis": OvisConfig, - "solar": SolarConfig, - "skywork_chat": SkyworkR1VChatConfig, - "telechat": Telechat2Config, - "ultravox": UltravoxConfig, - **_CONFIG_REGISTRY_OVERRIDE_HF +_AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = { + "internvl_chat": {"has_no_defaults_at_init": True}, + "Llama_Nemotron_Nano_VL": {"attn_implementation": "eager"}, + "NVLM_D": {"has_no_defaults_at_init": True}, } -class ConfigFormat(str, enum.Enum): - AUTO = "auto" - HF = "hf" - MISTRAL = "mistral" +class HFConfigParser(ConfigParserBase): + def parse( + self, + model: str | Path, + trust_remote_code: bool, + revision: str | None = None, + code_revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type is None: + model_type = ( + "speculators" + if config_dict.get("speculators_config") is not None + else model_type + ) + + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + else: + try: + kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type) + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + except ValueError as e: + if ( + not trust_remote_code + and "requires you to execute the configuration file" in str(e) + ): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + +class MistralConfigParser(ConfigParserBase): + def parse( + self, + model: str | Path, + trust_remote_code: bool, + revision: str | None = None, + code_revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if ( + max_position_embeddings := config_dict.get("max_position_embeddings") + ) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs + ) + config_dict["max_position_embeddings"] = max_position_embeddings + + from vllm.transformers_utils.configs.mistral import adapt_config_dict + + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if (sliding_window := getattr(config, "sliding_window", None)) and isinstance( + sliding_window, list + ): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config + + +_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { + "hf": HFConfigParser, + "mistral": MistralConfigParser, +} + +ConfigFormat = Literal[ + "auto", + "hf", + "mistral", +] + + +def get_config_parser(config_format: str) -> ConfigParserBase: + """Get the config parser for a given config format.""" + if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER: + raise ValueError(f"Unknown config format `{config_format}`.") + return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]() + + +def register_config_parser(config_format: str): + """Register a customized vllm config parser. + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse( + ... self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: str | None = None, + ... code_revision: str | None = None, + ... **kwargs, + ... ) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + + """ # noqa: E501 + + def _wrapper(config_parser_cls): + if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: + logger.warning( + "Config format `%s` is already registered, and will be " + "overwritten by the new parser class `%s`.", + config_format, + config_parser_cls, + ) + if not issubclass(config_parser_cls, ConfigParserBase): + raise ValueError( + "The config parser must be a subclass of `ConfigParserBase`." + ) + _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls + logger.info( + "Registered config parser `%s` with config format `%s`", + config_parser_cls, + config_format, + ) + return config_parser_cls + + return _wrapper _R = TypeVar("_R") @@ -111,8 +303,9 @@ def with_retry( if attempt == max_retries - 1: logger.error("%s: %s", log_msg, e) raise - logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1, - max_retries) + logger.error( + "%s: %s, retrying %d of %d", log_msg, e, attempt + 1, max_retries + ) time.sleep(retry_delay) retry_delay *= 2 @@ -124,32 +317,31 @@ def with_retry( def list_repo_files( repo_id: str, *, - revision: Optional[str] = None, - repo_type: Optional[str] = None, - token: Union[str, bool, None] = None, + revision: str | None = None, + repo_type: str | None = None, + token: str | bool | None = None, ) -> list[str]: - def lookup_files() -> list[str]: # directly list files if model is local if (local_path := Path(repo_id)).exists(): return [ str(file.relative_to(local_path)) - for file in local_path.rglob('*') if file.is_file() + for file in local_path.rglob("*") + if file.is_file() ] # if model is remote, use hf_hub api to list files try: if envs.VLLM_USE_MODELSCOPE: - from vllm.transformers_utils.utils import ( - modelscope_list_repo_files) - return modelscope_list_repo_files(repo_id, - revision=revision, - token=os.getenv( - "MODELSCOPE_API_TOKEN", - None)) - return hf_list_repo_files(repo_id, - revision=revision, - repo_type=repo_type, - token=token) + from vllm.transformers_utils.utils import modelscope_list_repo_files + + return modelscope_list_repo_files( + repo_id, + revision=revision, + token=os.getenv("MODELSCOPE_API_TOKEN", None), + ) + return hf_list_repo_files( + repo_id, revision=revision, repo_type=repo_type, token=token + ) except huggingface_hub.errors.OfflineModeIsEnabled: # Don't raise in offline mode, # all we know is that we don't have this @@ -163,27 +355,27 @@ def file_exists( repo_id: str, file_name: str, *, - repo_type: Optional[str] = None, - revision: Optional[str] = None, - token: Union[str, bool, None] = None, + repo_type: str | None = None, + revision: str | None = None, + token: str | bool | None = None, ) -> bool: - file_list = list_repo_files(repo_id, - repo_type=repo_type, - revision=revision, - token=token) + file_list = list_repo_files( + repo_id, repo_type=repo_type, revision=revision, token=token + ) return file_name in file_list # In offline mode the result can be a false negative -def file_or_path_exists(model: Union[str, Path], config_name: str, - revision: Optional[str]) -> bool: +def file_or_path_exists( + model: str | Path, config_name: str, revision: str | None +) -> bool: if (local_path := Path(model)).exists(): return (local_path / config_name).is_file() # Offline mode support: Check if config file is cached already - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=config_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=config_name, revision=revision + ) if isinstance(cached_filepath, str): # The config file exists in cache- we can continue trying to load return True @@ -192,10 +384,9 @@ def file_or_path_exists(model: Union[str, Path], config_name: str, # hf_hub. This will fail in offline mode. # Call HF to check if the file exists - return file_exists(str(model), - config_name, - revision=revision, - token=os.getenv('HF_TOKEN', None)) + return file_exists( + str(model), config_name, revision=revision, token=_get_hf_token() + ) def patch_rope_scaling(config: PretrainedConfig) -> None: @@ -217,7 +408,8 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None: raise ValueError( f"Found conflicts between 'rope_type={rope_type}' (modern " f"field) and 'type={rope_type_legacy}' (legacy field). " - "You should only specify one of them.") + "You should only specify one of them." + ) if "rope_type" not in rope_scaling and "type" in rope_scaling: rope_scaling["rope_type"] = rope_scaling["type"] @@ -245,7 +437,11 @@ def _uses_mrope(config: PretrainedConfig) -> bool: def uses_mrope(config: PretrainedConfig) -> bool: """Detect if the model with this config uses M-ROPE.""" - return _uses_mrope(config) or thinker_uses_mrope(config) + return ( + _uses_mrope(config) + or _uses_mrope(config.get_text_config()) + or thinker_uses_mrope(config) + ) def thinker_uses_mrope(config: PretrainedConfig) -> bool: @@ -263,19 +459,111 @@ def thinker_uses_mrope(config: PretrainedConfig) -> bool: def is_encoder_decoder(config: PretrainedConfig) -> bool: """Detect if the model with this config is used as an encoder/decoder.""" - text_config = getattr(config, "text_config", None) - if text_config is not None: - return is_encoder_decoder(text_config) - return getattr(config, "is_encoder_decoder", False) + def _is_encoder_decoder(config: PretrainedConfig) -> bool: + return getattr(config, "is_encoder_decoder", False) + + return _is_encoder_decoder(config) or _is_encoder_decoder(config.get_text_config()) + + +def is_interleaved(config: PretrainedConfig) -> bool: + """ + Detect if the model with this config is used with interleaved attention. + """ + text_config = config.get_text_config() + if layer_types := getattr(text_config, "layer_types", None): + return len(set(layer_types)) > 1 + return False + + +def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str): + """ + Update kwargs for AutoConfig initialization based on model_type + """ + if model_type in _AUTO_CONFIG_KWARGS_OVERRIDES: + kwargs.update(_AUTO_CONFIG_KWARGS_OVERRIDES[model_type]) + return kwargs + + +def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig: + """Remap config attributes to match the expected names.""" + for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items(): + if hasattr(config, old_attr): + if not hasattr(config, new_attr): + config.update({new_attr: getattr(config, old_attr)}) + logger.debug("Remapped config attribute '%s' to '%s'", old_attr, new_attr) + return config + + +def maybe_override_with_speculators( + model: str, + tokenizer: str, + trust_remote_code: bool, + revision: str | None = None, + vllm_speculative_config: dict[str, Any] | None = None, + **kwargs, +) -> tuple[str, str, dict[str, Any] | None]: + """ + Resolve model configuration when speculators are detected. + + Checks if the provided model is a speculators model and if so, extracts + the target model configuration and builds the speculative config. + + Args: + model: Model name or path + tokenizer: Tokenizer name or path + trust_remote_code: Whether to trust remote code + revision: Model revision + vllm_speculative_config: Existing vLLM speculative config + + Returns: + Tuple of (resolved_model, resolved_tokenizer, speculative_config) + """ + is_gguf = check_gguf_file(model) + if is_gguf: + kwargs["gguf_file"] = Path(model).name + gguf_model_repo = Path(model).parent + else: + gguf_model_repo = None + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model if gguf_model_repo is None else gguf_model_repo, + revision=revision, + trust_remote_code=trust_remote_code, + token=_get_hf_token(), + **kwargs, + ) + speculators_config = config_dict.get("speculators_config") + + if speculators_config is None: + # No speculators config found, return original values + return model, tokenizer, vllm_speculative_config + + # Speculators format detected - process overrides + from vllm.transformers_utils.configs.speculators.base import SpeculatorsConfig + + speculative_config = SpeculatorsConfig.extract_vllm_speculative_config( + config_dict=config_dict + ) + + # Set the draft model to the speculators model + speculative_config["model"] = model + + # Override model and tokenizer with the verifier model from config + verifier_model = speculators_config["verifier"]["name_or_path"] + model = tokenizer = verifier_model + + return model, tokenizer, speculative_config def get_config( - model: Union[str, Path], + model: str | Path, trust_remote_code: bool, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, + revision: str | None = None, + code_revision: str | None = None, + config_format: str | ConfigFormat = "auto", + hf_overrides_kw: dict[str, Any] | None = None, + hf_overrides_fn: Callable[[PretrainedConfig], PretrainedConfig] | None = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -285,20 +573,20 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - if config_format == ConfigFormat.AUTO: + if config_format == "auto": try: - if is_gguf or file_or_path_exists( - model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF - elif file_or_path_exists(model, - MISTRAL_CONFIG_NAME, - revision=revision): - config_format = ConfigFormat.MISTRAL + if is_gguf or file_or_path_exists(model, HF_CONFIG_NAME, revision=revision): + config_format = "hf" + elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): + config_format = "mistral" else: raise ValueError( "Could not detect config format for no config file found. " - "Ensure your model has either config.json (HF format) " - "or params.json (Mistral format).") + "With config_format 'auto', ensure your model has either " + "config.json (HF format) or params.json (Mistral format). " + "Otherwise please specify your_custom_config_format " + "in engine args for customized config parser." + ) except Exception as e: error_message = ( @@ -313,74 +601,83 @@ def get_config( "'params.json'.\n" "3. For GGUF: pass the local path of the GGUF checkpoint.\n" " Loading GGUF from a remote repo directly is not yet " - "supported.\n").format(model=model) + "supported.\n" + ).format(model=model) raise ValueError(error_message) from e - if config_format == ConfigFormat.HF: - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=os.getenv('HF_TOKEN', None), - **kwargs, - ) - - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained( - model, - revision=revision, - code_revision=code_revision, - token=os.getenv('HF_TOKEN', None), - **kwargs, - ) - else: - try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=os.getenv('HF_TOKEN', None), - **kwargs, - ) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, **kwargs) - else: - supported_formats = [ - fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO - ] - raise ValueError( - f"Unsupported config format: {config_format}. " - f"Supported formats are: {', '.join(supported_formats)}. " - f"Ensure your model uses one of these configuration formats " - f"or specify the correct format explicitly.") - + config_parser = get_config_parser(config_format) + config_dict, config = config_parser.parse( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError( - f"Can't get gguf config for {config.model_type}.") + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) + # Architecture mapping for models without explicit architectures field + if not config.architectures: + if config.model_type not in MODEL_MAPPING_NAMES: + logger.warning( + "Model config does not have a top-level 'architectures' field: " + "expecting `hf_overrides={'architectures': ['...']}` to be passed " + "in engine args." + ) + else: + model_type = MODEL_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + + # ModelOpt 0.31.0 and after saves the quantization config in the model + # config file. + quantization_config = config_dict.get("quantization_config", None) + + # ModelOpt 0.29.0 and before saves the quantization config in a separate + # "hf_quant_config.json" in the same directory as the model config file. + if quantization_config is None and file_or_path_exists( + model, "hf_quant_config.json", revision + ): + quantization_config = get_hf_file_to_dict( + "hf_quant_config.json", model, revision + ) + + if quantization_config is not None: + config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0",): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0"] = "1" + logger.info_once( + ( + "Detected quantization_config.scale_fmt=%s; " + "enabling UE8M0 for DeepGEMM." + ), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.warning_once( + ( + "Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0=0 is set; " + "UE8M0 for DeepGEMM disabled." + ), + scale_fmt, + ) + + if hf_overrides_kw: + logger.debug("Overriding HF config with %s", hf_overrides_kw) + config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.debug("Overriding HF config with %s", hf_overrides_fn) + config = hf_overrides_fn(config) + patch_rope_scaling(config) if trust_remote_code: @@ -389,27 +686,27 @@ def get_config( return config -def try_get_local_file(model: Union[str, Path], - file_name: str, - revision: Optional[str] = 'main') -> Optional[Path]: +def try_get_local_file( + model: str | Path, file_name: str, revision: str | None = "main" +) -> Path | None: file_path = Path(model) / file_name if file_path.is_file(): return file_path else: try: - cached_filepath = try_to_load_from_cache(repo_id=model, - filename=file_name, - revision=revision) + cached_filepath = try_to_load_from_cache( + repo_id=model, filename=file_name, revision=revision + ) if isinstance(cached_filepath, str): return Path(cached_filepath) - except HFValidationError: + except ValueError: ... return None -def get_hf_file_to_dict(file_name: str, - model: Union[str, Path], - revision: Optional[str] = 'main'): +def get_hf_file_to_dict( + file_name: str, model: str | Path, revision: str | None = "main" +): """ Downloads a file from the Hugging Face Hub and returns its contents as a dictionary. @@ -424,25 +721,27 @@ def get_hf_file_to_dict(file_name: str, the contents of the downloaded file. """ - file_path = try_get_local_file(model=model, - file_name=file_name, - revision=revision) + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) if file_path is None: try: hf_hub_file = hf_hub_download(model, file_name, revision=revision) except huggingface_hub.errors.OfflineModeIsEnabled: return None - except (RepositoryNotFoundError, RevisionNotFoundError, - EntryNotFoundError, LocalEntryNotFoundError) as e: + except ( + RepositoryNotFoundError, + RevisionNotFoundError, + EntryNotFoundError, + LocalEntryNotFoundError, + ) as e: logger.debug("File or repository not found in hf_hub_download", e) return None except HfHubHTTPError as e: logger.warning( - "Cannot connect to Hugging Face Hub. Skipping file " - "download for '%s':", + "Cannot connect to Hugging Face Hub. Skipping file download for '%s':", file_name, - exc_info=e) + exc_info=e, + ) return None file_path = Path(hf_hub_file) @@ -454,28 +753,28 @@ def get_hf_file_to_dict(file_name: str, @cache -def get_pooling_config(model: str, revision: Optional[str] = 'main'): +def get_pooling_config(model: str, revision: str | None = "main") -> dict | None: """ This function gets the pooling and normalize config from the model - only applies to sentence-transformers models. Args: - model (str): The name of the Hugging Face model. - revision (str, optional): The specific version - of the model to use. Defaults to 'main'. + model: The name of the Hugging Face model. + revision: The specific version of the model to use. + Defaults to 'main'. Returns: - dict: A dictionary containing the pooling - type and whether normalization is used. + A dictionary containing the pooling type and whether + normalization is used, or None if no pooling configuration is found. """ modules_file_name = "modules.json" modules_dict = None - if file_or_path_exists(model=model, - config_name=modules_file_name, - revision=revision): + if file_or_path_exists( + model=model, config_name=modules_file_name, revision=revision + ): modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) if modules_dict is None: @@ -483,20 +782,31 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): logger.info("Found sentence-transformers modules configuration.") - pooling = next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Pooling"), - None) + pooling = next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling" + ), + None, + ) normalize = bool( - next((item for item in modules_dict - if item["type"] == "sentence_transformers.models.Normalize"), - False)) + next( + ( + item + for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize" + ), + False, + ) + ) if pooling: - pooling_file_name = "{}/config.json".format(pooling["path"]) pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) pooling_type_name = next( - (item for item, val in pooling_dict.items() if val is True), None) + (item for item, val in pooling_dict.items() if val is True), None + ) if pooling_type_name is not None: pooling_type_name = get_pooling_config_name(pooling_type_name) @@ -507,7 +817,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'): return None -def get_pooling_config_name(pooling_name: str) -> Union[str, None]: +def get_pooling_config_name(pooling_name: str) -> str | None: if "pooling_mode_" in pooling_name: pooling_name = pooling_name.replace("pooling_mode_", "") @@ -517,28 +827,25 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]: if "lasttoken" in pooling_name: pooling_name = "last" - supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] + supported_pooling_types = ["LAST", "ALL", "CLS", "STEP", "MEAN"] pooling_type_name = pooling_name.upper() - try: - if pooling_type_name in supported_pooling_types: - return pooling_type_name - except NotImplementedError as e: - logger.debug("Pooling type not supported", e) - return None - return None + if pooling_type_name in supported_pooling_types: + return pooling_type_name + + raise NotImplementedError(f"Pooling type {pooling_type_name} not supported") @cache -def get_sentence_transformer_tokenizer_config(model: str, - revision: Optional[str] = 'main' - ): +def get_sentence_transformer_tokenizer_config( + model: str | Path, revision: str | None = "main" +): """ Returns the tokenization configuration dictionary for a given Sentence Transformer BERT model. Parameters: - - model (str): The name of the Sentence Transformer + - model (str|Path): The name of the Sentence Transformer BERT model. - revision (str, optional): The revision of the m odel to use. Defaults to 'main'. @@ -559,26 +866,26 @@ def get_sentence_transformer_tokenizer_config(model: str, encoder_dict = None for config_file in sentence_transformer_config_files: - if try_get_local_file(model=model, - file_name=config_file, - revision=revision) is not None: + if ( + try_get_local_file(model=model, file_name=config_file, revision=revision) + is not None + ): encoder_dict = get_hf_file_to_dict(config_file, model, revision) if encoder_dict: break - if not encoder_dict and not model.startswith("/"): + if not encoder_dict and not Path(model).is_absolute(): try: # If model is on HuggingfaceHub, get the repo files - repo_files = list_repo_files(model, - revision=revision, - token=os.getenv('HF_TOKEN', None)) + repo_files = list_repo_files( + model, revision=revision, token=_get_hf_token() + ) except Exception: repo_files = [] for config_name in sentence_transformer_config_files: if config_name in repo_files: - encoder_dict = get_hf_file_to_dict(config_name, model, - revision) + encoder_dict = get_hf_file_to_dict(config_name, model, revision) if encoder_dict: break @@ -595,186 +902,83 @@ def get_sentence_transformer_tokenizer_config(model: str, def maybe_register_config_serialize_by_value() -> None: """Try to register HF model configuration class to serialize by value - If trust_remote_code is set, and the model's config file specifies an - `AutoConfig` class, then the config class is typically an instance of - a custom class imported from the HF modules cache. + If trust_remote_code is set, and the model's config file specifies an + `AutoConfig` class, then the config class is typically an instance of + a custom class imported from the HF modules cache. - Examples: + Examples: - >>> from transformers import AutoConfig - >>> klass = AutoConfig.from_pretrained('meta-llama/Meta-Llama-3-8B', trust_remote_code=True) - >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig - >>> import transformers_modules # error, not initialized - >>> klass = AutoConfig.from_pretrained('deepseek-ai/DeepSeek-V2.5', trust_remote_code=True) - >>> import transformers_modules # success, initialized - >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config + >>> from transformers import AutoConfig + >>> klass = AutoConfig.from_pretrained( + ... "meta-llama/Meta-Llama-3-8B", trust_remote_code=True + ... ) + >>> klass.__class__ # transformers.models.llama.configuration_llama.LlamaConfig + >>> import transformers_modules # error, not initialized + >>> klass = AutoConfig.from_pretrained( + ... "deepseek-ai/DeepSeek-V2.5", trust_remote_code=True + ... ) + >>> import transformers_modules # success, initialized + >>> klass.__class__ # transformers_modules.deepseek-ai.DeepSeek-V2.5.98b11844770b2c3ffc18b175c758a803640f4e77.configuration_deepseek.DeepseekV2Config - In the DeepSeek example, the config class is an instance of a custom - class that is not serializable by default. This class will not be - importable in spawned workers, and won't exist at all on - other nodes, which breaks serialization of the config. + In the DeepSeek example, the config class is an instance of a custom + class that is not serializable by default. This class will not be + importable in spawned workers, and won't exist at all on + other nodes, which breaks serialization of the config. - In this function we tell the cloudpickle serialization library to pass - instances of these generated classes by value instead of by reference, - i.e. the class definition is serialized along with its data so that the - class module does not need to be importable on the receiving end. + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. - See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs - """ # noqa + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ # noqa try: import transformers_modules + + transformers_modules_available = True except ImportError: - # the config does not need trust_remote_code - return + transformers_modules_available = False try: - import cloudpickle - cloudpickle.register_pickle_by_value(transformers_modules) - - # ray vendors its own version of cloudpickle - from vllm.executor.ray_utils import ray - if ray: - ray.cloudpickle.register_pickle_by_value(transformers_modules) - - # multiprocessing uses pickle to serialize arguments when using spawn - # Here we get pickle to use cloudpickle to serialize config objects - # that contain instances of the custom config class to avoid - # serialization problems if the generated module (and model) has a `.` - # in its name import multiprocessing import pickle + import cloudpickle + from vllm.config import VllmConfig + # Register multiprocessing reducers to handle cross-process + # serialization of VllmConfig objects that may contain custom configs + # from transformers_modules def _reduce_config(config: VllmConfig): - return (pickle.loads, (cloudpickle.dumps(config), )) + return (pickle.loads, (cloudpickle.dumps(config),)) multiprocessing.reducer.register(VllmConfig, _reduce_config) + # Register transformers_modules with cloudpickle if available + if transformers_modules_available: + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + from vllm.v1.executor.ray_utils import ray + + if ray: + ray.cloudpickle.register_pickle_by_value(transformers_modules) + except Exception as e: logger.warning( "Unable to register remote classes used by" " trust_remote_code with by-value serialization. This may" " lead to a later error. If remote code is not needed" " remove `--trust-remote-code`", - exc_info=e) - - -def load_params_config(model: Union[str, Path], revision: Optional[str], - **kwargs) -> PretrainedConfig: - # This function loads a params.json config which - # should be used when loading models in mistral format - - config_file_name = "params.json" - - config_dict = get_hf_file_to_dict(config_file_name, model, revision) - if config_dict is None: - raise ValueError( - f"Failed to load mistral '{config_file_name}' config for model " - f"{model}. Please check if the model is a mistral-format model " - f"and if the config file exists.") - assert isinstance(config_dict, dict) - - config_mapping = { - "dim": "hidden_size", - "norm_eps": "rms_norm_eps", - "n_kv_heads": "num_key_value_heads", - "n_layers": "num_hidden_layers", - "n_heads": "num_attention_heads", - "hidden_dim": "intermediate_size", - } - - def recurse_elems(elem: Any): - if isinstance(elem, dict): - config_dict = {} - for key, value in elem.items(): - key = config_mapping.get(key, key) - config_dict[key] = recurse_elems(value) - - return config_dict - else: - return elem - - config_dict["model_type"] = config_dict.get("model_type", "transformer") - config_dict["hidden_act"] = config_dict.get("activation", "silu") - config_dict["tie_word_embeddings"] = config_dict.get( - "tie_embeddings", False) - - if config_dict.get("max_position_embeddings") is None: - max_position_embeddings = 128_000 - try: - trust_remote_code_val = kwargs.get("trust_remote_code", False) - hf_config = get_config(model=model, - trust_remote_code=trust_remote_code_val, - revision=revision, - config_format=ConfigFormat.HF) - if hf_value := hf_config.get_text_config().max_position_embeddings: - max_position_embeddings = hf_value - except Exception as e: - logger.warning( - "The params.json file is missing 'max_position_embeddings'" - " and could not get a value from the HF config." - " Defaulting to 128000", - exc_info=e) - config_dict["max_position_embeddings"] = max_position_embeddings - - if config_dict.get("quantization") is not None: - quantization = config_dict.get("quantization", {}) - if quantization.get("qformat_weight") == "fp8_e4m3": - # This maps to the FP8 static per-tensor quantization scheme - quantization_config = { - "quant_method": "fp8", - "activation_scheme": "static" - } - elif quantization.get("quant_method") == "compressed-tensors": - # Pass through the quantization config to compressed-tensors - quantization_config = quantization - else: - raise ValueError( - f"Found unknown quantization='{quantization}' in config") - - config_dict["quantization_config"] = quantization_config - - config_type: Literal["text", - "multimodal"] = "multimodal" if config_dict.get( - "vision_encoder") is not None else "text" - - if config_dict.get("moe") is not None: - config_dict["architectures"] = ["MixtralForCausalLM"] - else: - config_dict["architectures"] = ["MistralForCausalLM"] - - if config_type == "multimodal": - multimodal_config = config_dict.pop("vision_encoder") - quantization_config = config_dict.get("quantization_config", {}) - - config_dict = { - "text_config": config_dict, - "vision_config": multimodal_config - } - config_dict["architectures"] = ["PixtralForConditionalGeneration"] - config_dict["model_type"] = "pixtral" - if quantization_config: - config_dict["quantization_config"] = quantization_config - - config_dict.update(kwargs) - - config_dict = recurse_elems(config_dict) - - # transform to HF config format - if config_type == "multimodal": - config_dict["text_config"] = PretrainedConfig( - **config_dict["text_config"]) - config_dict["vision_config"] = PretrainedConfig( - **config_dict["vision_config"]) - - return PretrainedConfig(**config_dict) + exc_info=e, + ) def get_hf_image_processor_config( - model: Union[str, Path], - hf_token: Optional[Union[bool, str]] = None, - revision: Optional[str] = None, + model: str | Path, + hf_token: bool | str | None = None, + revision: str | None = None, **kwargs, ) -> dict[str, Any]: # ModelScope does not provide an interface for image_processor @@ -783,23 +987,15 @@ def get_hf_image_processor_config( # Separate model folder from file path for GGUF models if check_gguf_file(model): model = Path(model).parent - return get_image_processor_config(model, - token=hf_token, - revision=revision, - **kwargs) + return get_image_processor_config( + model, token=hf_token, revision=revision, **kwargs + ) def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ - # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517 - if hasattr(config, "thinker_config"): - # TODO(suyang.fy): Refactor code. - # For Qwen2.5-Omni, change hf_text_config to - # thinker_config.text_config. - return config.thinker_config.text_config - text_config = config.get_text_config() if text_config is not config: @@ -814,8 +1010,9 @@ def get_hf_text_config(config: PretrainedConfig): def try_get_generation_config( model: str, trust_remote_code: bool, - revision: Optional[str] = None, -) -> Optional[GenerationConfig]: + revision: str | None = None, + config_format: str | ConfigFormat = "auto", +) -> GenerationConfig | None: try: return GenerationConfig.from_pretrained( model, @@ -827,56 +1024,38 @@ def try_get_generation_config( model, trust_remote_code=trust_remote_code, revision=revision, + config_format=config_format, ) return GenerationConfig.from_model_config(config) except OSError: # Not found return None -def get_cross_encoder_activation_function(config: PretrainedConfig): - - function_name: Optional[str] = None - if hasattr(config, "sentence_transformers") and "activation_fn" in \ - config.sentence_transformers: - function_name = config.sentence_transformers["activation_fn"] - - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), \ - "Loading of activation functions is restricted to " \ - "torch.nn.modules for security reasons" - return resolve_obj_by_qualname(function_name)() - else: - return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() - - def try_get_safetensors_metadata( model: str, *, - revision: Optional[str] = None, + revision: str | None = None, ): get_safetensors_metadata_partial = partial( get_safetensors_metadata, model, revision=revision, - token=os.getenv('HF_TOKEN', None), + token=_get_hf_token(), ) try: - return with_retry(get_safetensors_metadata_partial, - "Error retrieving safetensors") + return with_retry( + get_safetensors_metadata_partial, "Error retrieving safetensors" + ) except Exception: return None def try_get_tokenizer_config( - pretrained_model_name_or_path: Union[str, os.PathLike], + pretrained_model_name_or_path: str | os.PathLike, trust_remote_code: bool, - revision: Optional[str] = None, -) -> Optional[dict[str, Any]]: + revision: str | None = None, +) -> dict[str, Any] | None: try: return get_tokenizer_config( pretrained_model_name_or_path, @@ -885,3 +1064,139 @@ def try_get_tokenizer_config( ) except Exception: return None + + +@cache +def try_get_dense_modules( + model: str | Path, + revision: str | None = None, +) -> list[dict[str, Any]] | None: + try: + modules = get_hf_file_to_dict("modules.json", model, revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + layer_configs = [] + for module in dense_modules: + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model, revision) + if not layer_config: + continue + layer_config["folder"] = folder + layer_configs.append(layer_config) + return layer_configs + except Exception: + return None + + +def get_safetensors_params_metadata( + model: str, + *, + revision: str | None = None, +) -> dict[str, Any]: + """ + Get the safetensors metadata for remote model repository. + """ + full_metadata = {} + if (model_path := Path(model)).exists(): + safetensors_to_check = model_path.glob("*.safetensors") + full_metadata = { + param_name: info + for file_path in safetensors_to_check + if file_path.is_file() + for param_name, info in parse_safetensors_file_metadata(file_path).items() + } + else: + repo_mt = try_get_safetensors_metadata(model, revision=revision) + if repo_mt and (files_mt := repo_mt.files_metadata): + full_metadata = { + param_name: asdict(info) + for file_mt in files_mt.values() + for param_name, info in file_mt.tensors.items() + } + return full_metadata + + +def _download_mistral_config_file(model, revision) -> dict: + config_file_name = "params.json" + config_dict = get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists." + ) + assert isinstance(config_dict, dict) + return config_dict + + +def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: + max_position_embeddings = 128_000 + try: + trust_remote_code_val = kwargs.get("trust_remote_code", False) + hf_config = get_config( + model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format="hf", + ) + if hf_value := hf_config.get_text_config().max_position_embeddings: + max_position_embeddings = hf_value + except Exception as e: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000", + exc_info=e, + ) + + return max_position_embeddings + + +def get_model_path(model: str | Path, revision: str | None = None): + if os.path.exists(model): + return model + assert huggingface_hub.constants.HF_HUB_OFFLINE + common_kwargs = { + "local_files_only": huggingface_hub.constants.HF_HUB_OFFLINE, + "revision": revision, + } + + if envs.VLLM_USE_MODELSCOPE: + from modelscope.hub.snapshot_download import snapshot_download + + return snapshot_download(model_id=model, **common_kwargs) + + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes( + file_name: str, model: str | Path, revision: str | None = "main" +) -> bytes | None: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, file_name=file_name, revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download( + model, file_name, revision=revision, token=_get_hf_token() + ) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, "rb") as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py new file mode 100644 index 0000000..79d47ff --- /dev/null +++ b/vllm/transformers_utils/config_parser_base.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from pathlib import Path + +from transformers import PretrainedConfig + + +class ConfigParserBase(ABC): + @abstractmethod + def parse( + self, + model: str | Path, + trust_remote_code: bool, + revision: str | None = None, + code_revision: str | None = None, + **kwargs, + ) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError diff --git a/vllm/transformers_utils/dynamic_module.py b/vllm/transformers_utils/dynamic_module.py new file mode 100644 index 0000000..24ead83 --- /dev/null +++ b/vllm/transformers_utils/dynamic_module.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def try_get_class_from_dynamic_module( + class_reference: str, + pretrained_model_name_or_path: str, + cache_dir: str | os.PathLike | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: bool | str | None = None, + revision: str | None = None, + local_files_only: bool = False, + repo_type: str | None = None, + code_revision: str | None = None, + warn_on_fail: bool = True, + **kwargs, +) -> type | None: + """ + As `transformers.dynamic_module_utils.get_class_from_dynamic_module`, + but ignoring any errors. + """ + try: + return get_class_from_dynamic_module( + class_reference, + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + revision=revision, + local_files_only=local_files_only, + repo_type=repo_type, + code_revision=code_revision, + **kwargs, + ) + except Exception: + location = "ModelScope" if envs.VLLM_USE_MODELSCOPE else "HF Hub" + + if warn_on_fail: + logger.warning( + "Unable to load %s from %s on %s.", + class_reference, + pretrained_model_name_or_path, + location, + exc_info=True, + ) + + return None diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 66c8fb7..1ae42ba 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -2,22 +2,32 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json +import os +import struct from functools import cache from os import PathLike from pathlib import Path -from typing import Optional, Union +from typing import Any -from vllm.envs import VLLM_MODEL_REDIRECT_PATH +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) def is_s3(model_or_path: str) -> bool: - return model_or_path.lower().startswith('s3://') + return model_or_path.lower().startswith("s3://") -def check_gguf_file(model: Union[str, PathLike]) -> bool: +def is_gcs(model_or_path: str) -> bool: + return model_or_path.lower().startswith("gs://") + + +def is_cloud_storage(model_or_path: str) -> bool: + return is_s3(model_or_path) or is_gcs(model_or_path) + + +def check_gguf_file(model: str | PathLike) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): @@ -37,23 +47,26 @@ def check_gguf_file(model: Union[str, PathLike]) -> bool: def modelscope_list_repo_files( repo_id: str, - revision: Optional[str] = None, - token: Union[str, bool, None] = None, + revision: str | None = None, + token: str | bool | None = None, ) -> list[str]: """List files in a modelscope repo.""" from modelscope.hub.api import HubApi + api = HubApi() api.login(token) # same as huggingface_hub.list_repo_files files = [ - file['Path'] for file in api.get_model_files( - model_id=repo_id, revision=revision, recursive=True) - if file['Type'] == 'blob' + file["Path"] + for file in api.get_model_files( + model_id=repo_id, revision=revision, recursive=True + ) + if file["Type"] == "blob" ] return files -def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]: +def _maybe_json_dict(path: str | PathLike) -> dict[str, str]: with open(path) as f: try: return json.loads(f.read()) @@ -61,7 +74,7 @@ def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]: return dict[str, str]() -def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]: +def _maybe_space_split_dict(path: str | PathLike) -> dict[str, str]: parsed_dict = dict[str, str]() with open(path) as f: for line in f.readlines(): @@ -82,7 +95,7 @@ def maybe_model_redirect(model: str) -> str: :return: maybe redirect to a local folder """ - model_redirect_path = VLLM_MODEL_REDIRECT_PATH + model_redirect_path = envs.VLLM_MODEL_REDIRECT_PATH if not model_redirect_path: return model @@ -90,10 +103,28 @@ def maybe_model_redirect(model: str) -> str: if not Path(model_redirect_path).exists(): return model - redirect_dict = (_maybe_json_dict(model_redirect_path) - or _maybe_space_split_dict(model_redirect_path)) - if (redirect_model := redirect_dict.get(model)): + redirect_dict = _maybe_json_dict(model_redirect_path) or _maybe_space_split_dict( + model_redirect_path + ) + if redirect_model := redirect_dict.get(model): logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model) return redirect_model return model + + +def parse_safetensors_file_metadata(path: str | PathLike) -> dict[str, Any]: + with open(path, "rb") as f: + length_of_metadata = struct.unpack(" str: + """When VLLM_USE_MODELSCOPE is True convert a model + repository string to a Path str.""" + if not envs.VLLM_USE_MODELSCOPE or Path(model_repo).exists(): + return model_repo + from modelscope.utils.file_utils import get_model_cache_root + + return os.path.join(get_model_cache_root(), model_repo)