# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from __future__ import annotations import contextlib import gc import os import time import torch from torch.library import Library from dataclasses import dataclass, field from functools import lru_cache from typing import Optional, Callable, Tuple, Generator import vllm.envs as envs from vllm.platforms import current_platform from vllm.ray.lazy_utils import is_in_ray_actor from vllm.utils import ( torch_utils, system_utils, ) from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, supports_custom_op, vllm_lib, ) from vllm.utils.mem_utils import GiB_bytes from vllm.utils.platform_utils import ( cuda_is_initialized, xpu_is_initialized, ) from vllm.logger import init_logger from vllm_mlu.mlu_hijack_utils import MluHijackObject logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8 @dataclass class MemorySnapshot: """Memory snapshot.""" torch_peak: int = 0 free_memory: int = 0 total_memory: int = 0 mlu_memory: int = 0 torch_memory: int = 0 non_torch_memory: int = 0 timestamp: float = 0.0 auto_measure: bool = True def __post_init__(self): if self.auto_measure: self.measure() def measure(self): # we measure the torch peak memory usage via allocated_bytes, # rather than `torch.mlu.memory_reserved()` . # After `torch.mlu.reset_peak_memory_stats()`, # `torch.mlu.memory_reserved()` will keep growing, and only shrink # when we call `torch.mlu.empty_cache()` or OOM happens. self.torch_peak = torch.mlu.memory_stats().get( "allocated_bytes.all.peak", 0) self.free_memory, self.total_memory = torch.mlu.mem_get_info() self.mlu_memory = self.total_memory - self.free_memory # torch.mlu.memory_reserved() is how many bytes # PyTorch gets from mlu (by calling mluMalloc, etc.) # this is used to measure the non-torch memory usage self.torch_memory = torch.mlu.memory_reserved() self.non_torch_memory = self.mlu_memory - self.torch_memory self.timestamp = time.time() def __sub__(self, other: MemorySnapshot) -> MemorySnapshot: return MemorySnapshot( torch_peak=self.torch_peak - other.torch_peak, free_memory=self.free_memory - other.free_memory, total_memory=self.total_memory - other.total_memory, mlu_memory=self.mlu_memory - other.mlu_memory, torch_memory=self.torch_memory - other.torch_memory, non_torch_memory=self.non_torch_memory - other.non_torch_memory, timestamp=self.timestamp - other.timestamp, auto_measure=False, ) @dataclass class MemoryProfilingResult: """Memory profiling result. All numbers are in bytes. """ non_kv_cache_memory: int = 0 torch_peak_increase: int = 0 non_torch_increase: int = 0 weights_memory: float = 0 before_create: MemorySnapshot = field(default_factory=MemorySnapshot) before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) profile_time: float = 0.0 def __repr__(self) -> str: return (f"Memory profiling takes {self.profile_time:.2f} seconds. " f"Total non KV cache memory: " f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; " f"torch peak memory increase: " f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; " f"non-torch forward increase memory: " f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; " f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.") @contextlib.contextmanager def memory_profiling( baseline_snapshot: MemorySnapshot, weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: """Memory profiling context manager. baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. Note that, before loading the model weights, we also initialize the device and distributed environment, which may consume some memory. This part is not included in the weights_memory because PyTorch does not control it. The memory in one GPU can be classified into 3 categories: 1. memory used by anything other than the current vLLM instance. 2. memory used by torch in the current vLLM instance. 3. memory used in the current vLLM instance, but not by torch. A quantitive example: Before creating the current vLLM instance: category 1: 1 GiB category 2: 0 GiB category 3: 0 GiB After creating the current vLLM instance and loading the model, (i.e. before profiling): category 1: 1 GiB category 2: 2 GiB (model weights take 2 GiB) category 3: 0.5 GiB (memory used by NCCL) During profiling (peak): category 1: 1 GiB category 2: 4 GiB (peak activation tensors take 2 GiB) category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) After profiling: category 1: 1 GiB category 2: 3 GiB (after garbage-collecting activation tensors) category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) In this case, non-kv cache takes 5 GiB in total, including: a. 2 GiB used by the model weights (category 2) b. 2 GiB reserved for the peak activation tensors (category 2) c. 1 GiB used by non-torch components (category 3) The memory used for loading weights (a.) is directly given from the argument `weights_memory`. The increase of `torch.mlu.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). """ # noqa gc.collect() torch.mlu.empty_cache() torch.mlu.reset_peak_memory_stats() result = MemoryProfilingResult() result.before_create = baseline_snapshot # the part of memory used for holding the model weights result.weights_memory = weights_memory result.before_profile.measure() yield result gc.collect() torch.mlu.empty_cache() result.after_profile.measure() diff_profile = result.after_profile - result.before_profile diff_from_create = result.after_profile - result.before_create result.torch_peak_increase = diff_profile.torch_peak result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa @lru_cache(maxsize=8) def _mlu_device_count_stateless( mlu_visible_devices: Optional[str] = None) -> int: if mlu_visible_devices is None: return torch.mlu.device_count() if mlu_visible_devices == "": return 0 if "," not in mlu_visible_devices: return 1 return len(mlu_visible_devices.split(",")) def mlu_device_count_stateless() -> int: """Get number of MLU devices, caching based on the value of MLU_VISIBLE_DEVICES at the time of call. This should be used instead of torch.cuda.device_count() unless MLU_VISIBLE_DEVICES has already been set to the desired value.""" # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. return _mlu_device_count_stateless(os.environ.get("MLU_VISIBLE_DEVICES", "mlu")) def vllm__utils_system_utils___maybe_force_spawn(): """Check if we need to force the use of the `spawn` multiprocessing start method. """ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn": return reasons = [] if is_in_ray_actor(): # even if we choose to spawn, we need to pass the ray address # to the subprocess so that it knows how to connect to the ray cluster. # env vars are inherited by subprocesses, even if we use spawn. import ray os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address reasons.append("In a Ray actor and can only be spawned") ''' ============================= Modify by vllm_mlu ============================= @brief: Force use spawn for MLU platform. ''' if cuda_is_initialized(): reasons.append("CUDA is initialized") elif xpu_is_initialized(): reasons.append("XPU is initialized") elif current_platform.is_out_of_tree(): reasons.append("MLU is initialized") ''' ================== End of MLU Hijack ================== ''' if reasons: logger.warning( "We must use the `spawn` multiprocessing start method. " "Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " "See https://docs.vllm.ai/en/latest/getting_started/" "troubleshooting.html#python-multiprocessing " "for more information. Reason: %s", reasons) os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" ''' ============================= Modify by vllm_mlu ============================= @brief: change dispatch_key default value from 'CUDA' to 'MLU' ''' vllm__utils__torch_utils__direct_register_custom_op_org = torch_utils.direct_register_custom_op def vllm__utils__torch_utils__direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: list[str] | None = [], fake_impl: Callable | None = None, target_lib: Library | None = None, dispatch_key: str = "MLU", tags: Tuple[torch.Tag, ...] = (), ): vllm__utils__torch_utils__direct_register_custom_op_org( op_name=op_name, op_func=op_func, mutates_args=mutates_args, fake_impl=fake_impl, target_lib=target_lib, dispatch_key=dispatch_key, tags=tags, ) ''' ================== End of MLU Hijack ================== ''' MluHijackObject.apply_hijack(torch_utils, torch_utils.direct_register_custom_op, vllm__utils__torch_utils__direct_register_custom_op) MluHijackObject.apply_hijack(system_utils, system_utils._maybe_force_spawn, vllm__utils_system_utils___maybe_force_spawn)