152 lines
5.2 KiB
Python
152 lines
5.2 KiB
Python
from contextlib import contextmanager
|
|
from dataclasses import fields
|
|
from typing import Dict, Tuple, List, Optional
|
|
import torch
|
|
|
|
NUM_BYTES_IN_MB = 1024**2
|
|
NUM_BYTES_IN_GB = 1024**3
|
|
|
|
|
|
class MemoryAnalyzer:
|
|
def __init__(
|
|
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
|
|
):
|
|
"""This memory usage analyzer will be mostly acurate only if you initialize
|
|
at the beginning and insert `get_memory_usage_in_gb` at the end of your
|
|
forward pass.
|
|
|
|
NOTE: It will have negative impact if not properly used as it stores
|
|
activations of every nn.Module's forward function and relies on user to
|
|
reset it everytime the forward pass ends.
|
|
|
|
Limitations:
|
|
1. does not work with customized operators
|
|
2. does not work with functional operators
|
|
3. it approximates activation as nn.Module.forward's output (if it's
|
|
inside the graph requires gradients), so it may not be exactly accurate.
|
|
"""
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
|
|
self.activ_addrs = set()
|
|
self.activ_memory = 0
|
|
|
|
@staticmethod
|
|
def _is_activation(x):
|
|
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
|
|
|
|
def _get_weight_grads_addrs(self):
|
|
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
|
|
grads = set(
|
|
[
|
|
p.grad.untyped_storage().data_ptr()
|
|
for p in self.model.parameters()
|
|
if p.grad is not None
|
|
]
|
|
)
|
|
return weights.union(grads)
|
|
|
|
def pack_hook(self):
|
|
def _pack_hook(x):
|
|
if self._is_activation(x):
|
|
weight_grads = self._get_weight_grads_addrs()
|
|
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
|
|
data_ptr = x.untyped_storage().data_ptr()
|
|
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
|
|
self.activ_addrs.add(data_ptr)
|
|
self.activ_memory += x.untyped_storage().size()
|
|
return x
|
|
|
|
return _pack_hook
|
|
|
|
def unpack_hook(self):
|
|
def _unpack_hook(x):
|
|
if self._is_activation(x):
|
|
weight_grads = self._get_weight_grads_addrs()
|
|
data_ptr = x.untyped_storage().data_ptr()
|
|
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
|
|
self.activ_addrs.remove(data_ptr)
|
|
self.activ_memory -= x.untyped_storage().size()
|
|
return x
|
|
|
|
return _unpack_hook
|
|
|
|
@contextmanager
|
|
def record_activation(self):
|
|
with torch.autograd.graph.saved_tensors_hooks(
|
|
self.pack_hook(), self.unpack_hook()
|
|
):
|
|
yield
|
|
|
|
@staticmethod
|
|
def get_weight_memory(model: torch.nn.Module):
|
|
weights = [
|
|
p.nelement() * p.element_size()
|
|
for p in model.parameters()
|
|
if p.device != "cpu"
|
|
]
|
|
return sum(weights)
|
|
|
|
@staticmethod
|
|
def get_gradient_memory(model: torch.nn.Module):
|
|
grads = [
|
|
p.grad.nelement() * p.grad.element_size()
|
|
for p in model.parameters()
|
|
if p.grad is not None and p.grad.device != "cpu"
|
|
]
|
|
return sum(grads)
|
|
|
|
def _sum_activation_memory(self):
|
|
return self.activ_memory
|
|
|
|
def get_optimizer_state_memory(self):
|
|
if isinstance(self.optimizer, torch.optim.AdamW):
|
|
params = sum(
|
|
[
|
|
p.nelement() * p.element_size()
|
|
for pg in self.optimizer.param_groups
|
|
for p in pg["params"]
|
|
if torch.is_tensor(p) and p.device != "cpu"
|
|
]
|
|
)
|
|
for state in self.optimizer.state.values():
|
|
params += sum(
|
|
[
|
|
v.nelement() * v.element_size()
|
|
for k, v in state.items()
|
|
if torch.is_tensor(v) and v.device != "cpu"
|
|
]
|
|
)
|
|
return params
|
|
return 0
|
|
|
|
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
|
|
return (
|
|
self.get_weight_memory(self.model),
|
|
self.get_gradient_memory(self.model),
|
|
self._sum_activation_memory(),
|
|
self.get_optimizer_state_memory(),
|
|
)
|
|
|
|
def get_memory_usage_in_gb(self) -> str:
|
|
w, g, a, opt = self._get_memory_usage()
|
|
|
|
return (
|
|
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
|
|
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
|
|
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
|
|
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
|
|
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
|
|
)
|
|
|
|
def get_memory_usage_in_mb(self) -> str:
|
|
w, g, a, opt = self._get_memory_usage()
|
|
|
|
return (
|
|
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
|
|
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
|
|
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
|
|
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
|
|
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
|
|
)
|