# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import importlib.metadata import os import threading from collections.abc import Callable, Collection from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeVar import numpy as np import numpy.typing as npt import torch from packaging import version from packaging.version import Version from torch.library import Library import vllm.envs as envs import ixformer.inference.functions as ixfops if TYPE_CHECKING: from vllm.config import ModelConfig from vllm.sequence import IntermediateTensors else: ModelConfig = object IntermediateTensors = object STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, "half": torch.half, "bfloat16": torch.bfloat16, "float": torch.float, "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, "fp8_ds_mla": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, torch.uint8: np.uint8, torch.int32: np.int32, torch.int64: np.int64, } T = TypeVar("T") @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" old_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) yield torch.set_default_dtype(old_dtype) @contextlib.contextmanager def set_default_torch_num_threads(num_threads: int): """Sets the default number of threads for PyTorch to the given value.""" old_num_threads = torch.get_num_threads() torch.set_num_threads(num_threads) yield torch.set_num_threads(old_num_threads) @contextlib.contextmanager def guard_cuda_initialization(): """Avoid unexpected CUDA initialization.""" from vllm.platforms import current_platform if not current_platform.is_cuda(): yield return had_key = "CUDA_VISIBLE_DEVICES" in os.environ old_value = os.environ.get("CUDA_VISIBLE_DEVICES") os.environ["CUDA_VISIBLE_DEVICES"] = "" try: yield except Exception as e: if "No CUDA GPUs are available" in str(e): err_msg = "CUDA initialization is blocked." else: err_msg = str(e) raise RuntimeError(err_msg) from e finally: if had_key: os.environ["CUDA_VISIBLE_DEVICES"] = old_value else: os.environ.pop("CUDA_VISIBLE_DEVICES") def get_dtype_size(dtype: torch.dtype) -> int: """Get the size of the data type in bytes.""" return torch.tensor([], dtype=dtype).element_size() # bool = 0, int = 1, float = 2, complex = 3 def _get_precision_level(dtype: torch.dtype) -> int: # NOTE: Complex dtypes return `is_floating_point=False` return (dtype != torch.bool) + dtype.is_floating_point + dtype.is_complex * 2 def is_lossless_cast(src_dtype: torch.dtype, tgt_dtype: torch.dtype): """ Test whether it is lossless to cast a tensor from `src_dtype` to `tgt_dtype`. """ if src_dtype == tgt_dtype: return True src_level = _get_precision_level(src_dtype) tgt_level = _get_precision_level(tgt_dtype) if src_level < tgt_level: return True if src_level > tgt_level: return False # Compare integral types if not src_dtype.is_floating_point and not src_dtype.is_complex: src_info = torch.iinfo(src_dtype) tgt_info = torch.iinfo(tgt_dtype) return src_info.min >= tgt_info.min and src_info.max <= tgt_info.max # Compare floating-point types src_info = torch.finfo(src_dtype) tgt_info = torch.finfo(tgt_dtype) return ( src_info.min >= tgt_info.min and src_info.max <= tgt_info.max and src_info.resolution >= tgt_info.resolution ) def common_broadcastable_dtype(dtypes: Collection[torch.dtype]): """ Get the common `dtype` where all of the other `dtypes` can be cast to it without losing any information. """ return max( dtypes, key=lambda dtype: sum(is_lossless_cast(dt, dtype) for dt in dtypes), ) def _generate_random_fp8( tensor: torch.Tensor, low: float, high: float, ) -> None: # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, # it may occur Inf or NaN if we directly use torch.randint # to generate random data for fp8 data. # For example, s.11111.00 in fp8e5m2 format represents Inf. # | E4M3 | E5M2 # -----|-------------|------------------- # Inf | N/A | s.11111.00 # NaN | s.1111.111 | s.11111.{01,10,11} from vllm import _custom_ops as ops tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) tensor_tmp.uniform_(low, high) ops.convert_fp8(tensor, tensor_tmp) del tensor_tmp def get_kv_cache_torch_dtype( cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, ) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype): torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") return torch_dtype def kv_cache_dtype_str_to_dtype( kv_cache_dtype: str, model_config: ModelConfig ) -> torch.dtype: if kv_cache_dtype == "auto": # Model config may not be specified for unit tests, default to float16 return model_config.dtype if model_config else torch.half return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype] def create_kv_caches_with_random_flash( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, seed: int | None = None, device: str | None = "cuda", cache_layout: str | None = "NHD", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: from vllm.platforms import current_platform current_platform.seed_everything(seed) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) assert cache_layout in ("NHD", "HND") stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i] for i in stride_order) scale = head_size**-0.5 key_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_value_cache = torch.empty( size=kv_cache_allocation_shape, dtype=dtype, device=device ).permute(*stride_order) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_value_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_value_cache[:, 0]) value_caches.append(key_value_cache[:, 1]) return key_caches, value_caches def create_kv_caches_with_random( num_blocks: int, block_size: int, num_layers: int, num_heads: int, head_size: int, cache_dtype: str | torch.dtype | None, model_dtype: str | torch.dtype | None = None, seed: int | None = None, device: str | None = "cuda", ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: if cache_dtype == "fp8" and head_size % 16: raise ValueError( f"Does not support key cache of type fp8 with head_size {head_size}" ) from vllm.platforms import current_platform current_platform.seed_everything(seed) dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) key_caches: list[torch.Tensor] = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: key_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(key_cache, -scale, scale) else: raise ValueError(f"Does not support key cache of type {cache_dtype}") key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_caches: list[torch.Tensor] = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) if cache_dtype in ["auto", "half", "bfloat16", "float"]: value_cache.uniform_(-scale, scale) elif cache_dtype == "fp8": _generate_random_fp8(value_cache, -scale, scale) else: raise ValueError(f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches def async_tensor_h2d( data: list, dtype: torch.dtype, target_device: str | torch.device, pin_memory: bool, ) -> torch.Tensor: """Asynchronously create a tensor and copy it from host to device.""" t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) def make_ndarray_with_pad( x: list[list[T]], pad: T, dtype: npt.DTypeLike, *, max_len: int | None = None, ) -> npt.NDArray: """ Make a padded array from 2D inputs. The padding is applied to the end of each inner list until it reaches `max_len`. """ if max_len is None: # Unlike for most functions, map is faster than a genexpr over `len` max_len = max(map(len, x), default=0) padded_x = np.full((len(x), max_len), pad, dtype=dtype) for ind, blocktb in enumerate(x): assert len(blocktb) <= max_len padded_x[ind, : len(blocktb)] = blocktb return padded_x def make_tensor_with_pad( x: list[list[T]], pad: T, dtype: torch.dtype, *, max_len: int | None = None, device: str | torch.device | None = None, pin_memory: bool = False, ) -> torch.Tensor: """ Make a padded tensor from 2D inputs. The padding is applied to the end of each inner list until it reaches `max_len`. """ np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) tensor = torch.from_numpy(padded_x).to(device) if pin_memory: tensor = tensor.pin_memory() return tensor prev_set_stream = torch.cuda.set_stream _current_stream_tls = threading.local() def _patched_set_stream(stream: torch.cuda.Stream) -> None: _current_stream_tls.value = stream prev_set_stream(stream) torch.cuda.set_stream = _patched_set_stream class _StreamPlaceholder: def __init__(self): self.synchronize = lambda: None def current_stream() -> torch.cuda.Stream: """ replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`. it turns out that `torch.cuda.current_stream()` is quite expensive, as it will construct a new stream object at each call. here we patch `torch.cuda.set_stream` to keep track of the current stream directly, so that we can avoid calling `torch.cuda.current_stream()`. the underlying hypothesis is that we do not call `torch._C._cuda_setStream` from C/C++ code. """ from vllm.platforms import current_platform if not hasattr(_current_stream_tls, "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process if current_platform.is_rocm(): # torch.cuda.set_stream here is the alias of _pathed_set_stream torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu(): _current_stream_tls.value = _StreamPlaceholder() else: current_stream = current_platform.current_stream if current_stream is not None: _current_stream_tls.value = current_stream() else: raise ValueError( "Fail to set current stream, current platform " "may not support current_stream with torch API" ) return _current_stream_tls.value # Global auxilary stream for running operations in background streams. # We have single global auxilary stream to avoid an explosion of streams # for every layer (and make profiling look sane). # # aux_stream() is currently used for: # - MoE shared_expert overlap with router _aux_stream: torch.cuda.Stream | None = None def aux_stream() -> torch.cuda.Stream | None: """ Ensures aux_stream is initialized only once """ global _aux_stream from vllm.platforms import current_platform # TODO: validate this works properly on ROCm platform. if _aux_stream is None and current_platform.is_cuda(): _aux_stream = torch.cuda.Stream() return _aux_stream @lru_cache(maxsize=8) def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int: # Note: cuda_visible_devices is not used, but we keep it as an argument for # LRU Cache purposes. # Code below is based on # https://github.com/pytorch/pytorch/blob/ # c1cd946818442aca8c7f812b16d187ce1586c3bc/ # torch/cuda/__init__.py#L831C1-L831C17 import torch.cuda import torch.version from vllm.platforms import current_platform if not torch.cuda._is_compiled(): return 0 if current_platform.is_rocm(): # ROCm uses amdsmi instead of nvml for stateless device count # This requires a sufficiently modern version of Torch 2.4.0 raw_count = ( torch.cuda._device_count_amdsmi() if (hasattr(torch.cuda, "_device_count_amdsmi")) else -1 ) else: raw_count = torch.cuda._device_count_nvml() r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count return r def cuda_device_count_stateless() -> int: """Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES at the time of call. This should be used instead of torch.cuda.device_count() unless CUDA_VISIBLE_DEVICES has already been set to the desired value.""" # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. """ if isinstance(tensor, torch.Tensor): return ixfops.weak_ref_tensor(tensor) else: return tensor def weak_ref_tensors( tensors: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor] | IntermediateTensors, ) -> torch.Tensor | list[Any] | tuple[Any] | Any: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. """ if isinstance(tensors, torch.Tensor): return weak_ref_tensor(tensors) if isinstance(tensors, list): return [weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) # For IntermediateTensors used in pipeline parallelism from vllm.sequence import IntermediateTensors if isinstance(tensors, IntermediateTensors): ret = IntermediateTensors( {key: weak_ref_tensor(val) for key, val in tensors.tensors.items()} ) return ret raise ValueError("Invalid type for tensors") def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: """ Get a CUDA view of a CPU tensor using Unified Virtual Addressing (UVA). """ assert cpu_tensor.is_pinned(), "CPU tensor must be pinned" return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) # Helper function used in testing. def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: torch_version = version.parse(torch_version) return torch_version >= version.parse(target) def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. Args: target: a version string, like "2.6.0". Returns: Whether the condition meets. """ try: return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version("torch")) >= Version(target) def _is_torch_equal(target: str) -> bool: assert target.count(".") == 2 torch_version = str(torch.__version__) torch_version = version.parse(torch_version) # torch version is like "2.6.0.dev20240101" or "2.6.0.dev20240101+cpu" # or "2.6.0+cu128" but never "2.6.0.1" return ( torch_version >= version.parse(target) and version.parse(target + ".1") > torch_version ) def is_torch_equal(target: str) -> bool: """Check if the installed torch version is == the target version. Args: target: a version string, like "2.6.0". Returns: Whether the condition meets. """ try: return _is_torch_equal(target) except Exception: return Version(importlib.metadata.version("torch")) == Version(target) # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. def supports_dynamo() -> bool: return is_torch_equal_or_newer("2.4.0") # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: return ( is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available() ) # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: return hasattr(torch.library, "custom_op") # create a library to hold the custom op vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: list[str] | None = None, fake_impl: Callable | None = None, target_lib: Library | None = None, dispatch_key: str | None = None, tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. By default, the custom op is registered to the vLLM library. If you want to register it to a different library, you can pass the library object to the `target_lib` argument. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ if not supports_custom_op(): from vllm.platforms import current_platform assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " "the required dependencies." ) return if mutates_args is None: mutates_args = [] if dispatch_key is None: from vllm.platforms import current_platform dispatch_key = current_platform.dispatch_key import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl)