Files
2025-08-05 19:02:46 +08:00

364 lines
12 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import queue
import socket
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple
import torch.cuda.memory
import torch.cuda.nvtx
import torch.nn as nn
import torch.profiler
import torch.utils.hooks
logger = logging.getLogger(__name__)
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class NsightProfiler:
"""Profiler that triggers start of NSight profiler.
NOTE: you need to ensure that the script running this code actually is running with
``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the
capturing is performed by this profiler during certain steps.
"""
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
# TODO figure out if there is a way to know if nsys is launched at this point
def __enter__(self):
self.main_profiler._install_hooks()
torch.cuda.profiler.start()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.profiler.stop()
self.main_profiler._remove_hooks()
def step(self) -> None:
pass
class PyTorchProfiler:
"""Profiler which relies on native Pytorch profiling. Current setting of the profiler
captures traces, memory footprint and other info that could be read via TensorBoard.
"""
ACTIVITIES = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
activities_str = "_".join(a.name for a in self.ACTIVITIES)
trace_handler = torch.profiler.tensorboard_trace_handler(
dir_name=str(
main_profiler.output_dir
/ f"profile_{activities_str}_{main_profiler.done_steps:06}"
),
worker_name=main_profiler.worker_name,
use_gzip=True,
)
self.hta = torch.profiler.profile(
on_trace_ready=trace_handler,
profile_memory=True,
record_shapes=True,
with_stack=True,
activities=self.ACTIVITIES,
)
self.done_steps = 0
def __enter__(self):
torch.cuda.synchronize()
self.hta.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.hta.__exit__(exc_type, exc_val, exc_tb)
def step(self) -> None:
self.hta.step()
self.done_steps += 1
class PyTorchProfiler_CUDAOnly(PyTorchProfiler):
# This profiler does not profile the CPU-side of things
# so we expect it to have almost no overhead
ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA]
class MemSnapshotsProfiler:
"""Profiler that captures memory traces for allocation and deallocation of memory for
tensors.
"""
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
self.enabled = False
@property
def _has_trace_plot(self) -> bool:
return hasattr(torch.cuda._memory_viz, "trace_plot")
def __enter__(self):
if not self._has_trace_plot:
return
self.enabled = True
# TODO: This does not show the previous memory allocations
# We could at least have a placeholder with how much
# memory was allocated before
torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True,
)
def __exit__(self, exc_type, exc_val, exc_tb):
if not self._has_trace_plot:
self.main_profiler.summary.append(
("MemTrace", "(not available with your Pytorch version)")
)
return
assert self.enabled
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._record_memory_history(False)
# No data was recorded - avoids a `ValueError` in `trace_plot`
if all(len(t) == 0 for t in snapshot["device_traces"]):
self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)"))
return
# Dump to disk
filename = self.main_profiler._create_output_filename("memory_trace_plot.html")
self.main_profiler.summary.append(("MemTrace", filename))
with open(filename, "w+") as fd:
fd.write(
torch.cuda._memory_viz.trace_plot(
snapshot, device=None, plot_segments=False
)
)
def step(self) -> None:
pass
@dataclass
class _ProfilerState:
cls: Any
iter_begin: int
iter_end: int
object: Any = None
class _Profiler:
_CURRENT_PROFILER = None
def __init__(
self,
output_dir: str,
schedule: Sequence[Tuple[Any, int, int]],
module: Optional[nn.Module],
) -> None:
self.check_schedule(schedule)
self.done_steps = 0
self.output_dir = Path(output_dir).absolute()
self.output_dir.mkdir(exist_ok=True, parents=True)
self.worker_name = ""
if torch.distributed.is_initialized():
self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
self.module = weakref.ref(module if module is not None else nn.Module())
self.parents = ["Global"]
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self.hooks_refcount = 0
self.profilers: List[_ProfilerState] = sorted(
[_ProfilerState(cls, begin, end) for cls, begin, end in schedule],
key=lambda x: x.iter_begin,
)
self.last_step = self.profilers[-1].iter_end if self.profilers else 0
self.summary: List[Tuple[str, str]] = []
def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None:
if len(schedule) == 0:
logger.warning(
"You specified empty schedule for profiling. No data will be captured."
)
pq: Any = queue.PriorityQueue()
for cls, begin, end in schedule:
assert (
begin >= 0
), f"Begin step of profiler must be non-negative, found: {begin}"
assert end > 0, f"End step of profiler must be positive, found: {end}"
assert (
begin < end
), f"Start must be before the end, found: begin={begin} and end={end}"
pq.put((begin, end))
prev_end = -1
for begin, end in pq.queue:
assert begin >= prev_end, (
"There is some overlapping in profiler scheduling. Please do not"
+ " overlap profilers by step as they may affect each other. Schedule:"
+ f" {schedule}"
)
prev_end = end
def update_profilers_on_step(self) -> None:
for p in self.profilers:
if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end:
if p.object is None:
o = p.cls(self)
logging.info(f"Starting {p.cls.__name__} profiler...")
o.__enter__()
p.object = o
else:
p.object.step()
else:
if p.object is not None:
o = p.object
p.object = None
logging.info(f"Shutting down {p.cls.__name__} profiler...")
o.__exit__(None, None, None)
def _create_output_filename(self, filename: str) -> Path:
"""
Returns where to write a file with desired filename.
Handles the case where we are in distributed settings, or when
we need to output the same file multiple times (eg if a profiler
runs for several steps)
"""
if self.worker_name != "":
file = Path(filename)
folder = self.output_dir / file.stem
folder.mkdir(parents=True, exist_ok=True)
return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}"
return self.output_dir / f"{self.done_steps:06}_{filename}"
def _install_hooks(self) -> None:
self.hooks_refcount += 1
# Already installed
if self.hooks:
return
module = self.module()
if module is None:
return
for name, sub_mod in module.named_modules():
if name == "":
continue
name = name.split(".")[-1]
self.hooks += [
sub_mod.register_forward_pre_hook(self._enter_module_hook(name)),
sub_mod.register_forward_hook(self._exit_module_hook(name)),
]
def _remove_hooks(self) -> None:
self.hooks_refcount -= 1
if self.hooks_refcount == 0:
for h in self.hooks:
h.remove()
def _enter_module_hook(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self._exit_module(name)
return grad_outs
def f(module, inputs):
self._enter_module(name)
inputs = _normalize_tuple(inputs)
out = PopState.apply(*inputs)
return out
return f
def _exit_module_hook(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self._enter_module(name)
return grad_outs
def f(module, inputs, outputs):
self._exit_module(name)
outputs = _normalize_tuple(outputs)
return PushState.apply(*outputs)
return f
def _enter_module(self, name) -> None:
self.parents.append(name)
torch.cuda.nvtx.range_push(name)
def _exit_module(self, name) -> None:
torch.cuda.nvtx.range_pop()
assert self.parents[-1] == name
self.parents.pop()
def start(self):
self.__enter__()
def stop(self, exc_type=None, exc_val=None, exc_tb=None):
self.__exit__(exc_type, exc_val, exc_tb)
def __enter__(self):
if _Profiler._CURRENT_PROFILER is not None:
raise ValueError("Only one xformers profiler can be active at a time")
_Profiler._CURRENT_PROFILER = self
self.update_profilers_on_step()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_Profiler._CURRENT_PROFILER = None
for p in self.profilers:
if p.object is not None:
p.object.__exit__(exc_type, exc_val, exc_tb)
def step(self) -> None:
"""Signals the profiler that the next profiling step has started."""
self.done_steps += 1
if self.done_steps <= self.last_step:
self.parents = ["Global"]
self.update_profilers_on_step()
if self.done_steps == self.last_step:
logger.info("xFormers profiler done. %s", self.format_summary())
def format_summary(self) -> str:
if len(self.summary) == 0:
return ""
pad_titles = max(len(title) for title, value in self.summary)
return "summary:\n" + "\n".join(
[f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary]
)