301 lines
10 KiB
Python
301 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import gc
|
|
import os
|
|
import time
|
|
import torch
|
|
from torch.library import Library
|
|
from dataclasses import dataclass, field
|
|
from functools import lru_cache
|
|
from typing import Optional, Callable, Tuple, Generator
|
|
|
|
import vllm.envs as envs
|
|
from vllm.platforms import current_platform
|
|
from vllm.ray.lazy_utils import is_in_ray_actor
|
|
from vllm.utils import (
|
|
torch_utils,
|
|
system_utils,
|
|
)
|
|
from vllm.utils.torch_utils import (
|
|
STR_DTYPE_TO_TORCH_DTYPE,
|
|
supports_custom_op,
|
|
vllm_lib,
|
|
)
|
|
from vllm.utils.mem_utils import GiB_bytes
|
|
from vllm.utils.platform_utils import (
|
|
cuda_is_initialized,
|
|
xpu_is_initialized,
|
|
)
|
|
from vllm.logger import init_logger
|
|
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
STR_DTYPE_TO_TORCH_DTYPE["int8"] = torch.int8
|
|
|
|
|
|
@dataclass
|
|
class MemorySnapshot:
|
|
"""Memory snapshot."""
|
|
torch_peak: int = 0
|
|
free_memory: int = 0
|
|
total_memory: int = 0
|
|
mlu_memory: int = 0
|
|
torch_memory: int = 0
|
|
non_torch_memory: int = 0
|
|
timestamp: float = 0.0
|
|
auto_measure: bool = True
|
|
|
|
def __post_init__(self):
|
|
if self.auto_measure:
|
|
self.measure()
|
|
|
|
def measure(self):
|
|
# we measure the torch peak memory usage via allocated_bytes,
|
|
# rather than `torch.mlu.memory_reserved()` .
|
|
# After `torch.mlu.reset_peak_memory_stats()`,
|
|
# `torch.mlu.memory_reserved()` will keep growing, and only shrink
|
|
# when we call `torch.mlu.empty_cache()` or OOM happens.
|
|
self.torch_peak = torch.mlu.memory_stats().get(
|
|
"allocated_bytes.all.peak", 0)
|
|
|
|
self.free_memory, self.total_memory = torch.mlu.mem_get_info()
|
|
self.mlu_memory = self.total_memory - self.free_memory
|
|
|
|
# torch.mlu.memory_reserved() is how many bytes
|
|
# PyTorch gets from mlu (by calling mluMalloc, etc.)
|
|
# this is used to measure the non-torch memory usage
|
|
self.torch_memory = torch.mlu.memory_reserved()
|
|
|
|
self.non_torch_memory = self.mlu_memory - self.torch_memory
|
|
self.timestamp = time.time()
|
|
|
|
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
|
|
return MemorySnapshot(
|
|
torch_peak=self.torch_peak - other.torch_peak,
|
|
free_memory=self.free_memory - other.free_memory,
|
|
total_memory=self.total_memory - other.total_memory,
|
|
mlu_memory=self.mlu_memory - other.mlu_memory,
|
|
torch_memory=self.torch_memory - other.torch_memory,
|
|
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
|
|
timestamp=self.timestamp - other.timestamp,
|
|
auto_measure=False,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MemoryProfilingResult:
|
|
"""Memory profiling result. All numbers are in bytes.
|
|
"""
|
|
non_kv_cache_memory: int = 0
|
|
torch_peak_increase: int = 0
|
|
non_torch_increase: int = 0
|
|
weights_memory: float = 0
|
|
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
profile_time: float = 0.0
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
|
f"Total non KV cache memory: "
|
|
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
|
|
f"torch peak memory increase: "
|
|
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
|
|
f"non-torch forward increase memory: "
|
|
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
|
|
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.")
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def memory_profiling(
|
|
baseline_snapshot: MemorySnapshot,
|
|
weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
|
|
"""Memory profiling context manager.
|
|
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
|
weights_memory: memory used by PyTorch when loading the model weights.
|
|
Note that, before loading the model weights, we also initialize the device
|
|
and distributed environment, which may consume some memory. This part is not
|
|
included in the weights_memory because PyTorch does not control it.
|
|
|
|
The memory in one GPU can be classified into 3 categories:
|
|
1. memory used by anything other than the current vLLM instance.
|
|
2. memory used by torch in the current vLLM instance.
|
|
3. memory used in the current vLLM instance, but not by torch.
|
|
|
|
A quantitive example:
|
|
|
|
Before creating the current vLLM instance:
|
|
category 1: 1 GiB
|
|
category 2: 0 GiB
|
|
category 3: 0 GiB
|
|
|
|
After creating the current vLLM instance and loading the model,
|
|
(i.e. before profiling):
|
|
category 1: 1 GiB
|
|
category 2: 2 GiB (model weights take 2 GiB)
|
|
category 3: 0.5 GiB (memory used by NCCL)
|
|
|
|
During profiling (peak):
|
|
category 1: 1 GiB
|
|
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
|
|
|
After profiling:
|
|
category 1: 1 GiB
|
|
category 2: 3 GiB (after garbage-collecting activation tensors)
|
|
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
|
|
|
In this case, non-kv cache takes 5 GiB in total, including:
|
|
a. 2 GiB used by the model weights (category 2)
|
|
b. 2 GiB reserved for the peak activation tensors (category 2)
|
|
c. 1 GiB used by non-torch components (category 3)
|
|
|
|
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
|
|
|
The increase of `torch.mlu.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
|
|
|
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
|
""" # noqa
|
|
gc.collect()
|
|
torch.mlu.empty_cache()
|
|
torch.mlu.reset_peak_memory_stats()
|
|
|
|
result = MemoryProfilingResult()
|
|
|
|
result.before_create = baseline_snapshot
|
|
# the part of memory used for holding the model weights
|
|
result.weights_memory = weights_memory
|
|
|
|
result.before_profile.measure()
|
|
|
|
yield result
|
|
|
|
gc.collect()
|
|
torch.mlu.empty_cache()
|
|
|
|
result.after_profile.measure()
|
|
|
|
diff_profile = result.after_profile - result.before_profile
|
|
diff_from_create = result.after_profile - result.before_create
|
|
result.torch_peak_increase = diff_profile.torch_peak
|
|
result.non_torch_increase = diff_from_create.non_torch_memory
|
|
result.profile_time = diff_profile.timestamp
|
|
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa
|
|
|
|
|
|
@lru_cache(maxsize=8)
|
|
def _mlu_device_count_stateless(
|
|
mlu_visible_devices: Optional[str] = None) -> int:
|
|
|
|
if mlu_visible_devices is None:
|
|
return torch.mlu.device_count()
|
|
if mlu_visible_devices == "":
|
|
return 0
|
|
if "," not in mlu_visible_devices:
|
|
return 1
|
|
return len(mlu_visible_devices.split(","))
|
|
|
|
|
|
def mlu_device_count_stateless() -> int:
|
|
"""Get number of MLU devices, caching based on the value of
|
|
MLU_VISIBLE_DEVICES at the time of call.
|
|
|
|
This should be used instead of torch.cuda.device_count()
|
|
unless MLU_VISIBLE_DEVICES has already been set to the desired
|
|
value."""
|
|
|
|
# This can be removed and simply replaced with torch.cuda.get_device_count
|
|
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
|
return _mlu_device_count_stateless(os.environ.get("MLU_VISIBLE_DEVICES", "mlu"))
|
|
|
|
|
|
def vllm__utils_system_utils___maybe_force_spawn():
|
|
"""Check if we need to force the use of the `spawn` multiprocessing start
|
|
method.
|
|
"""
|
|
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
|
|
return
|
|
|
|
reasons = []
|
|
if is_in_ray_actor():
|
|
# even if we choose to spawn, we need to pass the ray address
|
|
# to the subprocess so that it knows how to connect to the ray cluster.
|
|
# env vars are inherited by subprocesses, even if we use spawn.
|
|
import ray
|
|
|
|
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
|
|
reasons.append("In a Ray actor and can only be spawned")
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: Force use spawn for MLU platform.
|
|
'''
|
|
if cuda_is_initialized():
|
|
reasons.append("CUDA is initialized")
|
|
elif xpu_is_initialized():
|
|
reasons.append("XPU is initialized")
|
|
elif current_platform.is_out_of_tree():
|
|
reasons.append("MLU is initialized")
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
if reasons:
|
|
logger.warning(
|
|
"We must use the `spawn` multiprocessing start method. "
|
|
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
|
"See https://docs.vllm.ai/en/latest/getting_started/"
|
|
"troubleshooting.html#python-multiprocessing "
|
|
"for more information. Reason: %s", reasons)
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: change dispatch_key default value from 'CUDA' to 'MLU'
|
|
'''
|
|
vllm__utils__torch_utils__direct_register_custom_op_org = torch_utils.direct_register_custom_op
|
|
def vllm__utils__torch_utils__direct_register_custom_op(
|
|
op_name: str,
|
|
op_func: Callable,
|
|
mutates_args: list[str] | None = [],
|
|
fake_impl: Callable | None = None,
|
|
target_lib: Library | None = None,
|
|
dispatch_key: str = "MLU",
|
|
tags: Tuple[torch.Tag, ...] = (),
|
|
):
|
|
vllm__utils__torch_utils__direct_register_custom_op_org(
|
|
op_name=op_name,
|
|
op_func=op_func,
|
|
mutates_args=mutates_args,
|
|
fake_impl=fake_impl,
|
|
target_lib=target_lib,
|
|
dispatch_key=dispatch_key,
|
|
tags=tags,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
MluHijackObject.apply_hijack(torch_utils,
|
|
torch_utils.direct_register_custom_op,
|
|
vllm__utils__torch_utils__direct_register_custom_op)
|
|
MluHijackObject.apply_hijack(system_utils,
|
|
system_utils._maybe_force_spawn,
|
|
vllm__utils_system_utils___maybe_force_spawn) |