Files
enginex-mlu590-vllm/vllm_mlu/utils.py
2026-04-24 09:58:03 +08:00

301 lines
10 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from __future__ import annotations
import contextlib
import gc
import os
import time
import torch
from torch.library import Library
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Optional, Callable, Tuple, Generator
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.ray.lazy_utils import is_in_ray_actor
from vllm.utils import (
torch_utils,
system_utils,
)
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
supports_custom_op,
vllm_lib,
)
from vllm.utils.mem_utils import GiB_bytes
from vllm.utils.platform_utils import (
cuda_is_initialized,
xpu_is_initialized,
)
from vllm.logger import init_logger
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8
@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak: int = 0
free_memory: int = 0
total_memory: int = 0
mlu_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
auto_measure: bool = True
def __post_init__(self):
if self.auto_measure:
self.measure()
def measure(self):
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.mlu.memory_reserved()` .
# After `torch.mlu.reset_peak_memory_stats()`,
# `torch.mlu.memory_reserved()` will keep growing, and only shrink
# when we call `torch.mlu.empty_cache()` or OOM happens.
self.torch_peak = torch.mlu.memory_stats().get(
"allocated_bytes.all.peak", 0)
self.free_memory, self.total_memory = torch.mlu.mem_get_info()
self.mlu_memory = self.total_memory - self.free_memory
# torch.mlu.memory_reserved() is how many bytes
# PyTorch gets from mlu (by calling mluMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.mlu.memory_reserved()
self.non_torch_memory = self.mlu_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
free_memory=self.free_memory - other.free_memory,
total_memory=self.total_memory - other.total_memory,
mlu_memory=self.mlu_memory - other.mlu_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)
@dataclass
class MemoryProfilingResult:
"""Memory profiling result. All numbers are in bytes.
"""
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
def __repr__(self) -> str:
return (f"Memory profiling takes {self.profile_time:.2f} seconds. "
f"Total non KV cache memory: "
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
f"torch peak memory increase: "
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
f"non-torch forward increase memory: "
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.")
@contextlib.contextmanager
def memory_profiling(
baseline_snapshot: MemorySnapshot,
weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.mlu.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.mlu.empty_cache()
torch.mlu.reset_peak_memory_stats()
result = MemoryProfilingResult()
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory = weights_memory
result.before_profile.measure()
yield result
gc.collect()
torch.mlu.empty_cache()
result.after_profile.measure()
diff_profile = result.after_profile - result.before_profile
diff_from_create = result.after_profile - result.before_create
result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
@lru_cache(maxsize=8)
def _mlu_device_count_stateless(
mlu_visible_devices: Optional[str] = None) -> int:
if mlu_visible_devices is None:
return torch.mlu.device_count()
if mlu_visible_devices == "":
return 0
if "," not in mlu_visible_devices:
return 1
return len(mlu_visible_devices.split(","))
def mlu_device_count_stateless() -> int:
"""Get number of MLU devices, caching based on the value of
MLU_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.cuda.device_count()
unless MLU_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _mlu_device_count_stateless(os.environ.get("MLU_VISIBLE_DEVICES", "mlu"))
def vllm__utils_system_utils___maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return
reasons = []
if is_in_ray_actor():
# even if we choose to spawn, we need to pass the ray address
# to the subprocess so that it knows how to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
reasons.append("In a Ray actor and can only be spawned")
'''
=============================
Modify by vllm_mlu
=============================
@brief: Force use spawn for MLU platform.
'''
if cuda_is_initialized():
reasons.append("CUDA is initialized")
elif xpu_is_initialized():
reasons.append("XPU is initialized")
elif current_platform.is_out_of_tree():
reasons.append("MLU is initialized")
'''
==================
End of MLU Hijack
==================
'''
if reasons:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s", reasons)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
'''
=============================
Modify by vllm_mlu
=============================
@brief: change dispatch_key default value from 'CUDA' to 'MLU'
'''
vllm__utils__torch_utils__direct_register_custom_op_org = torch_utils.direct_register_custom_op
def vllm__utils__torch_utils__direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str] | None = [],
fake_impl: Callable | None = None,
target_lib: Library | None = None,
dispatch_key: str = "MLU",
tags: Tuple[torch.Tag, ...] = (),
):
vllm__utils__torch_utils__direct_register_custom_op_org(
op_name=op_name,
op_func=op_func,
mutates_args=mutates_args,
fake_impl=fake_impl,
target_lib=target_lib,
dispatch_key=dispatch_key,
tags=tags,
)
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(torch_utils,
torch_utils.direct_register_custom_op,
vllm__utils__torch_utils__direct_register_custom_op)
MluHijackObject.apply_hijack(system_utils,
system_utils._maybe_force_spawn,
vllm__utils_system_utils___maybe_force_spawn)