init
This commit is contained in:
151
vacc_tools/memory_analyzer.py
Normal file
151
vacc_tools/memory_analyzer.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import fields
|
||||
from typing import Dict, Tuple, List, Optional
|
||||
import torch
|
||||
|
||||
NUM_BYTES_IN_MB = 1024**2
|
||||
NUM_BYTES_IN_GB = 1024**3
|
||||
|
||||
|
||||
class MemoryAnalyzer:
|
||||
def __init__(
|
||||
self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None
|
||||
):
|
||||
"""This memory usage analyzer will be mostly acurate only if you initialize
|
||||
at the beginning and insert `get_memory_usage_in_gb` at the end of your
|
||||
forward pass.
|
||||
|
||||
NOTE: It will have negative impact if not properly used as it stores
|
||||
activations of every nn.Module's forward function and relies on user to
|
||||
reset it everytime the forward pass ends.
|
||||
|
||||
Limitations:
|
||||
1. does not work with customized operators
|
||||
2. does not work with functional operators
|
||||
3. it approximates activation as nn.Module.forward's output (if it's
|
||||
inside the graph requires gradients), so it may not be exactly accurate.
|
||||
"""
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.activ_addrs = set()
|
||||
self.activ_memory = 0
|
||||
|
||||
@staticmethod
|
||||
def _is_activation(x):
|
||||
return torch.is_tensor(x) and x.requires_grad and x.device != "cpu"
|
||||
|
||||
def _get_weight_grads_addrs(self):
|
||||
weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()])
|
||||
grads = set(
|
||||
[
|
||||
p.grad.untyped_storage().data_ptr()
|
||||
for p in self.model.parameters()
|
||||
if p.grad is not None
|
||||
]
|
||||
)
|
||||
return weights.union(grads)
|
||||
|
||||
def pack_hook(self):
|
||||
def _pack_hook(x):
|
||||
if self._is_activation(x):
|
||||
weight_grads = self._get_weight_grads_addrs()
|
||||
# NOTE: storage is more accurate than using x.nelement() * x.element_size()
|
||||
data_ptr = x.untyped_storage().data_ptr()
|
||||
if data_ptr not in weight_grads and data_ptr not in self.activ_addrs:
|
||||
self.activ_addrs.add(data_ptr)
|
||||
self.activ_memory += x.untyped_storage().size()
|
||||
return x
|
||||
|
||||
return _pack_hook
|
||||
|
||||
def unpack_hook(self):
|
||||
def _unpack_hook(x):
|
||||
if self._is_activation(x):
|
||||
weight_grads = self._get_weight_grads_addrs()
|
||||
data_ptr = x.untyped_storage().data_ptr()
|
||||
if data_ptr not in weight_grads and data_ptr in self.activ_addrs:
|
||||
self.activ_addrs.remove(data_ptr)
|
||||
self.activ_memory -= x.untyped_storage().size()
|
||||
return x
|
||||
|
||||
return _unpack_hook
|
||||
|
||||
@contextmanager
|
||||
def record_activation(self):
|
||||
with torch.autograd.graph.saved_tensors_hooks(
|
||||
self.pack_hook(), self.unpack_hook()
|
||||
):
|
||||
yield
|
||||
|
||||
@staticmethod
|
||||
def get_weight_memory(model: torch.nn.Module):
|
||||
weights = [
|
||||
p.nelement() * p.element_size()
|
||||
for p in model.parameters()
|
||||
if p.device != "cpu"
|
||||
]
|
||||
return sum(weights)
|
||||
|
||||
@staticmethod
|
||||
def get_gradient_memory(model: torch.nn.Module):
|
||||
grads = [
|
||||
p.grad.nelement() * p.grad.element_size()
|
||||
for p in model.parameters()
|
||||
if p.grad is not None and p.grad.device != "cpu"
|
||||
]
|
||||
return sum(grads)
|
||||
|
||||
def _sum_activation_memory(self):
|
||||
return self.activ_memory
|
||||
|
||||
def get_optimizer_state_memory(self):
|
||||
if isinstance(self.optimizer, torch.optim.AdamW):
|
||||
params = sum(
|
||||
[
|
||||
p.nelement() * p.element_size()
|
||||
for pg in self.optimizer.param_groups
|
||||
for p in pg["params"]
|
||||
if torch.is_tensor(p) and p.device != "cpu"
|
||||
]
|
||||
)
|
||||
for state in self.optimizer.state.values():
|
||||
params += sum(
|
||||
[
|
||||
v.nelement() * v.element_size()
|
||||
for k, v in state.items()
|
||||
if torch.is_tensor(v) and v.device != "cpu"
|
||||
]
|
||||
)
|
||||
return params
|
||||
return 0
|
||||
|
||||
def _get_memory_usage(self) -> Tuple[int, int, int, int]:
|
||||
return (
|
||||
self.get_weight_memory(self.model),
|
||||
self.get_gradient_memory(self.model),
|
||||
self._sum_activation_memory(),
|
||||
self.get_optimizer_state_memory(),
|
||||
)
|
||||
|
||||
def get_memory_usage_in_gb(self) -> str:
|
||||
w, g, a, opt = self._get_memory_usage()
|
||||
|
||||
return (
|
||||
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, "
|
||||
f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB"
|
||||
)
|
||||
|
||||
def get_memory_usage_in_mb(self) -> str:
|
||||
w, g, a, opt = self._get_memory_usage()
|
||||
|
||||
return (
|
||||
f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, "
|
||||
f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB"
|
||||
)
|
||||
Reference in New Issue
Block a user