First commit
This commit is contained in:
17
pkgs/xformers/profiler/__init__.py
Normal file
17
pkgs/xformers/profiler/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .api import profile, step
|
||||
from .profiler import MemSnapshotsProfiler, NsightProfiler, PyTorchProfiler
|
||||
from .slow_ops_profiler import DetectSlowOpsProfiler
|
||||
|
||||
__all__ = [
|
||||
"profile",
|
||||
"step",
|
||||
"MemSnapshotsProfiler",
|
||||
"PyTorchProfiler",
|
||||
"NsightProfiler",
|
||||
"DetectSlowOpsProfiler",
|
||||
]
|
||||
BIN
pkgs/xformers/profiler/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/profiler/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/profiler/__pycache__/api.cpython-310.pyc
Normal file
BIN
pkgs/xformers/profiler/__pycache__/api.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/profiler/__pycache__/device_limits.cpython-310.pyc
Normal file
BIN
pkgs/xformers/profiler/__pycache__/device_limits.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/profiler/__pycache__/profiler.cpython-310.pyc
Normal file
BIN
pkgs/xformers/profiler/__pycache__/profiler.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
94
pkgs/xformers/profiler/api.py
Normal file
94
pkgs/xformers/profiler/api.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .profiler import (
|
||||
MemSnapshotsProfiler,
|
||||
NsightProfiler,
|
||||
PyTorchProfiler,
|
||||
PyTorchProfiler_CUDAOnly,
|
||||
_Profiler,
|
||||
)
|
||||
from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401
|
||||
|
||||
DEFAULT_SCHEDULE = (
|
||||
(MemSnapshotsProfiler, 0, 2),
|
||||
(NsightProfiler, 4, 6),
|
||||
(PyTorchProfiler, 6, 7),
|
||||
(PyTorchProfiler_CUDAOnly, 7, 8),
|
||||
# TODO: There are some issues in PyTorch stable
|
||||
# which are now fixed on main, but might break this profiler
|
||||
# https://github.com/pytorch/pytorch/issues/94403
|
||||
# (DetectSlowOpsProfiler, 9, 10),
|
||||
)
|
||||
|
||||
|
||||
def profile(
|
||||
output_dir: str,
|
||||
module: Optional[nn.Module] = None,
|
||||
schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE,
|
||||
):
|
||||
"""
|
||||
A pre-configured profiler that will run on the first ~20 steps of the training
|
||||
It will provide multiple traces that can be exploited later.
|
||||
Use it in a context manager around your training loop, and call `xformers.profiler.step`
|
||||
before starting the next iteration.
|
||||
|
||||
:Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import timm.models
|
||||
import xformers.profiler
|
||||
|
||||
dtype = torch.bfloat16
|
||||
device = "cuda"
|
||||
model = timm.models.vit_large_patch16_224().to(device).to(dtype)
|
||||
inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype)
|
||||
optim = torch.optim.Adam(model.parameters())
|
||||
|
||||
with xformers.profiler.profile(
|
||||
output_dir="profile_data",
|
||||
module=model,
|
||||
schedule=[
|
||||
(MemSnapshotsProfiler, 0, 2),
|
||||
(DetectSlowOpsProfiler, 2, 4),
|
||||
(NsightProfiler, 4, 6),
|
||||
(PyTorchProfiler, 6, 20),
|
||||
]
|
||||
):
|
||||
for i in range(20):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
xformers.profiler.step()
|
||||
|
||||
# alternatively, use the profiler without context and with ``.start()`` / `.stop()`
|
||||
# calls.
|
||||
|
||||
xprofiler = xformers.profiler.profile(...)
|
||||
xprofiler.start()
|
||||
|
||||
for i in range(20):
|
||||
model(inp).sum().backward()
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
xprofiler.step()
|
||||
|
||||
xprofiler.stop()
|
||||
"""
|
||||
return _Profiler(output_dir=output_dir, schedule=schedule, module=module)
|
||||
|
||||
|
||||
def step() -> None:
|
||||
"""See `xformers.profiler.profile`"""
|
||||
# Silently return if no profiler is enabled
|
||||
if _Profiler._CURRENT_PROFILER is None:
|
||||
return
|
||||
_Profiler._CURRENT_PROFILER.step()
|
||||
113
pkgs/xformers/profiler/device_limits.py
Normal file
113
pkgs/xformers/profiler/device_limits.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Mapping, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceLimit:
|
||||
name: str = "default" # pattern to match from `torch.cuda.get_device_name()`
|
||||
source: str = ""
|
||||
sm: Tuple[int, int] = (0, 0)
|
||||
# bytes/s
|
||||
gmem_bandwidth: float = math.inf
|
||||
# dtype -> TFlop/s
|
||||
gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
# For f32, we assume we can use tf32
|
||||
DEVICE_LIMITS: Tuple[DeviceLimit, ...] = (
|
||||
DeviceLimit(
|
||||
"H100",
|
||||
"https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501
|
||||
sm=(9, 0),
|
||||
gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s
|
||||
gemm_tflops={
|
||||
torch.float64: 67,
|
||||
# NOTE: NVIDIA gives all numbers "with 2:4 sparsity"
|
||||
# but we want the full GEMM numbers
|
||||
torch.float32: 989 // 2,
|
||||
torch.float16: 1979 // 2,
|
||||
torch.bfloat16: 1979 // 2,
|
||||
torch.int8: 3958 // 2,
|
||||
},
|
||||
),
|
||||
DeviceLimit(
|
||||
"A100",
|
||||
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501
|
||||
sm=(8, 0),
|
||||
gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s
|
||||
gemm_tflops={
|
||||
torch.float64: 19.5,
|
||||
torch.float32: 156,
|
||||
torch.float16: 312,
|
||||
torch.bfloat16: 312,
|
||||
torch.int8: 624,
|
||||
},
|
||||
),
|
||||
DeviceLimit(
|
||||
"A30",
|
||||
"https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf",
|
||||
sm=(8, 0),
|
||||
gmem_bandwidth=933 * (1024**3),
|
||||
gemm_tflops={
|
||||
torch.float64: 10.3,
|
||||
torch.float32: 82,
|
||||
torch.float16: 165,
|
||||
torch.bfloat16: 165,
|
||||
torch.int8: 330,
|
||||
},
|
||||
),
|
||||
DeviceLimit(
|
||||
"T4",
|
||||
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf",
|
||||
sm=(7, 5),
|
||||
gmem_bandwidth=300 * (1024**3),
|
||||
gemm_tflops={
|
||||
torch.float32: 8.1,
|
||||
torch.float16: 65,
|
||||
torch.int8: 130,
|
||||
},
|
||||
),
|
||||
# Assuming SXM2
|
||||
DeviceLimit(
|
||||
"V100",
|
||||
"https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf",
|
||||
sm=(7, 0),
|
||||
gmem_bandwidth=900 * (1024**3),
|
||||
gemm_tflops={
|
||||
torch.float64: 7.8,
|
||||
torch.float32: 15.7,
|
||||
torch.float16: 125,
|
||||
},
|
||||
),
|
||||
DeviceLimit(
|
||||
"P100",
|
||||
"https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf",
|
||||
sm=(6, 0),
|
||||
gmem_bandwidth=732 * (1024**3),
|
||||
gemm_tflops={
|
||||
torch.float64: 5.3,
|
||||
torch.float32: 10.6,
|
||||
torch.float16: 21.2,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_device_limits(device) -> DeviceLimit:
|
||||
"""Currently only implemented for GPUs"""
|
||||
if device is not None and device.type == "cuda":
|
||||
device_sm = torch.cuda.get_device_capability(device)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
for lim in DEVICE_LIMITS:
|
||||
if lim.sm == device_sm:
|
||||
if lim.name in device_name:
|
||||
return lim
|
||||
return DeviceLimit()
|
||||
363
pkgs/xformers/profiler/profiler.py
Normal file
363
pkgs/xformers/profiler/profiler.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import socket
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch.cuda.memory
|
||||
import torch.cuda.nvtx
|
||||
import torch.nn as nn
|
||||
import torch.profiler
|
||||
import torch.utils.hooks
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _normalize_tuple(x):
|
||||
if not isinstance(x, tuple):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
|
||||
class NsightProfiler:
|
||||
"""Profiler that triggers start of NSight profiler.
|
||||
|
||||
NOTE: you need to ensure that the script running this code actually is running with
|
||||
``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the
|
||||
capturing is performed by this profiler during certain steps.
|
||||
"""
|
||||
|
||||
def __init__(self, main_profiler: "_Profiler") -> None:
|
||||
self.main_profiler = main_profiler
|
||||
# TODO figure out if there is a way to know if nsys is launched at this point
|
||||
|
||||
def __enter__(self):
|
||||
self.main_profiler._install_hooks()
|
||||
torch.cuda.profiler.start()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.cuda.profiler.stop()
|
||||
self.main_profiler._remove_hooks()
|
||||
|
||||
def step(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchProfiler:
|
||||
"""Profiler which relies on native Pytorch profiling. Current setting of the profiler
|
||||
captures traces, memory footprint and other info that could be read via TensorBoard.
|
||||
"""
|
||||
|
||||
ACTIVITIES = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
|
||||
def __init__(self, main_profiler: "_Profiler") -> None:
|
||||
self.main_profiler = main_profiler
|
||||
activities_str = "_".join(a.name for a in self.ACTIVITIES)
|
||||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||
dir_name=str(
|
||||
main_profiler.output_dir
|
||||
/ f"profile_{activities_str}_{main_profiler.done_steps:06}"
|
||||
),
|
||||
worker_name=main_profiler.worker_name,
|
||||
use_gzip=True,
|
||||
)
|
||||
self.hta = torch.profiler.profile(
|
||||
on_trace_ready=trace_handler,
|
||||
profile_memory=True,
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
activities=self.ACTIVITIES,
|
||||
)
|
||||
self.done_steps = 0
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self.hta.__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
torch.cuda.synchronize()
|
||||
self.hta.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def step(self) -> None:
|
||||
self.hta.step()
|
||||
self.done_steps += 1
|
||||
|
||||
|
||||
class PyTorchProfiler_CUDAOnly(PyTorchProfiler):
|
||||
# This profiler does not profile the CPU-side of things
|
||||
# so we expect it to have almost no overhead
|
||||
ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA]
|
||||
|
||||
|
||||
class MemSnapshotsProfiler:
|
||||
"""Profiler that captures memory traces for allocation and deallocation of memory for
|
||||
tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, main_profiler: "_Profiler") -> None:
|
||||
self.main_profiler = main_profiler
|
||||
self.enabled = False
|
||||
|
||||
@property
|
||||
def _has_trace_plot(self) -> bool:
|
||||
return hasattr(torch.cuda._memory_viz, "trace_plot")
|
||||
|
||||
def __enter__(self):
|
||||
if not self._has_trace_plot:
|
||||
return
|
||||
self.enabled = True
|
||||
# TODO: This does not show the previous memory allocations
|
||||
# We could at least have a placeholder with how much
|
||||
# memory was allocated before
|
||||
torch.cuda.memory._record_memory_history(
|
||||
True,
|
||||
# keep 100,000 alloc/free events from before the snapshot
|
||||
trace_alloc_max_entries=100000,
|
||||
# record stack information for the trace events
|
||||
trace_alloc_record_context=True,
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self._has_trace_plot:
|
||||
self.main_profiler.summary.append(
|
||||
("MemTrace", "(not available with your Pytorch version)")
|
||||
)
|
||||
return
|
||||
assert self.enabled
|
||||
snapshot = torch.cuda.memory._snapshot()
|
||||
torch.cuda.memory._record_memory_history(False)
|
||||
# No data was recorded - avoids a `ValueError` in `trace_plot`
|
||||
if all(len(t) == 0 for t in snapshot["device_traces"]):
|
||||
self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)"))
|
||||
return
|
||||
# Dump to disk
|
||||
filename = self.main_profiler._create_output_filename("memory_trace_plot.html")
|
||||
self.main_profiler.summary.append(("MemTrace", filename))
|
||||
with open(filename, "w+") as fd:
|
||||
fd.write(
|
||||
torch.cuda._memory_viz.trace_plot(
|
||||
snapshot, device=None, plot_segments=False
|
||||
)
|
||||
)
|
||||
|
||||
def step(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ProfilerState:
|
||||
cls: Any
|
||||
iter_begin: int
|
||||
iter_end: int
|
||||
object: Any = None
|
||||
|
||||
|
||||
class _Profiler:
|
||||
_CURRENT_PROFILER = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
schedule: Sequence[Tuple[Any, int, int]],
|
||||
module: Optional[nn.Module],
|
||||
) -> None:
|
||||
self.check_schedule(schedule)
|
||||
self.done_steps = 0
|
||||
self.output_dir = Path(output_dir).absolute()
|
||||
self.output_dir.mkdir(exist_ok=True, parents=True)
|
||||
self.worker_name = ""
|
||||
if torch.distributed.is_initialized():
|
||||
self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
|
||||
|
||||
self.module = weakref.ref(module if module is not None else nn.Module())
|
||||
self.parents = ["Global"]
|
||||
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
|
||||
self.hooks_refcount = 0
|
||||
self.profilers: List[_ProfilerState] = sorted(
|
||||
[_ProfilerState(cls, begin, end) for cls, begin, end in schedule],
|
||||
key=lambda x: x.iter_begin,
|
||||
)
|
||||
self.last_step = self.profilers[-1].iter_end if self.profilers else 0
|
||||
self.summary: List[Tuple[str, str]] = []
|
||||
|
||||
def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None:
|
||||
if len(schedule) == 0:
|
||||
logger.warning(
|
||||
"You specified empty schedule for profiling. No data will be captured."
|
||||
)
|
||||
|
||||
pq: Any = queue.PriorityQueue()
|
||||
for cls, begin, end in schedule:
|
||||
assert (
|
||||
begin >= 0
|
||||
), f"Begin step of profiler must be non-negative, found: {begin}"
|
||||
assert end > 0, f"End step of profiler must be positive, found: {end}"
|
||||
assert (
|
||||
begin < end
|
||||
), f"Start must be before the end, found: begin={begin} and end={end}"
|
||||
|
||||
pq.put((begin, end))
|
||||
|
||||
prev_end = -1
|
||||
for begin, end in pq.queue:
|
||||
assert begin >= prev_end, (
|
||||
"There is some overlapping in profiler scheduling. Please do not"
|
||||
+ " overlap profilers by step as they may affect each other. Schedule:"
|
||||
+ f" {schedule}"
|
||||
)
|
||||
prev_end = end
|
||||
|
||||
def update_profilers_on_step(self) -> None:
|
||||
for p in self.profilers:
|
||||
if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end:
|
||||
if p.object is None:
|
||||
o = p.cls(self)
|
||||
logging.info(f"Starting {p.cls.__name__} profiler...")
|
||||
o.__enter__()
|
||||
p.object = o
|
||||
else:
|
||||
p.object.step()
|
||||
else:
|
||||
if p.object is not None:
|
||||
o = p.object
|
||||
p.object = None
|
||||
logging.info(f"Shutting down {p.cls.__name__} profiler...")
|
||||
o.__exit__(None, None, None)
|
||||
|
||||
def _create_output_filename(self, filename: str) -> Path:
|
||||
"""
|
||||
Returns where to write a file with desired filename.
|
||||
Handles the case where we are in distributed settings, or when
|
||||
we need to output the same file multiple times (eg if a profiler
|
||||
runs for several steps)
|
||||
"""
|
||||
if self.worker_name != "":
|
||||
file = Path(filename)
|
||||
folder = self.output_dir / file.stem
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}"
|
||||
return self.output_dir / f"{self.done_steps:06}_{filename}"
|
||||
|
||||
def _install_hooks(self) -> None:
|
||||
self.hooks_refcount += 1
|
||||
# Already installed
|
||||
if self.hooks:
|
||||
return
|
||||
module = self.module()
|
||||
if module is None:
|
||||
return
|
||||
for name, sub_mod in module.named_modules():
|
||||
if name == "":
|
||||
continue
|
||||
name = name.split(".")[-1]
|
||||
self.hooks += [
|
||||
sub_mod.register_forward_pre_hook(self._enter_module_hook(name)),
|
||||
sub_mod.register_forward_hook(self._exit_module_hook(name)),
|
||||
]
|
||||
|
||||
def _remove_hooks(self) -> None:
|
||||
self.hooks_refcount -= 1
|
||||
if self.hooks_refcount == 0:
|
||||
for h in self.hooks:
|
||||
h.remove()
|
||||
|
||||
def _enter_module_hook(self, name):
|
||||
class PopState(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
self._exit_module(name)
|
||||
return grad_outs
|
||||
|
||||
def f(module, inputs):
|
||||
self._enter_module(name)
|
||||
inputs = _normalize_tuple(inputs)
|
||||
out = PopState.apply(*inputs)
|
||||
return out
|
||||
|
||||
return f
|
||||
|
||||
def _exit_module_hook(self, name):
|
||||
class PushState(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *args):
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outs):
|
||||
self._enter_module(name)
|
||||
return grad_outs
|
||||
|
||||
def f(module, inputs, outputs):
|
||||
self._exit_module(name)
|
||||
outputs = _normalize_tuple(outputs)
|
||||
return PushState.apply(*outputs)
|
||||
|
||||
return f
|
||||
|
||||
def _enter_module(self, name) -> None:
|
||||
self.parents.append(name)
|
||||
torch.cuda.nvtx.range_push(name)
|
||||
|
||||
def _exit_module(self, name) -> None:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
assert self.parents[-1] == name
|
||||
self.parents.pop()
|
||||
|
||||
def start(self):
|
||||
self.__enter__()
|
||||
|
||||
def stop(self, exc_type=None, exc_val=None, exc_tb=None):
|
||||
self.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __enter__(self):
|
||||
if _Profiler._CURRENT_PROFILER is not None:
|
||||
raise ValueError("Only one xformers profiler can be active at a time")
|
||||
_Profiler._CURRENT_PROFILER = self
|
||||
self.update_profilers_on_step()
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
_Profiler._CURRENT_PROFILER = None
|
||||
|
||||
for p in self.profilers:
|
||||
if p.object is not None:
|
||||
p.object.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def step(self) -> None:
|
||||
"""Signals the profiler that the next profiling step has started."""
|
||||
self.done_steps += 1
|
||||
|
||||
if self.done_steps <= self.last_step:
|
||||
self.parents = ["Global"]
|
||||
self.update_profilers_on_step()
|
||||
if self.done_steps == self.last_step:
|
||||
logger.info("xFormers profiler done. %s", self.format_summary())
|
||||
|
||||
def format_summary(self) -> str:
|
||||
if len(self.summary) == 0:
|
||||
return ""
|
||||
pad_titles = max(len(title) for title, value in self.summary)
|
||||
return "summary:\n" + "\n".join(
|
||||
[f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary]
|
||||
)
|
||||
513
pkgs/xformers/profiler/slow_ops_profiler.py
Normal file
513
pkgs/xformers/profiler/slow_ops_profiler.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
import torch.cuda.memory
|
||||
import torch.cuda.nvtx
|
||||
import torch.profiler
|
||||
import torch.utils.hooks
|
||||
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ..ops.common import FUNC_TO_XFORMERS_OPERATOR
|
||||
from .device_limits import get_device_limits
|
||||
from .profiler import _Profiler
|
||||
|
||||
|
||||
class TorchFuncMockNoDispatch:
|
||||
"""
|
||||
Wraps a method to call it without the custom
|
||||
pytorch dispatcher
|
||||
"""
|
||||
|
||||
def __init__(self, pt_impl):
|
||||
self.pt_impl = pt_impl
|
||||
|
||||
def __get__(self, obj, c):
|
||||
return partial(self, obj)
|
||||
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
with _pop_mode_temporarily():
|
||||
return self.pt_impl(obj, *args, **kwargs)
|
||||
|
||||
|
||||
class DispatcherWithoutBrokenFuncs(TorchDispatchMode):
|
||||
TENSOR_FUNCS_NO_DISPATCH = [
|
||||
# Can't convert Stream argument to Python object
|
||||
# https://github.com/pytorch/pytorch/issues/94403
|
||||
"record_stream"
|
||||
]
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self._pt_impls = {}
|
||||
for k in self.TENSOR_FUNCS_NO_DISPATCH:
|
||||
impl = getattr(torch.Tensor, k)
|
||||
self._pt_impls[k] = impl
|
||||
setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl))
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
for k in self.TENSOR_FUNCS_NO_DISPATCH:
|
||||
setattr(torch.Tensor, k, self._pt_impls[k])
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
def get_shape(i):
|
||||
return i.shape
|
||||
|
||||
|
||||
def prod(x):
|
||||
res = 1
|
||||
for i in x:
|
||||
res *= i
|
||||
return res
|
||||
|
||||
|
||||
class GemmOpComputeFlops:
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1])
|
||||
|
||||
def __call__(self, inputs: List[Any], outputs: List[Any]) -> float:
|
||||
return 2 * prod(self._get_mnk(inputs))
|
||||
|
||||
def op_suffix(self, inputs: List[Any]) -> str:
|
||||
m, n, k = self._get_mnk(inputs)
|
||||
return f"_{m}x{n}x{k}"
|
||||
|
||||
|
||||
class GemmOpComputeFlopsLinear(GemmOpComputeFlops):
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1])
|
||||
|
||||
|
||||
class GemmOpComputeFlopsMv(GemmOpComputeFlops):
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1])
|
||||
|
||||
|
||||
class GemmOpComputeFlopsBmm(GemmOpComputeFlops):
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
a, b = inputs[0], inputs[1]
|
||||
assert a.ndim == 3
|
||||
assert b.ndim == 3
|
||||
bs = max(inputs[0].shape[0], inputs[1].shape[0])
|
||||
return (bs * a.shape[1], b.shape[-1], b.shape[-2])
|
||||
|
||||
|
||||
class GemmOpComputeFlopsAddmm(GemmOpComputeFlops):
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
return super()._get_mnk(inputs[1:])
|
||||
|
||||
|
||||
class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm):
|
||||
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
||||
return super()._get_mnk(inputs[1:])
|
||||
|
||||
|
||||
def conv_flop_count(
|
||||
x_shape: List[int],
|
||||
w_shape: List[int],
|
||||
out_shape: List[int],
|
||||
transposed: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Count flops for convolution. Note only multiplication is
|
||||
counted. Computation for addition and bias is ignored.
|
||||
Flops for a transposed convolution are calculated as
|
||||
flops = (x_shape[2:] * prod(w_shape) * batch_size).
|
||||
Args:
|
||||
x_shape (list(int)): The input shape before convolution.
|
||||
w_shape (list(int)): The filter shape.
|
||||
out_shape (list(int)): The output shape after convolution.
|
||||
transposed (bool): is the convolution transposed
|
||||
Returns:
|
||||
int: the number of flops
|
||||
"""
|
||||
batch_size = x_shape[0]
|
||||
conv_shape = (x_shape if transposed else out_shape)[2:]
|
||||
flop = batch_size * prod(w_shape) * prod(conv_shape)
|
||||
return flop
|
||||
|
||||
|
||||
def conv_flop(inputs: List[Any], outputs: List[Any]):
|
||||
"""
|
||||
Count flops for convolution.
|
||||
"""
|
||||
x, w = inputs[:2]
|
||||
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
|
||||
transposed = inputs[6]
|
||||
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
|
||||
|
||||
def transpose_shape(shape):
|
||||
return [shape[1], shape[0]] + list(shape[2:])
|
||||
|
||||
|
||||
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
|
||||
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
|
||||
output_mask = inputs[-1]
|
||||
fwd_transposed = inputs[7]
|
||||
flop_count = 0.0
|
||||
|
||||
if output_mask[0]:
|
||||
grad_input_shape = get_shape(outputs[0])
|
||||
flop_count += conv_flop_count(
|
||||
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
|
||||
)
|
||||
if output_mask[1]:
|
||||
grad_weight_shape = get_shape(outputs[1])
|
||||
flop_count += conv_flop_count(
|
||||
transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed
|
||||
)
|
||||
|
||||
return flop_count
|
||||
|
||||
|
||||
def tensor_storage_size_in_mem(x: torch.Tensor):
|
||||
total = 1
|
||||
for dim_sz, stride in zip(x.shape, x.stride()):
|
||||
if stride >= 1:
|
||||
total *= dim_sz
|
||||
return total
|
||||
|
||||
|
||||
def get_size(inputs: List[Any]):
|
||||
total_bytes = 0
|
||||
|
||||
def process(x) -> None:
|
||||
nonlocal total_bytes
|
||||
if isinstance(x, torch.Tensor):
|
||||
total_bytes += tensor_storage_size_in_mem(x) * x.element_size()
|
||||
|
||||
tree_map(process, inputs)
|
||||
return total_bytes
|
||||
|
||||
|
||||
def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]):
|
||||
size_input, size_output = get_size(inputs), get_size(outputs)
|
||||
return size_input + size_output
|
||||
|
||||
|
||||
def output_read_from_input(inputs: List[Any], outputs: List[Any]):
|
||||
size_input, size_output = get_size(inputs), get_size(outputs)
|
||||
return size_output + min(size_input, size_output)
|
||||
|
||||
|
||||
def output_total_size(inputs: List[Any], outputs: List[Any]):
|
||||
return get_size(outputs)
|
||||
|
||||
|
||||
def input_total_size(inputs: List[Any], outputs: List[Any]):
|
||||
return get_size(inputs)
|
||||
|
||||
|
||||
def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]):
|
||||
# Approximation that isn't too bad
|
||||
total_elements = 0
|
||||
|
||||
def process(x) -> None:
|
||||
nonlocal total_elements
|
||||
if isinstance(x, torch.Tensor):
|
||||
total_elements += x.numel()
|
||||
|
||||
tree_map(process, inputs)
|
||||
tree_map(process, outputs)
|
||||
return total_elements / 2
|
||||
|
||||
|
||||
def no_flop(inputs: List[Any], outputs: List[Any]):
|
||||
return 0
|
||||
|
||||
|
||||
def no_io(inputs: List[Any], outputs: List[Any]):
|
||||
return 0
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
NO_FLOPS_NO_IO_OPS = [
|
||||
aten.permute,
|
||||
aten.view,
|
||||
aten.view_as,
|
||||
aten.detach,
|
||||
aten.t,
|
||||
aten.transpose,
|
||||
aten.expand,
|
||||
aten._unsafe_view,
|
||||
aten.select,
|
||||
aten.split,
|
||||
aten.split_with_sizes,
|
||||
aten.empty,
|
||||
aten.empty_strided,
|
||||
aten.empty_like,
|
||||
aten.is_same_size,
|
||||
]
|
||||
NO_FLOPS_OPS = [
|
||||
aten._reshape_alias,
|
||||
aten.reshape,
|
||||
aten.clone,
|
||||
aten.cat,
|
||||
aten.select_backward,
|
||||
aten.slice,
|
||||
aten.slice_backward,
|
||||
aten.ones,
|
||||
aten.ones_like,
|
||||
aten.zeros_like,
|
||||
aten.zero_,
|
||||
aten.zeros,
|
||||
aten.masked_fill,
|
||||
aten.masked_fill_,
|
||||
]
|
||||
|
||||
flop_mapping = {
|
||||
aten.mv: GemmOpComputeFlopsMv(), # mat-vec
|
||||
aten.mm: GemmOpComputeFlops(),
|
||||
aten.matmul: GemmOpComputeFlops(),
|
||||
aten.addmm: GemmOpComputeFlopsAddmm(),
|
||||
aten.bmm: GemmOpComputeFlopsBmm(),
|
||||
aten.addbmm: GemmOpComputeFlopsAddbmm(),
|
||||
aten.linear: GemmOpComputeFlopsLinear(),
|
||||
aten.convolution: conv_flop,
|
||||
aten._convolution: conv_flop,
|
||||
aten.convolution_backward: conv_backward_flop,
|
||||
# Operations with 0 flop
|
||||
**{op: no_flop for op in NO_FLOPS_OPS},
|
||||
**{op: no_flop for op in NO_FLOPS_NO_IO_OPS},
|
||||
}
|
||||
io_mapping = {
|
||||
aten.clone: output_read_from_input,
|
||||
aten.cat: output_read_from_input,
|
||||
aten.slice: output_read_from_input,
|
||||
aten.ones_like: output_total_size,
|
||||
aten.zeros_like: output_total_size,
|
||||
aten.zero_: input_total_size,
|
||||
**{op: no_io for op in NO_FLOPS_NO_IO_OPS}
|
||||
# TODO: Check how this is implemented in PT
|
||||
# aten.slice_backward: no_flop,
|
||||
# aten.select_backward: no_flop,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OpInfo:
|
||||
flop_count: float = 0.0
|
||||
time_ms: float = 0.0
|
||||
io_bytes: int = 0
|
||||
is_exact_flop: bool = True
|
||||
op_name: str = ""
|
||||
op_suffix: str = ""
|
||||
stacktrace: Tuple[str, ...] = field(default_factory=tuple)
|
||||
ev_start: torch.cuda.Event = field(
|
||||
default_factory=lambda: torch.cuda.Event(enable_timing=True)
|
||||
)
|
||||
ev_end: torch.cuda.Event = field(
|
||||
default_factory=lambda: torch.cuda.Event(enable_timing=True)
|
||||
)
|
||||
|
||||
# Hardware limits for this operation (inf if unknown)
|
||||
hardware_tflops_limit: float = math.inf
|
||||
hardware_membw_limit: float = math.inf
|
||||
|
||||
@property
|
||||
def time_membound_ms(self) -> float:
|
||||
assert self.time_ms > 0.0
|
||||
if self.io_bytes == 0:
|
||||
return 0.0
|
||||
return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit)
|
||||
|
||||
@property
|
||||
def time_computebound_ms(self) -> float:
|
||||
assert self.time_ms > 0.0
|
||||
tflop = self.flop_count / (1000**4)
|
||||
if tflop == 0.0:
|
||||
return 0.0
|
||||
return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit)
|
||||
|
||||
def finalize(self) -> None:
|
||||
self.time_ms = self.ev_start.elapsed_time(self.ev_end)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OpInfoAggregated:
|
||||
is_exact_flop: bool = True
|
||||
total_flop_count: float = 0.0
|
||||
total_io_bytes: int = 0
|
||||
total_time_ms: float = 0.0
|
||||
total_time_membound_ms: float = 0.0
|
||||
total_time_computebound_ms: float = 0.0
|
||||
num: int = 0
|
||||
stacktraces: List[Tuple[str, ...]] = field(default_factory=list)
|
||||
|
||||
def add(self, op: _OpInfo) -> None:
|
||||
self.total_flop_count += op.flop_count
|
||||
self.total_time_ms += op.time_ms
|
||||
self.total_io_bytes += op.io_bytes
|
||||
self.total_time_membound_ms += op.time_membound_ms
|
||||
self.total_time_computebound_ms += op.time_computebound_ms
|
||||
self.num += 1
|
||||
self.is_exact_flop = op.is_exact_flop
|
||||
self.stacktraces.append(op.stacktrace)
|
||||
|
||||
def as_dict(self, **kwargs) -> Dict[str, Any]:
|
||||
mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms)
|
||||
tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4)
|
||||
compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms)
|
||||
return {
|
||||
"is_exact_flop": self.is_exact_flop,
|
||||
"total_flop_count": self.total_flop_count,
|
||||
"total_time_ms": self.total_time_ms,
|
||||
"total_io_bytes": self.total_io_bytes,
|
||||
"num": self.num,
|
||||
"Tflops": tflops,
|
||||
"mem_bound": mem_bound,
|
||||
"compute_bound": compute_bound,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs):
|
||||
"""
|
||||
Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/
|
||||
"""
|
||||
|
||||
def __init__(self, main_profiler: _Profiler) -> None:
|
||||
self.main_profiler = main_profiler
|
||||
self.trace: List[_OpInfo] = []
|
||||
self.temp_disabled = False
|
||||
|
||||
def _hardware_tflops_membw_limit(
|
||||
self, args: Tuple[Any, ...], outputs: Tuple[Any, ...]
|
||||
) -> Tuple[float, float]:
|
||||
device = None
|
||||
dtypes: List[torch.dtype] = []
|
||||
for a in itertools.chain(outputs, args):
|
||||
if isinstance(a, torch.Tensor):
|
||||
if device is None:
|
||||
device = a.device
|
||||
dtypes.append(a.dtype)
|
||||
limits = get_device_limits(device)
|
||||
dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops]
|
||||
if not dtypes or device is None:
|
||||
return (math.inf, math.inf)
|
||||
dtype = dtypes[0]
|
||||
if torch.is_autocast_enabled() and dtype is torch.float32:
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
return limits.gemm_tflops[dtype], limits.gmem_bandwidth
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
func_packet = func._overloadpacket
|
||||
if self.temp_disabled or func_packet.__name__ in [
|
||||
"_record_function_exit",
|
||||
"_record_function_enter_new",
|
||||
]:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
op = _OpInfo()
|
||||
op.ev_start.record()
|
||||
out = func(*args, **kwargs)
|
||||
op.ev_end.record()
|
||||
|
||||
(
|
||||
op.hardware_tflops_limit,
|
||||
op.hardware_membw_limit,
|
||||
) = self._hardware_tflops_membw_limit(
|
||||
args, out if isinstance(out, tuple) else (out,)
|
||||
)
|
||||
op.op_name = func_packet.__name__
|
||||
# Prevent functions called by flop counting ops to be recorded
|
||||
self.temp_disabled = True
|
||||
flop_count = -1
|
||||
compute_flops = None
|
||||
if func_packet in FUNC_TO_XFORMERS_OPERATOR:
|
||||
flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop(
|
||||
*args, **kwargs
|
||||
)
|
||||
if flop_count == -1:
|
||||
compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op)
|
||||
flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,))
|
||||
if isinstance(compute_flops, GemmOpComputeFlops):
|
||||
op.op_name += compute_flops.op_suffix(args)
|
||||
|
||||
compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes)
|
||||
op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,))
|
||||
self.temp_disabled = False
|
||||
|
||||
op.stacktrace = tuple(self.main_profiler.parents)
|
||||
op.flop_count = flop_count
|
||||
op.is_exact_flop = compute_flops is not guess_flops_unknown_op
|
||||
self.trace.append(op)
|
||||
|
||||
return out
|
||||
|
||||
def __enter__(self):
|
||||
self.main_profiler._install_hooks()
|
||||
super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
self.main_profiler._remove_hooks()
|
||||
torch.cuda.synchronize() # Wait for the events to be recorded
|
||||
for op in self.trace:
|
||||
op.finalize()
|
||||
self.save_json()
|
||||
|
||||
def step(self) -> None:
|
||||
pass
|
||||
|
||||
def save_json(self) -> None:
|
||||
# Aggregate data at the module + op level
|
||||
all_paths: Set[Tuple[str, ...]] = set()
|
||||
per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict(
|
||||
_OpInfoAggregated
|
||||
)
|
||||
per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated)
|
||||
for op in self.trace:
|
||||
all_paths.add(op.stacktrace)
|
||||
for op in self.trace:
|
||||
for i in range(len(op.stacktrace)):
|
||||
if op.stacktrace[: i + 1] in all_paths:
|
||||
per_module_data[op.stacktrace[: i + 1]].add(op)
|
||||
per_op_data[op.op_name].add(op)
|
||||
|
||||
# Generate JSON
|
||||
all_data = []
|
||||
for stacktrace, agg_info in per_module_data.items():
|
||||
all_data.append(
|
||||
agg_info.as_dict(
|
||||
agg="module", path=stacktrace, name=stacktrace[-1], op=""
|
||||
)
|
||||
)
|
||||
for op_name, agg_info in per_op_data.items():
|
||||
# Find the most common path
|
||||
paths_count: Dict[Tuple[str, ...], int] = defaultdict(int)
|
||||
agg_info.stacktraces.sort() # In case of a draw, let's always return the same
|
||||
for p in agg_info.stacktraces:
|
||||
paths_count[p] += 1
|
||||
maxp = agg_info.stacktraces[0]
|
||||
for p, count in paths_count.items():
|
||||
if count > paths_count[maxp]:
|
||||
maxp = p
|
||||
all_data.append(
|
||||
agg_info.as_dict(
|
||||
agg="opname",
|
||||
path=f"{'.'.join(maxp)} (x{paths_count[maxp]})",
|
||||
name="",
|
||||
op=op_name,
|
||||
)
|
||||
)
|
||||
|
||||
filename = self.main_profiler._create_output_filename("ops.json")
|
||||
self.main_profiler.summary.append(("OpsSummary", str(filename)))
|
||||
with open(filename, "w+") as f:
|
||||
json.dump(all_data, f)
|
||||
Reference in New Issue
Block a user