# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import logging import os import queue import socket import weakref from dataclasses import dataclass from pathlib import Path from typing import Any, List, Optional, Sequence, Tuple import torch.cuda.memory import torch.cuda.nvtx import torch.nn as nn import torch.profiler import torch.utils.hooks logger = logging.getLogger(__name__) def _normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x class NsightProfiler: """Profiler that triggers start of NSight profiler. NOTE: you need to ensure that the script running this code actually is running with ``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the capturing is performed by this profiler during certain steps. """ def __init__(self, main_profiler: "_Profiler") -> None: self.main_profiler = main_profiler # TODO figure out if there is a way to know if nsys is launched at this point def __enter__(self): self.main_profiler._install_hooks() torch.cuda.profiler.start() def __exit__(self, exc_type, exc_val, exc_tb): torch.cuda.profiler.stop() self.main_profiler._remove_hooks() def step(self) -> None: pass class PyTorchProfiler: """Profiler which relies on native Pytorch profiling. Current setting of the profiler captures traces, memory footprint and other info that could be read via TensorBoard. """ ACTIVITIES = [ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] def __init__(self, main_profiler: "_Profiler") -> None: self.main_profiler = main_profiler activities_str = "_".join(a.name for a in self.ACTIVITIES) trace_handler = torch.profiler.tensorboard_trace_handler( dir_name=str( main_profiler.output_dir / f"profile_{activities_str}_{main_profiler.done_steps:06}" ), worker_name=main_profiler.worker_name, use_gzip=True, ) self.hta = torch.profiler.profile( on_trace_ready=trace_handler, profile_memory=True, record_shapes=True, with_stack=True, activities=self.ACTIVITIES, ) self.done_steps = 0 def __enter__(self): torch.cuda.synchronize() self.hta.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): torch.cuda.synchronize() self.hta.__exit__(exc_type, exc_val, exc_tb) def step(self) -> None: self.hta.step() self.done_steps += 1 class PyTorchProfiler_CUDAOnly(PyTorchProfiler): # This profiler does not profile the CPU-side of things # so we expect it to have almost no overhead ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA] class MemSnapshotsProfiler: """Profiler that captures memory traces for allocation and deallocation of memory for tensors. """ def __init__(self, main_profiler: "_Profiler") -> None: self.main_profiler = main_profiler self.enabled = False @property def _has_trace_plot(self) -> bool: return hasattr(torch.cuda._memory_viz, "trace_plot") def __enter__(self): if not self._has_trace_plot: return self.enabled = True # TODO: This does not show the previous memory allocations # We could at least have a placeholder with how much # memory was allocated before torch.cuda.memory._record_memory_history( True, # keep 100,000 alloc/free events from before the snapshot trace_alloc_max_entries=100000, # record stack information for the trace events trace_alloc_record_context=True, ) def __exit__(self, exc_type, exc_val, exc_tb): if not self._has_trace_plot: self.main_profiler.summary.append( ("MemTrace", "(not available with your Pytorch version)") ) return assert self.enabled snapshot = torch.cuda.memory._snapshot() torch.cuda.memory._record_memory_history(False) # No data was recorded - avoids a `ValueError` in `trace_plot` if all(len(t) == 0 for t in snapshot["device_traces"]): self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)")) return # Dump to disk filename = self.main_profiler._create_output_filename("memory_trace_plot.html") self.main_profiler.summary.append(("MemTrace", filename)) with open(filename, "w+") as fd: fd.write( torch.cuda._memory_viz.trace_plot( snapshot, device=None, plot_segments=False ) ) def step(self) -> None: pass @dataclass class _ProfilerState: cls: Any iter_begin: int iter_end: int object: Any = None class _Profiler: _CURRENT_PROFILER = None def __init__( self, output_dir: str, schedule: Sequence[Tuple[Any, int, int]], module: Optional[nn.Module], ) -> None: self.check_schedule(schedule) self.done_steps = 0 self.output_dir = Path(output_dir).absolute() self.output_dir.mkdir(exist_ok=True, parents=True) self.worker_name = "" if torch.distributed.is_initialized(): self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid())) self.module = weakref.ref(module if module is not None else nn.Module()) self.parents = ["Global"] self.hooks: List[torch.utils.hooks.RemovableHandle] = [] self.hooks_refcount = 0 self.profilers: List[_ProfilerState] = sorted( [_ProfilerState(cls, begin, end) for cls, begin, end in schedule], key=lambda x: x.iter_begin, ) self.last_step = self.profilers[-1].iter_end if self.profilers else 0 self.summary: List[Tuple[str, str]] = [] def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None: if len(schedule) == 0: logger.warning( "You specified empty schedule for profiling. No data will be captured." ) pq: Any = queue.PriorityQueue() for cls, begin, end in schedule: assert ( begin >= 0 ), f"Begin step of profiler must be non-negative, found: {begin}" assert end > 0, f"End step of profiler must be positive, found: {end}" assert ( begin < end ), f"Start must be before the end, found: begin={begin} and end={end}" pq.put((begin, end)) prev_end = -1 for begin, end in pq.queue: assert begin >= prev_end, ( "There is some overlapping in profiler scheduling. Please do not" + " overlap profilers by step as they may affect each other. Schedule:" + f" {schedule}" ) prev_end = end def update_profilers_on_step(self) -> None: for p in self.profilers: if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end: if p.object is None: o = p.cls(self) logging.info(f"Starting {p.cls.__name__} profiler...") o.__enter__() p.object = o else: p.object.step() else: if p.object is not None: o = p.object p.object = None logging.info(f"Shutting down {p.cls.__name__} profiler...") o.__exit__(None, None, None) def _create_output_filename(self, filename: str) -> Path: """ Returns where to write a file with desired filename. Handles the case where we are in distributed settings, or when we need to output the same file multiple times (eg if a profiler runs for several steps) """ if self.worker_name != "": file = Path(filename) folder = self.output_dir / file.stem folder.mkdir(parents=True, exist_ok=True) return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}" return self.output_dir / f"{self.done_steps:06}_{filename}" def _install_hooks(self) -> None: self.hooks_refcount += 1 # Already installed if self.hooks: return module = self.module() if module is None: return for name, sub_mod in module.named_modules(): if name == "": continue name = name.split(".")[-1] self.hooks += [ sub_mod.register_forward_pre_hook(self._enter_module_hook(name)), sub_mod.register_forward_hook(self._exit_module_hook(name)), ] def _remove_hooks(self) -> None: self.hooks_refcount -= 1 if self.hooks_refcount == 0: for h in self.hooks: h.remove() def _enter_module_hook(self, name): class PopState(torch.autograd.Function): @staticmethod def forward(ctx, *args): if len(args) == 1: return args[0] return args @staticmethod def backward(ctx, *grad_outs): self._exit_module(name) return grad_outs def f(module, inputs): self._enter_module(name) inputs = _normalize_tuple(inputs) out = PopState.apply(*inputs) return out return f def _exit_module_hook(self, name): class PushState(torch.autograd.Function): @staticmethod def forward(ctx, *args): if len(args) == 1: return args[0] return args @staticmethod def backward(ctx, *grad_outs): self._enter_module(name) return grad_outs def f(module, inputs, outputs): self._exit_module(name) outputs = _normalize_tuple(outputs) return PushState.apply(*outputs) return f def _enter_module(self, name) -> None: self.parents.append(name) torch.cuda.nvtx.range_push(name) def _exit_module(self, name) -> None: torch.cuda.nvtx.range_pop() assert self.parents[-1] == name self.parents.pop() def start(self): self.__enter__() def stop(self, exc_type=None, exc_val=None, exc_tb=None): self.__exit__(exc_type, exc_val, exc_tb) def __enter__(self): if _Profiler._CURRENT_PROFILER is not None: raise ValueError("Only one xformers profiler can be active at a time") _Profiler._CURRENT_PROFILER = self self.update_profilers_on_step() return self def __exit__(self, exc_type, exc_val, exc_tb): _Profiler._CURRENT_PROFILER = None for p in self.profilers: if p.object is not None: p.object.__exit__(exc_type, exc_val, exc_tb) def step(self) -> None: """Signals the profiler that the next profiling step has started.""" self.done_steps += 1 if self.done_steps <= self.last_step: self.parents = ["Global"] self.update_profilers_on_step() if self.done_steps == self.last_step: logger.info("xFormers profiler done. %s", self.format_summary()) def format_summary(self) -> str: if len(self.summary) == 0: return "" pad_titles = max(len(title) for title, value in self.summary) return "summary:\n" + "\n".join( [f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary] )