First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .api import profile, step
from .profiler import MemSnapshotsProfiler, NsightProfiler, PyTorchProfiler
from .slow_ops_profiler import DetectSlowOpsProfiler
__all__ = [
"profile",
"step",
"MemSnapshotsProfiler",
"PyTorchProfiler",
"NsightProfiler",
"DetectSlowOpsProfiler",
]

Binary file not shown.

View File

@@ -0,0 +1,94 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional, Sequence, Tuple
import torch.nn as nn
from .profiler import (
MemSnapshotsProfiler,
NsightProfiler,
PyTorchProfiler,
PyTorchProfiler_CUDAOnly,
_Profiler,
)
from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401
DEFAULT_SCHEDULE = (
(MemSnapshotsProfiler, 0, 2),
(NsightProfiler, 4, 6),
(PyTorchProfiler, 6, 7),
(PyTorchProfiler_CUDAOnly, 7, 8),
# TODO: There are some issues in PyTorch stable
# which are now fixed on main, but might break this profiler
# https://github.com/pytorch/pytorch/issues/94403
# (DetectSlowOpsProfiler, 9, 10),
)
def profile(
output_dir: str,
module: Optional[nn.Module] = None,
schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE,
):
"""
A pre-configured profiler that will run on the first ~20 steps of the training
It will provide multiple traces that can be exploited later.
Use it in a context manager around your training loop, and call `xformers.profiler.step`
before starting the next iteration.
:Examples:
.. code-block:: python
import torch
import timm.models
import xformers.profiler
dtype = torch.bfloat16
device = "cuda"
model = timm.models.vit_large_patch16_224().to(device).to(dtype)
inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype)
optim = torch.optim.Adam(model.parameters())
with xformers.profiler.profile(
output_dir="profile_data",
module=model,
schedule=[
(MemSnapshotsProfiler, 0, 2),
(DetectSlowOpsProfiler, 2, 4),
(NsightProfiler, 4, 6),
(PyTorchProfiler, 6, 20),
]
):
for i in range(20):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
xformers.profiler.step()
# alternatively, use the profiler without context and with ``.start()`` / `.stop()`
# calls.
xprofiler = xformers.profiler.profile(...)
xprofiler.start()
for i in range(20):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
xprofiler.step()
xprofiler.stop()
"""
return _Profiler(output_dir=output_dir, schedule=schedule, module=module)
def step() -> None:
"""See `xformers.profiler.profile`"""
# Silently return if no profiler is enabled
if _Profiler._CURRENT_PROFILER is None:
return
_Profiler._CURRENT_PROFILER.step()

View File

@@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import Mapping, Tuple
import torch
@dataclass
class DeviceLimit:
name: str = "default" # pattern to match from `torch.cuda.get_device_name()`
source: str = ""
sm: Tuple[int, int] = (0, 0)
# bytes/s
gmem_bandwidth: float = math.inf
# dtype -> TFlop/s
gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict)
# For f32, we assume we can use tf32
DEVICE_LIMITS: Tuple[DeviceLimit, ...] = (
DeviceLimit(
"H100",
"https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501
sm=(9, 0),
gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s
gemm_tflops={
torch.float64: 67,
# NOTE: NVIDIA gives all numbers "with 2:4 sparsity"
# but we want the full GEMM numbers
torch.float32: 989 // 2,
torch.float16: 1979 // 2,
torch.bfloat16: 1979 // 2,
torch.int8: 3958 // 2,
},
),
DeviceLimit(
"A100",
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501
sm=(8, 0),
gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s
gemm_tflops={
torch.float64: 19.5,
torch.float32: 156,
torch.float16: 312,
torch.bfloat16: 312,
torch.int8: 624,
},
),
DeviceLimit(
"A30",
"https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf",
sm=(8, 0),
gmem_bandwidth=933 * (1024**3),
gemm_tflops={
torch.float64: 10.3,
torch.float32: 82,
torch.float16: 165,
torch.bfloat16: 165,
torch.int8: 330,
},
),
DeviceLimit(
"T4",
"https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf",
sm=(7, 5),
gmem_bandwidth=300 * (1024**3),
gemm_tflops={
torch.float32: 8.1,
torch.float16: 65,
torch.int8: 130,
},
),
# Assuming SXM2
DeviceLimit(
"V100",
"https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf",
sm=(7, 0),
gmem_bandwidth=900 * (1024**3),
gemm_tflops={
torch.float64: 7.8,
torch.float32: 15.7,
torch.float16: 125,
},
),
DeviceLimit(
"P100",
"https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf",
sm=(6, 0),
gmem_bandwidth=732 * (1024**3),
gemm_tflops={
torch.float64: 5.3,
torch.float32: 10.6,
torch.float16: 21.2,
},
),
)
def get_device_limits(device) -> DeviceLimit:
"""Currently only implemented for GPUs"""
if device is not None and device.type == "cuda":
device_sm = torch.cuda.get_device_capability(device)
device_name = torch.cuda.get_device_name(device)
for lim in DEVICE_LIMITS:
if lim.sm == device_sm:
if lim.name in device_name:
return lim
return DeviceLimit()

