514 lines
16 KiB
Python
514 lines
16 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import itertools
|
|
import json
|
|
import math
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass, field
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Set, Tuple
|
|
|
|
import torch.cuda.memory
|
|
import torch.cuda.nvtx
|
|
import torch.profiler
|
|
import torch.utils.hooks
|
|
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from ..ops.common import FUNC_TO_XFORMERS_OPERATOR
|
|
from .device_limits import get_device_limits
|
|
from .profiler import _Profiler
|
|
|
|
|
|
class TorchFuncMockNoDispatch:
|
|
"""
|
|
Wraps a method to call it without the custom
|
|
pytorch dispatcher
|
|
"""
|
|
|
|
def __init__(self, pt_impl):
|
|
self.pt_impl = pt_impl
|
|
|
|
def __get__(self, obj, c):
|
|
return partial(self, obj)
|
|
|
|
def __call__(self, obj, *args, **kwargs):
|
|
with _pop_mode_temporarily():
|
|
return self.pt_impl(obj, *args, **kwargs)
|
|
|
|
|
|
class DispatcherWithoutBrokenFuncs(TorchDispatchMode):
|
|
TENSOR_FUNCS_NO_DISPATCH = [
|
|
# Can't convert Stream argument to Python object
|
|
# https://github.com/pytorch/pytorch/issues/94403
|
|
"record_stream"
|
|
]
|
|
|
|
def __enter__(self) -> None:
|
|
self._pt_impls = {}
|
|
for k in self.TENSOR_FUNCS_NO_DISPATCH:
|
|
impl = getattr(torch.Tensor, k)
|
|
self._pt_impls[k] = impl
|
|
setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl))
|
|
return super().__enter__()
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
for k in self.TENSOR_FUNCS_NO_DISPATCH:
|
|
setattr(torch.Tensor, k, self._pt_impls[k])
|
|
return super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
def get_shape(i):
|
|
return i.shape
|
|
|
|
|
|
def prod(x):
|
|
res = 1
|
|
for i in x:
|
|
res *= i
|
|
return res
|
|
|
|
|
|
class GemmOpComputeFlops:
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1])
|
|
|
|
def __call__(self, inputs: List[Any], outputs: List[Any]) -> float:
|
|
return 2 * prod(self._get_mnk(inputs))
|
|
|
|
def op_suffix(self, inputs: List[Any]) -> str:
|
|
m, n, k = self._get_mnk(inputs)
|
|
return f"_{m}x{n}x{k}"
|
|
|
|
|
|
class GemmOpComputeFlopsLinear(GemmOpComputeFlops):
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1])
|
|
|
|
|
|
class GemmOpComputeFlopsMv(GemmOpComputeFlops):
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1])
|
|
|
|
|
|
class GemmOpComputeFlopsBmm(GemmOpComputeFlops):
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
a, b = inputs[0], inputs[1]
|
|
assert a.ndim == 3
|
|
assert b.ndim == 3
|
|
bs = max(inputs[0].shape[0], inputs[1].shape[0])
|
|
return (bs * a.shape[1], b.shape[-1], b.shape[-2])
|
|
|
|
|
|
class GemmOpComputeFlopsAddmm(GemmOpComputeFlops):
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
return super()._get_mnk(inputs[1:])
|
|
|
|
|
|
class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm):
|
|
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
|
|
return super()._get_mnk(inputs[1:])
|
|
|
|
|
|
def conv_flop_count(
|
|
x_shape: List[int],
|
|
w_shape: List[int],
|
|
out_shape: List[int],
|
|
transposed: bool = False,
|
|
) -> float:
|
|
"""
|
|
Count flops for convolution. Note only multiplication is
|
|
counted. Computation for addition and bias is ignored.
|
|
Flops for a transposed convolution are calculated as
|
|
flops = (x_shape[2:] * prod(w_shape) * batch_size).
|
|
Args:
|
|
x_shape (list(int)): The input shape before convolution.
|
|
w_shape (list(int)): The filter shape.
|
|
out_shape (list(int)): The output shape after convolution.
|
|
transposed (bool): is the convolution transposed
|
|
Returns:
|
|
int: the number of flops
|
|
"""
|
|
batch_size = x_shape[0]
|
|
conv_shape = (x_shape if transposed else out_shape)[2:]
|
|
flop = batch_size * prod(w_shape) * prod(conv_shape)
|
|
return flop
|
|
|
|
|
|
def conv_flop(inputs: List[Any], outputs: List[Any]):
|
|
"""
|
|
Count flops for convolution.
|
|
"""
|
|
x, w = inputs[:2]
|
|
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
|
|
transposed = inputs[6]
|
|
|
|
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
|
|
|
|
|
def transpose_shape(shape):
|
|
return [shape[1], shape[0]] + list(shape[2:])
|
|
|
|
|
|
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
|
|
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
|
|
output_mask = inputs[-1]
|
|
fwd_transposed = inputs[7]
|
|
flop_count = 0.0
|
|
|
|
if output_mask[0]:
|
|
grad_input_shape = get_shape(outputs[0])
|
|
flop_count += conv_flop_count(
|
|
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
|
|
)
|
|
if output_mask[1]:
|
|
grad_weight_shape = get_shape(outputs[1])
|
|
flop_count += conv_flop_count(
|
|
transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed
|
|
)
|
|
|
|
return flop_count
|
|
|
|
|
|
def tensor_storage_size_in_mem(x: torch.Tensor):
|
|
total = 1
|
|
for dim_sz, stride in zip(x.shape, x.stride()):
|
|
if stride >= 1:
|
|
total *= dim_sz
|
|
return total
|
|
|
|
|
|
def get_size(inputs: List[Any]):
|
|
total_bytes = 0
|
|
|
|
def process(x) -> None:
|
|
nonlocal total_bytes
|
|
if isinstance(x, torch.Tensor):
|
|
total_bytes += tensor_storage_size_in_mem(x) * x.element_size()
|
|
|
|
tree_map(process, inputs)
|
|
return total_bytes
|
|
|
|
|
|
def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]):
|
|
size_input, size_output = get_size(inputs), get_size(outputs)
|
|
return size_input + size_output
|
|
|
|
|
|
def output_read_from_input(inputs: List[Any], outputs: List[Any]):
|
|
size_input, size_output = get_size(inputs), get_size(outputs)
|
|
return size_output + min(size_input, size_output)
|
|
|
|
|
|
def output_total_size(inputs: List[Any], outputs: List[Any]):
|
|
return get_size(outputs)
|
|
|
|
|
|
def input_total_size(inputs: List[Any], outputs: List[Any]):
|
|
return get_size(inputs)
|
|
|
|
|
|
def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]):
|
|
# Approximation that isn't too bad
|
|
total_elements = 0
|
|
|
|
def process(x) -> None:
|
|
nonlocal total_elements
|
|
if isinstance(x, torch.Tensor):
|
|
total_elements += x.numel()
|
|
|
|
tree_map(process, inputs)
|
|
tree_map(process, outputs)
|
|
return total_elements / 2
|
|
|
|
|
|
def no_flop(inputs: List[Any], outputs: List[Any]):
|
|
return 0
|
|
|
|
|
|
def no_io(inputs: List[Any], outputs: List[Any]):
|
|
return 0
|
|
|
|
|
|
aten = torch.ops.aten
|
|
NO_FLOPS_NO_IO_OPS = [
|
|
aten.permute,
|
|
aten.view,
|
|
aten.view_as,
|
|
aten.detach,
|
|
aten.t,
|
|
aten.transpose,
|
|
aten.expand,
|
|
aten._unsafe_view,
|
|
aten.select,
|
|
aten.split,
|
|
aten.split_with_sizes,
|
|
aten.empty,
|
|
aten.empty_strided,
|
|
aten.empty_like,
|
|
aten.is_same_size,
|
|
]
|
|
NO_FLOPS_OPS = [
|
|
aten._reshape_alias,
|
|
aten.reshape,
|
|
aten.clone,
|
|
aten.cat,
|
|
aten.select_backward,
|
|
aten.slice,
|
|
aten.slice_backward,
|
|
aten.ones,
|
|
aten.ones_like,
|
|
aten.zeros_like,
|
|
aten.zero_,
|
|
aten.zeros,
|
|
aten.masked_fill,
|
|
aten.masked_fill_,
|
|
]
|
|
|
|
flop_mapping = {
|
|
aten.mv: GemmOpComputeFlopsMv(), # mat-vec
|
|
aten.mm: GemmOpComputeFlops(),
|
|
aten.matmul: GemmOpComputeFlops(),
|
|
aten.addmm: GemmOpComputeFlopsAddmm(),
|
|
aten.bmm: GemmOpComputeFlopsBmm(),
|
|
aten.addbmm: GemmOpComputeFlopsAddbmm(),
|
|
aten.linear: GemmOpComputeFlopsLinear(),
|
|
aten.convolution: conv_flop,
|
|
aten._convolution: conv_flop,
|
|
aten.convolution_backward: conv_backward_flop,
|
|
# Operations with 0 flop
|
|
**{op: no_flop for op in NO_FLOPS_OPS},
|
|
**{op: no_flop for op in NO_FLOPS_NO_IO_OPS},
|
|
}
|
|
io_mapping = {
|
|
aten.clone: output_read_from_input,
|
|
aten.cat: output_read_from_input,
|
|
aten.slice: output_read_from_input,
|
|
aten.ones_like: output_total_size,
|
|
aten.zeros_like: output_total_size,
|
|
aten.zero_: input_total_size,
|
|
**{op: no_io for op in NO_FLOPS_NO_IO_OPS}
|
|
# TODO: Check how this is implemented in PT
|
|
# aten.slice_backward: no_flop,
|
|
# aten.select_backward: no_flop,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class _OpInfo:
|
|
flop_count: float = 0.0
|
|
time_ms: float = 0.0
|
|
io_bytes: int = 0
|
|
is_exact_flop: bool = True
|
|
op_name: str = ""
|
|
op_suffix: str = ""
|
|
stacktrace: Tuple[str, ...] = field(default_factory=tuple)
|
|
ev_start: torch.cuda.Event = field(
|
|
default_factory=lambda: torch.cuda.Event(enable_timing=True)
|
|
)
|
|
ev_end: torch.cuda.Event = field(
|
|
default_factory=lambda: torch.cuda.Event(enable_timing=True)
|
|
)
|
|
|
|
# Hardware limits for this operation (inf if unknown)
|
|
hardware_tflops_limit: float = math.inf
|
|
hardware_membw_limit: float = math.inf
|
|
|
|
@property
|
|
def time_membound_ms(self) -> float:
|
|
assert self.time_ms > 0.0
|
|
if self.io_bytes == 0:
|
|
return 0.0
|
|
return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit)
|
|
|
|
@property
|
|
def time_computebound_ms(self) -> float:
|
|
assert self.time_ms > 0.0
|
|
tflop = self.flop_count / (1000**4)
|
|
if tflop == 0.0:
|
|
return 0.0
|
|
return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit)
|
|
|
|
def finalize(self) -> None:
|
|
self.time_ms = self.ev_start.elapsed_time(self.ev_end)
|
|
|
|
|
|
@dataclass
|
|
class _OpInfoAggregated:
|
|
is_exact_flop: bool = True
|
|
total_flop_count: float = 0.0
|
|
total_io_bytes: int = 0
|
|
total_time_ms: float = 0.0
|
|
total_time_membound_ms: float = 0.0
|
|
total_time_computebound_ms: float = 0.0
|
|
num: int = 0
|
|
stacktraces: List[Tuple[str, ...]] = field(default_factory=list)
|
|
|
|
def add(self, op: _OpInfo) -> None:
|
|
self.total_flop_count += op.flop_count
|
|
self.total_time_ms += op.time_ms
|
|
self.total_io_bytes += op.io_bytes
|
|
self.total_time_membound_ms += op.time_membound_ms
|
|
self.total_time_computebound_ms += op.time_computebound_ms
|
|
self.num += 1
|
|
self.is_exact_flop = op.is_exact_flop
|
|
self.stacktraces.append(op.stacktrace)
|
|
|
|
def as_dict(self, **kwargs) -> Dict[str, Any]:
|
|
mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms)
|
|
tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4)
|
|
compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms)
|
|
return {
|
|
"is_exact_flop": self.is_exact_flop,
|
|
"total_flop_count": self.total_flop_count,
|
|
"total_time_ms": self.total_time_ms,
|
|
"total_io_bytes": self.total_io_bytes,
|
|
"num": self.num,
|
|
"Tflops": tflops,
|
|
"mem_bound": mem_bound,
|
|
"compute_bound": compute_bound,
|
|
**kwargs,
|
|
}
|
|
|
|
|
|
class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs):
|
|
"""
|
|
Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/
|
|
"""
|
|
|
|
def __init__(self, main_profiler: _Profiler) -> None:
|
|
self.main_profiler = main_profiler
|
|
self.trace: List[_OpInfo] = []
|
|
self.temp_disabled = False
|
|
|
|
def _hardware_tflops_membw_limit(
|
|
self, args: Tuple[Any, ...], outputs: Tuple[Any, ...]
|
|
) -> Tuple[float, float]:
|
|
device = None
|
|
dtypes: List[torch.dtype] = []
|
|
for a in itertools.chain(outputs, args):
|
|
if isinstance(a, torch.Tensor):
|
|
if device is None:
|
|
device = a.device
|
|
dtypes.append(a.dtype)
|
|
limits = get_device_limits(device)
|
|
dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops]
|
|
if not dtypes or device is None:
|
|
return (math.inf, math.inf)
|
|
dtype = dtypes[0]
|
|
if torch.is_autocast_enabled() and dtype is torch.float32:
|
|
dtype = torch.get_autocast_gpu_dtype()
|
|
return limits.gemm_tflops[dtype], limits.gmem_bandwidth
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
func_packet = func._overloadpacket
|
|
if self.temp_disabled or func_packet.__name__ in [
|
|
"_record_function_exit",
|
|
"_record_function_enter_new",
|
|
]:
|
|
return func(*args, **kwargs)
|
|
|
|
op = _OpInfo()
|
|
op.ev_start.record()
|
|
out = func(*args, **kwargs)
|
|
op.ev_end.record()
|
|
|
|
(
|
|
op.hardware_tflops_limit,
|
|
op.hardware_membw_limit,
|
|
) = self._hardware_tflops_membw_limit(
|
|
args, out if isinstance(out, tuple) else (out,)
|
|
)
|
|
op.op_name = func_packet.__name__
|
|
# Prevent functions called by flop counting ops to be recorded
|
|
self.temp_disabled = True
|
|
flop_count = -1
|
|
compute_flops = None
|
|
if func_packet in FUNC_TO_XFORMERS_OPERATOR:
|
|
flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop(
|
|
*args, **kwargs
|
|
)
|
|
if flop_count == -1:
|
|
compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op)
|
|
flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,))
|
|
if isinstance(compute_flops, GemmOpComputeFlops):
|
|
op.op_name += compute_flops.op_suffix(args)
|
|
|
|
compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes)
|
|
op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,))
|
|
self.temp_disabled = False
|
|
|
|
op.stacktrace = tuple(self.main_profiler.parents)
|
|
op.flop_count = flop_count
|
|
op.is_exact_flop = compute_flops is not guess_flops_unknown_op
|
|
self.trace.append(op)
|
|
|
|
return out
|
|
|
|
def __enter__(self):
|
|
self.main_profiler._install_hooks()
|
|
super().__enter__()
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
self.main_profiler._remove_hooks()
|
|
torch.cuda.synchronize() # Wait for the events to be recorded
|
|
for op in self.trace:
|
|
op.finalize()
|
|
self.save_json()
|
|
|
|
def step(self) -> None:
|
|
pass
|
|
|
|
def save_json(self) -> None:
|
|
# Aggregate data at the module + op level
|
|
all_paths: Set[Tuple[str, ...]] = set()
|
|
per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict(
|
|
_OpInfoAggregated
|
|
)
|
|
per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated)
|
|
for op in self.trace:
|
|
all_paths.add(op.stacktrace)
|
|
for op in self.trace:
|
|
for i in range(len(op.stacktrace)):
|
|
if op.stacktrace[: i + 1] in all_paths:
|
|
per_module_data[op.stacktrace[: i + 1]].add(op)
|
|
per_op_data[op.op_name].add(op)
|
|
|
|
# Generate JSON
|
|
all_data = []
|
|
for stacktrace, agg_info in per_module_data.items():
|
|
all_data.append(
|
|
agg_info.as_dict(
|
|
agg="module", path=stacktrace, name=stacktrace[-1], op=""
|
|
)
|
|
)
|
|
for op_name, agg_info in per_op_data.items():
|
|
# Find the most common path
|
|
paths_count: Dict[Tuple[str, ...], int] = defaultdict(int)
|
|
agg_info.stacktraces.sort() # In case of a draw, let's always return the same
|
|
for p in agg_info.stacktraces:
|
|
paths_count[p] += 1
|
|
maxp = agg_info.stacktraces[0]
|
|
for p, count in paths_count.items():
|
|
if count > paths_count[maxp]:
|
|
maxp = p
|
|
all_data.append(
|
|
agg_info.as_dict(
|
|
agg="opname",
|
|
path=f"{'.'.join(maxp)} (x{paths_count[maxp]})",
|
|
name="",
|
|
op=op_name,
|
|
)
|
|
)
|
|
|
|
filename = self.main_profiler._create_output_filename("ops.json")
|
|
self.main_profiler.summary.append(("OpsSummary", str(filename)))
|
|
with open(filename, "w+") as f:
|
|
json.dump(all_data, f)
|