from contextlib import contextmanager from dataclasses import fields from typing import Dict, Tuple, List, Optional import torch NUM_BYTES_IN_MB = 1024**2 NUM_BYTES_IN_GB = 1024**3 class MemoryAnalyzer: def __init__( self, model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer] = None ): """This memory usage analyzer will be mostly acurate only if you initialize at the beginning and insert `get_memory_usage_in_gb` at the end of your forward pass. NOTE: It will have negative impact if not properly used as it stores activations of every nn.Module's forward function and relies on user to reset it everytime the forward pass ends. Limitations: 1. does not work with customized operators 2. does not work with functional operators 3. it approximates activation as nn.Module.forward's output (if it's inside the graph requires gradients), so it may not be exactly accurate. """ self.model = model self.optimizer = optimizer self.activ_addrs = set() self.activ_memory = 0 @staticmethod def _is_activation(x): return torch.is_tensor(x) and x.requires_grad and x.device != "cpu" def _get_weight_grads_addrs(self): weights = set([p.untyped_storage().data_ptr() for p in self.model.parameters()]) grads = set( [ p.grad.untyped_storage().data_ptr() for p in self.model.parameters() if p.grad is not None ] ) return weights.union(grads) def pack_hook(self): def _pack_hook(x): if self._is_activation(x): weight_grads = self._get_weight_grads_addrs() # NOTE: storage is more accurate than using x.nelement() * x.element_size() data_ptr = x.untyped_storage().data_ptr() if data_ptr not in weight_grads and data_ptr not in self.activ_addrs: self.activ_addrs.add(data_ptr) self.activ_memory += x.untyped_storage().size() return x return _pack_hook def unpack_hook(self): def _unpack_hook(x): if self._is_activation(x): weight_grads = self._get_weight_grads_addrs() data_ptr = x.untyped_storage().data_ptr() if data_ptr not in weight_grads and data_ptr in self.activ_addrs: self.activ_addrs.remove(data_ptr) self.activ_memory -= x.untyped_storage().size() return x return _unpack_hook @contextmanager def record_activation(self): with torch.autograd.graph.saved_tensors_hooks( self.pack_hook(), self.unpack_hook() ): yield @staticmethod def get_weight_memory(model: torch.nn.Module): weights = [ p.nelement() * p.element_size() for p in model.parameters() if p.device != "cpu" ] return sum(weights) @staticmethod def get_gradient_memory(model: torch.nn.Module): grads = [ p.grad.nelement() * p.grad.element_size() for p in model.parameters() if p.grad is not None and p.grad.device != "cpu" ] return sum(grads) def _sum_activation_memory(self): return self.activ_memory def get_optimizer_state_memory(self): if isinstance(self.optimizer, torch.optim.AdamW): params = sum( [ p.nelement() * p.element_size() for pg in self.optimizer.param_groups for p in pg["params"] if torch.is_tensor(p) and p.device != "cpu" ] ) for state in self.optimizer.state.values(): params += sum( [ v.nelement() * v.element_size() for k, v in state.items() if torch.is_tensor(v) and v.device != "cpu" ] ) return params return 0 def _get_memory_usage(self) -> Tuple[int, int, int, int]: return ( self.get_weight_memory(self.model), self.get_gradient_memory(self.model), self._sum_activation_memory(), self.get_optimizer_state_memory(), ) def get_memory_usage_in_gb(self) -> str: w, g, a, opt = self._get_memory_usage() return ( f"Total: {(w + g + a + opt) / NUM_BYTES_IN_GB:.3f} GB, " f"weight: {w / NUM_BYTES_IN_GB:.3f} GB, " f"gradient: {g / NUM_BYTES_IN_GB:.3f} GB, " f"activation: {a / NUM_BYTES_IN_GB:.3f} GB, " f"optimizer states: {opt / NUM_BYTES_IN_GB:.3f} GB" ) def get_memory_usage_in_mb(self) -> str: w, g, a, opt = self._get_memory_usage() return ( f"Total: {(w + g + a + opt) / NUM_BYTES_IN_MB:.2f} MB, " f"weight: {w / NUM_BYTES_IN_MB:.2f} MB, " f"gradient: {g / NUM_BYTES_IN_MB:.2f} MB, " f"activation: {a / NUM_BYTES_IN_MB:.2f} MB, " f"optimizer states: {opt / NUM_BYTES_IN_MB:.2f} MB" )