View File

@@ -0,0 +1,363 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import queue
import socket
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple
import torch.cuda.memory
import torch.cuda.nvtx
import torch.nn as nn
import torch.profiler
import torch.utils.hooks
logger = logging.getLogger(__name__)
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
class NsightProfiler:
"""Profiler that triggers start of NSight profiler.
NOTE: you need to ensure that the script running this code actually is running with
``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the
capturing is performed by this profiler during certain steps.
"""
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
# TODO figure out if there is a way to know if nsys is launched at this point
def __enter__(self):
self.main_profiler._install_hooks()
torch.cuda.profiler.start()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.profiler.stop()
self.main_profiler._remove_hooks()
def step(self) -> None:
pass
class PyTorchProfiler:
"""Profiler which relies on native Pytorch profiling. Current setting of the profiler
captures traces, memory footprint and other info that could be read via TensorBoard.
"""
ACTIVITIES = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
activities_str = "_".join(a.name for a in self.ACTIVITIES)
trace_handler = torch.profiler.tensorboard_trace_handler(
dir_name=str(
main_profiler.output_dir
/ f"profile_{activities_str}_{main_profiler.done_steps:06}"
),
worker_name=main_profiler.worker_name,
use_gzip=True,
)
self.hta = torch.profiler.profile(
on_trace_ready=trace_handler,
profile_memory=True,
record_shapes=True,
with_stack=True,
activities=self.ACTIVITIES,
)
self.done_steps = 0
def __enter__(self):
torch.cuda.synchronize()
self.hta.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
self.hta.__exit__(exc_type, exc_val, exc_tb)
def step(self) -> None:
self.hta.step()
self.done_steps += 1
class PyTorchProfiler_CUDAOnly(PyTorchProfiler):
# This profiler does not profile the CPU-side of things
# so we expect it to have almost no overhead
ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA]
class MemSnapshotsProfiler:
"""Profiler that captures memory traces for allocation and deallocation of memory for
tensors.
"""
def __init__(self, main_profiler: "_Profiler") -> None:
self.main_profiler = main_profiler
self.enabled = False
@property
def _has_trace_plot(self) -> bool:
return hasattr(torch.cuda._memory_viz, "trace_plot")
def __enter__(self):
if not self._has_trace_plot:
return
self.enabled = True
# TODO: This does not show the previous memory allocations
# We could at least have a placeholder with how much
# memory was allocated before
torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000,
# record stack information for the trace events
trace_alloc_record_context=True,
)
def __exit__(self, exc_type, exc_val, exc_tb):
if not self._has_trace_plot:
self.main_profiler.summary.append(
("MemTrace", "(not available with your Pytorch version)")
)
return
assert self.enabled
snapshot = torch.cuda.memory._snapshot()
torch.cuda.memory._record_memory_history(False)
# No data was recorded - avoids a `ValueError` in `trace_plot`
if all(len(t) == 0 for t in snapshot["device_traces"]):
self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)"))
return
# Dump to disk
filename = self.main_profiler._create_output_filename("memory_trace_plot.html")
self.main_profiler.summary.append(("MemTrace", filename))
with open(filename, "w+") as fd:
fd.write(
torch.cuda._memory_viz.trace_plot(
snapshot, device=None, plot_segments=False
)
)
def step(self) -> None:
pass
@dataclass
class _ProfilerState:
cls: Any
iter_begin: int
iter_end: int
object: Any = None
class _Profiler:
_CURRENT_PROFILER = None
def __init__(
self,
output_dir: str,
schedule: Sequence[Tuple[Any, int, int]],
module: Optional[nn.Module],
) -> None:
self.check_schedule(schedule)
self.done_steps = 0
self.output_dir = Path(output_dir).absolute()
self.output_dir.mkdir(exist_ok=True, parents=True)
self.worker_name = ""
if torch.distributed.is_initialized():
self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid()))
self.module = weakref.ref(module if module is not None else nn.Module())
self.parents = ["Global"]
self.hooks: List[torch.utils.hooks.RemovableHandle] = []
self.hooks_refcount = 0
self.profilers: List[_ProfilerState] = sorted(
[_ProfilerState(cls, begin, end) for cls, begin, end in schedule],
key=lambda x: x.iter_begin,
)
self.last_step = self.profilers[-1].iter_end if self.profilers else 0
self.summary: List[Tuple[str, str]] = []
def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None:
if len(schedule) == 0:
logger.warning(
"You specified empty schedule for profiling. No data will be captured."
)
pq: Any = queue.PriorityQueue()
for cls, begin, end in schedule:
assert (
begin >= 0
), f"Begin step of profiler must be non-negative, found: {begin}"
assert end > 0, f"End step of profiler must be positive, found: {end}"
assert (
begin < end
), f"Start must be before the end, found: begin={begin} and end={end}"
pq.put((begin, end))
prev_end = -1
for begin, end in pq.queue:
assert begin >= prev_end, (
"There is some overlapping in profiler scheduling. Please do not"
+ " overlap profilers by step as they may affect each other. Schedule:"
+ f" {schedule}"
)
prev_end = end
def update_profilers_on_step(self) -> None:
for p in self.profilers:
if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end:
if p.object is None:
o = p.cls(self)
logging.info(f"Starting {p.cls.__name__} profiler...")
o.__enter__()
p.object = o
else:
p.object.step()
else:
if p.object is not None:
o = p.object
p.object = None
logging.info(f"Shutting down {p.cls.__name__} profiler...")
o.__exit__(None, None, None)
def _create_output_filename(self, filename: str) -> Path:
"""
Returns where to write a file with desired filename.
Handles the case where we are in distributed settings, or when
we need to output the same file multiple times (eg if a profiler
runs for several steps)
"""
if self.worker_name != "":
file = Path(filename)
folder = self.output_dir / file.stem
folder.mkdir(parents=True, exist_ok=True)
return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}"
return self.output_dir / f"{self.done_steps:06}_{filename}"
def _install_hooks(self) -> None:
self.hooks_refcount += 1
# Already installed
if self.hooks:
return
module = self.module()
if module is None:
return
for name, sub_mod in module.named_modules():
if name == "":
continue
name = name.split(".")[-1]
self.hooks += [
sub_mod.register_forward_pre_hook(self._enter_module_hook(name)),
sub_mod.register_forward_hook(self._exit_module_hook(name)),
]
def _remove_hooks(self) -> None:
self.hooks_refcount -= 1
if self.hooks_refcount == 0:
for h in self.hooks:
h.remove()
def _enter_module_hook(self, name):
class PopState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self._exit_module(name)
return grad_outs
def f(module, inputs):
self._enter_module(name)
inputs = _normalize_tuple(inputs)
out = PopState.apply(*inputs)
return out
return f
def _exit_module_hook(self, name):
class PushState(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
if len(args) == 1:
return args[0]
return args
@staticmethod
def backward(ctx, *grad_outs):
self._enter_module(name)
return grad_outs
def f(module, inputs, outputs):
self._exit_module(name)
outputs = _normalize_tuple(outputs)
return PushState.apply(*outputs)
return f
def _enter_module(self, name) -> None:
self.parents.append(name)
torch.cuda.nvtx.range_push(name)
def _exit_module(self, name) -> None:
torch.cuda.nvtx.range_pop()
assert self.parents[-1] == name
self.parents.pop()
def start(self):
self.__enter__()
def stop(self, exc_type=None, exc_val=None, exc_tb=None):
self.__exit__(exc_type, exc_val, exc_tb)
def __enter__(self):
if _Profiler._CURRENT_PROFILER is not None:
raise ValueError("Only one xformers profiler can be active at a time")
_Profiler._CURRENT_PROFILER = self
self.update_profilers_on_step()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
_Profiler._CURRENT_PROFILER = None
for p in self.profilers:
if p.object is not None:
p.object.__exit__(exc_type, exc_val, exc_tb)
def step(self) -> None:
"""Signals the profiler that the next profiling step has started."""
self.done_steps += 1
if self.done_steps <= self.last_step:
self.parents = ["Global"]
self.update_profilers_on_step()
if self.done_steps == self.last_step:
logger.info("xFormers profiler done. %s", self.format_summary())
def format_summary(self) -> str:
if len(self.summary) == 0:
return ""
pad_titles = max(len(title) for title, value in self.summary)
return "summary:\n" + "\n".join(
[f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary]
)

View File

@@ -0,0 +1,513 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import json
import math
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Set, Tuple
import torch.cuda.memory
import torch.cuda.nvtx
import torch.profiler
import torch.utils.hooks
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily
from torch.utils._pytree import tree_map
from ..ops.common import FUNC_TO_XFORMERS_OPERATOR
from .device_limits import get_device_limits
from .profiler import _Profiler
class TorchFuncMockNoDispatch:
"""
Wraps a method to call it without the custom
pytorch dispatcher
"""
def __init__(self, pt_impl):
self.pt_impl = pt_impl
def __get__(self, obj, c):
return partial(self, obj)
def __call__(self, obj, *args, **kwargs):
with _pop_mode_temporarily():
return self.pt_impl(obj, *args, **kwargs)
class DispatcherWithoutBrokenFuncs(TorchDispatchMode):
TENSOR_FUNCS_NO_DISPATCH = [
# Can't convert Stream argument to Python object
# https://github.com/pytorch/pytorch/issues/94403
"record_stream"
]
def __enter__(self) -> None:
self._pt_impls = {}
for k in self.TENSOR_FUNCS_NO_DISPATCH:
impl = getattr(torch.Tensor, k)
self._pt_impls[k] = impl
setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl))
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
for k in self.TENSOR_FUNCS_NO_DISPATCH:
setattr(torch.Tensor, k, self._pt_impls[k])
return super().__exit__(exc_type, exc_val, exc_tb)
def get_shape(i):
return i.shape
def prod(x):
res = 1
for i in x:
res *= i
return res
class GemmOpComputeFlops:
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1])
def __call__(self, inputs: List[Any], outputs: List[Any]) -> float:
return 2 * prod(self._get_mnk(inputs))
def op_suffix(self, inputs: List[Any]) -> str:
m, n, k = self._get_mnk(inputs)
return f"_{m}x{n}x{k}"
class GemmOpComputeFlopsLinear(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1])
class GemmOpComputeFlopsMv(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1])
class GemmOpComputeFlopsBmm(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
a, b = inputs[0], inputs[1]
assert a.ndim == 3
assert b.ndim == 3
bs = max(inputs[0].shape[0], inputs[1].shape[0])
return (bs * a.shape[1], b.shape[-1], b.shape[-2])
class GemmOpComputeFlopsAddmm(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return super()._get_mnk(inputs[1:])
class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return super()._get_mnk(inputs[1:])
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> float:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0.0
if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(
transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed
)
return flop_count
def tensor_storage_size_in_mem(x: torch.Tensor):
total = 1
for dim_sz, stride in zip(x.shape, x.stride()):
if stride >= 1:
total *= dim_sz
return total
def get_size(inputs: List[Any]):
total_bytes = 0
def process(x) -> None:
nonlocal total_bytes
if isinstance(x, torch.Tensor):
total_bytes += tensor_storage_size_in_mem(x) * x.element_size()
tree_map(process, inputs)
return total_bytes
def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]):
size_input, size_output = get_size(inputs), get_size(outputs)
return size_input + size_output
def output_read_from_input(inputs: List[Any], outputs: List[Any]):
size_input, size_output = get_size(inputs), get_size(outputs)
return size_output + min(size_input, size_output)
def output_total_size(inputs: List[Any], outputs: List[Any]):
return get_size(outputs)
def input_total_size(inputs: List[Any], outputs: List[Any]):
return get_size(inputs)
def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]):
# Approximation that isn't too bad
total_elements = 0
def process(x) -> None:
nonlocal total_elements
if isinstance(x, torch.Tensor):
total_elements += x.numel()
tree_map(process, inputs)
tree_map(process, outputs)
return total_elements / 2
def no_flop(inputs: List[Any], outputs: List[Any]):
return 0
def no_io(inputs: List[Any], outputs: List[Any]):
return 0
aten = torch.ops.aten
NO_FLOPS_NO_IO_OPS = [
aten.permute,
aten.view,
aten.view_as,
aten.detach,
aten.t,
aten.transpose,
aten.expand,
aten._unsafe_view,
aten.select,
aten.split,
aten.split_with_sizes,
aten.empty,
aten.empty_strided,
aten.empty_like,
aten.is_same_size,
]
NO_FLOPS_OPS = [
aten._reshape_alias,
aten.reshape,
aten.clone,
aten.cat,
aten.select_backward,
aten.slice,
aten.slice_backward,
aten.ones,
aten.ones_like,
aten.zeros_like,
aten.zero_,
aten.zeros,
aten.masked_fill,
aten.masked_fill_,
]
flop_mapping = {
aten.mv: GemmOpComputeFlopsMv(), # mat-vec
aten.mm: GemmOpComputeFlops(),
aten.matmul: GemmOpComputeFlops(),
aten.addmm: GemmOpComputeFlopsAddmm(),
aten.bmm: GemmOpComputeFlopsBmm(),
aten.addbmm: GemmOpComputeFlopsAddbmm(),
aten.linear: GemmOpComputeFlopsLinear(),
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
# Operations with 0 flop
**{op: no_flop for op in NO_FLOPS_OPS},
**{op: no_flop for op in NO_FLOPS_NO_IO_OPS},
}
io_mapping = {
aten.clone: output_read_from_input,
aten.cat: output_read_from_input,
aten.slice: output_read_from_input,
aten.ones_like: output_total_size,
aten.zeros_like: output_total_size,
aten.zero_: input_total_size,
**{op: no_io for op in NO_FLOPS_NO_IO_OPS}
# TODO: Check how this is implemented in PT
# aten.slice_backward: no_flop,
# aten.select_backward: no_flop,
}
@dataclass
class _OpInfo:
flop_count: float = 0.0
time_ms: float = 0.0
io_bytes: int = 0
is_exact_flop: bool = True
op_name: str = ""
op_suffix: str = ""
stacktrace: Tuple[str, ...] = field(default_factory=tuple)
ev_start: torch.cuda.Event = field(
default_factory=lambda: torch.cuda.Event(enable_timing=True)
)
ev_end: torch.cuda.Event = field(
default_factory=lambda: torch.cuda.Event(enable_timing=True)
)
# Hardware limits for this operation (inf if unknown)
hardware_tflops_limit: float = math.inf
hardware_membw_limit: float = math.inf
@property
def time_membound_ms(self) -> float:
assert self.time_ms > 0.0
if self.io_bytes == 0:
return 0.0
return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit)
@property
def time_computebound_ms(self) -> float:
assert self.time_ms > 0.0
tflop = self.flop_count / (1000**4)
if tflop == 0.0:
return 0.0
return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit)
def finalize(self) -> None:
self.time_ms = self.ev_start.elapsed_time(self.ev_end)
@dataclass
class _OpInfoAggregated:
is_exact_flop: bool = True
total_flop_count: float = 0.0
total_io_bytes: int = 0
total_time_ms: float = 0.0
total_time_membound_ms: float = 0.0
total_time_computebound_ms: float = 0.0
num: int = 0
stacktraces: List[Tuple[str, ...]] = field(default_factory=list)
def add(self, op: _OpInfo) -> None:
self.total_flop_count += op.flop_count
self.total_time_ms += op.time_ms
self.total_io_bytes += op.io_bytes
self.total_time_membound_ms += op.time_membound_ms
self.total_time_computebound_ms += op.time_computebound_ms
self.num += 1
self.is_exact_flop = op.is_exact_flop
self.stacktraces.append(op.stacktrace)
def as_dict(self, **kwargs) -> Dict[str, Any]:
mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms)
tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4)
compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms)
return {
"is_exact_flop": self.is_exact_flop,
"total_flop_count": self.total_flop_count,
"total_time_ms": self.total_time_ms,
"total_io_bytes": self.total_io_bytes,
"num": self.num,
"Tflops": tflops,
"mem_bound": mem_bound,
"compute_bound": compute_bound,
**kwargs,
}
class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs):
"""
Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/
"""
def __init__(self, main_profiler: _Profiler) -> None:
self.main_profiler = main_profiler
self.trace: List[_OpInfo] = []
self.temp_disabled = False
def _hardware_tflops_membw_limit(
self, args: Tuple[Any, ...], outputs: Tuple[Any, ...]
) -> Tuple[float, float]:
device = None
dtypes: List[torch.dtype] = []
for a in itertools.chain(outputs, args):
if isinstance(a, torch.Tensor):
if device is None:
device = a.device
dtypes.append(a.dtype)
limits = get_device_limits(device)
dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops]
if not dtypes or device is None:
return (math.inf, math.inf)
dtype = dtypes[0]
if torch.is_autocast_enabled() and dtype is torch.float32:
dtype = torch.get_autocast_gpu_dtype()
return limits.gemm_tflops[dtype], limits.gmem_bandwidth
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
func_packet = func._overloadpacket
if self.temp_disabled or func_packet.__name__ in [
"_record_function_exit",
"_record_function_enter_new",
]:
return func(*args, **kwargs)
op = _OpInfo()
op.ev_start.record()
out = func(*args, **kwargs)
op.ev_end.record()
(
op.hardware_tflops_limit,
op.hardware_membw_limit,
) = self._hardware_tflops_membw_limit(
args, out if isinstance(out, tuple) else (out,)
)
op.op_name = func_packet.__name__
# Prevent functions called by flop counting ops to be recorded
self.temp_disabled = True
flop_count = -1
compute_flops = None
if func_packet in FUNC_TO_XFORMERS_OPERATOR:
flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop(
*args, **kwargs
)
if flop_count == -1:
compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op)
flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,))
if isinstance(compute_flops, GemmOpComputeFlops):
op.op_name += compute_flops.op_suffix(args)
compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes)
op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,))
self.temp_disabled = False
op.stacktrace = tuple(self.main_profiler.parents)
op.flop_count = flop_count
op.is_exact_flop = compute_flops is not guess_flops_unknown_op
self.trace.append(op)
return out
def __enter__(self):
self.main_profiler._install_hooks()
super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
self.main_profiler._remove_hooks()
torch.cuda.synchronize() # Wait for the events to be recorded
for op in self.trace:
op.finalize()
self.save_json()
def step(self) -> None:
pass
def save_json(self) -> None:
# Aggregate data at the module + op level
all_paths: Set[Tuple[str, ...]] = set()
per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict(
_OpInfoAggregated
)
per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated)
for op in self.trace:
all_paths.add(op.stacktrace)
for op in self.trace:
for i in range(len(op.stacktrace)):
if op.stacktrace[: i + 1] in all_paths:
per_module_data[op.stacktrace[: i + 1]].add(op)
per_op_data[op.op_name].add(op)
# Generate JSON
all_data = []
for stacktrace, agg_info in per_module_data.items():
all_data.append(
agg_info.as_dict(
agg="module", path=stacktrace, name=stacktrace[-1], op=""
)
)
for op_name, agg_info in per_op_data.items():
# Find the most common path
paths_count: Dict[Tuple[str, ...], int] = defaultdict(int)
agg_info.stacktraces.sort() # In case of a draw, let's always return the same
for p in agg_info.stacktraces:
paths_count[p] += 1
maxp = agg_info.stacktraces[0]
for p, count in paths_count.items():
if count > paths_count[maxp]:
maxp = p
all_data.append(
agg_info.as_dict(
agg="opname",
path=f"{'.'.join(maxp)} (x{paths_count[maxp]})",
name="",
op=op_name,
)
)
filename = self.main_profiler._create_output_filename("ops.json")
self.main_profiler.summary.append(("OpsSummary", str(filename)))
with open(filename, "w+") as f:
json.dump(all_data, f)