Files
enginex-vastai-va16-vllm/vacc_tools/memory_analyzer.py
2026-04-02 04:55:00 +00:00

152 lines
5.2 KiB
Python

from contextlib import contextmanager
from dataclasses import fields
from typing import Dict, Tuple, List, Optional
import torch
NUM_BYTES_IN_MB = 1024**2
NUM_BYTES_IN_GB = 1024**3
class MemoryAnalyzer:
def __init__(
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
):
"""This memory usage analyzer will be mostly acurate only if you initialize
at the beginning and insert `get_memory_usage_in_gb` at the end of your
forward pass.
NOTE: It will have negative impact if not properly used as it stores
activations of every nn.Module's forward function and relies on user to
reset it everytime the forward pass ends.
Limitations:
1. does not work with customized operators
2. does not work with functional operators
3. it approximates activation as nn.Module.forward's output (if it's
inside the graph requires gradients), so it may not be exactly accurate.
"""
self.model = model
self.optimizer = optimizer
self.activ_addrs = set()
self.activ_memory = 0
@staticmethod
def _is_activation(x):
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
def _get_weight_grads_addrs(self):
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
grads = set(
[
p.grad.untyped_storage().data_ptr()
for p in self.model.parameters()
if p.grad is not None
]
)
return weights.union(grads)
def pack_hook(self):
def _pack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
self.activ_addrs.add(data_ptr)
self.activ_memory += x.untyped_storage().size()
return x
return _pack_hook
def unpack_hook(self):
def _unpack_hook(x):
if self._is_activation(x):
weight_grads = self._get_weight_grads_addrs()
data_ptr = x.untyped_storage().data_ptr()
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
self.activ_addrs.remove(data_ptr)
self.activ_memory -= x.untyped_storage().size()
return x
return _unpack_hook
@contextmanager
def record_activation(self):
with torch.autograd.graph.saved_tensors_hooks(
self.pack_hook(), self.unpack_hook()
):
yield
@staticmethod
def get_weight_memory(model: torch.nn.Module):
weights = [
p.nelement() * p.element_size()
for p in model.parameters()
if p.device != "cpu"
]
return sum(weights)
@staticmethod
def get_gradient_memory(model: torch.nn.Module):
grads = [
p.grad.nelement() * p.grad.element_size()
for p in model.parameters()
if p.grad is not None and p.grad.device != "cpu"
]
return sum(grads)
def _sum_activation_memory(self):
return self.activ_memory
def get_optimizer_state_memory(self):
if isinstance(self.optimizer, torch.optim.AdamW):
params = sum(
[
p.nelement() * p.element_size()
for pg in self.optimizer.param_groups
for p in pg["params"]
if torch.is_tensor(p) and p.device != "cpu"
]
)
for state in self.optimizer.state.values():
params += sum(
[
v.nelement() * v.element_size()
for k, v in state.items()
if torch.is_tensor(v) and v.device != "cpu"
]
)
return params
return 0
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
return (
self.get_weight_memory(self.model),
self.get_gradient_memory(self.model),
self._sum_activation_memory(),
self.get_optimizer_state_memory(),
)
def get_memory_usage_in_gb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
)
def get_memory_usage_in_mb(self) -> str:
w, g, a, opt = self._get_memory_usage()
return (
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
)