Files
enginex-bi_series-vllm/pkgs/xformers/profiler/slow_ops_profiler.py
2025-08-05 19:02:46 +08:00

514 lines
16 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import json
import math
from collections import defaultdict
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Set, Tuple
import torch.cuda.memory
import torch.cuda.nvtx
import torch.profiler
import torch.utils.hooks
from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily
from torch.utils._pytree import tree_map
from ..ops.common import FUNC_TO_XFORMERS_OPERATOR
from .device_limits import get_device_limits
from .profiler import _Profiler
class TorchFuncMockNoDispatch:
"""
Wraps a method to call it without the custom
pytorch dispatcher
"""
def __init__(self, pt_impl):
self.pt_impl = pt_impl
def __get__(self, obj, c):
return partial(self, obj)
def __call__(self, obj, *args, **kwargs):
with _pop_mode_temporarily():
return self.pt_impl(obj, *args, **kwargs)
class DispatcherWithoutBrokenFuncs(TorchDispatchMode):
TENSOR_FUNCS_NO_DISPATCH = [
# Can't convert Stream argument to Python object
# https://github.com/pytorch/pytorch/issues/94403
"record_stream"
]
def __enter__(self) -> None:
self._pt_impls = {}
for k in self.TENSOR_FUNCS_NO_DISPATCH:
impl = getattr(torch.Tensor, k)
self._pt_impls[k] = impl
setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl))
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
for k in self.TENSOR_FUNCS_NO_DISPATCH:
setattr(torch.Tensor, k, self._pt_impls[k])
return super().__exit__(exc_type, exc_val, exc_tb)
def get_shape(i):
return i.shape
def prod(x):
res = 1
for i in x:
res *= i
return res
class GemmOpComputeFlops:
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1])
def __call__(self, inputs: List[Any], outputs: List[Any]) -> float:
return 2 * prod(self._get_mnk(inputs))
def op_suffix(self, inputs: List[Any]) -> str:
m, n, k = self._get_mnk(inputs)
return f"_{m}x{n}x{k}"
class GemmOpComputeFlopsLinear(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1])
class GemmOpComputeFlopsMv(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1])
class GemmOpComputeFlopsBmm(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
a, b = inputs[0], inputs[1]
assert a.ndim == 3
assert b.ndim == 3
bs = max(inputs[0].shape[0], inputs[1].shape[0])
return (bs * a.shape[1], b.shape[-1], b.shape[-2])
class GemmOpComputeFlopsAddmm(GemmOpComputeFlops):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return super()._get_mnk(inputs[1:])
class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm):
def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]:
return super()._get_mnk(inputs[1:])
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> float:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0.0
if output_mask[0]:
grad_input_shape = get_shape(outputs[0])
flop_count += conv_flop_count(
grad_out_shape, w_shape, grad_input_shape, not fwd_transposed
)
if output_mask[1]:
grad_weight_shape = get_shape(outputs[1])
flop_count += conv_flop_count(
transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed
)
return flop_count
def tensor_storage_size_in_mem(x: torch.Tensor):
total = 1
for dim_sz, stride in zip(x.shape, x.stride()):
if stride >= 1:
total *= dim_sz
return total
def get_size(inputs: List[Any]):
total_bytes = 0
def process(x) -> None:
nonlocal total_bytes
if isinstance(x, torch.Tensor):
total_bytes += tensor_storage_size_in_mem(x) * x.element_size()
tree_map(process, inputs)
return total_bytes
def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]):
size_input, size_output = get_size(inputs), get_size(outputs)
return size_input + size_output
def output_read_from_input(inputs: List[Any], outputs: List[Any]):
size_input, size_output = get_size(inputs), get_size(outputs)
return size_output + min(size_input, size_output)
def output_total_size(inputs: List[Any], outputs: List[Any]):
return get_size(outputs)
def input_total_size(inputs: List[Any], outputs: List[Any]):
return get_size(inputs)
def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]):
# Approximation that isn't too bad
total_elements = 0
def process(x) -> None:
nonlocal total_elements
if isinstance(x, torch.Tensor):
total_elements += x.numel()
tree_map(process, inputs)
tree_map(process, outputs)
return total_elements / 2
def no_flop(inputs: List[Any], outputs: List[Any]):
return 0
def no_io(inputs: List[Any], outputs: List[Any]):
return 0
aten = torch.ops.aten
NO_FLOPS_NO_IO_OPS = [
aten.permute,
aten.view,
aten.view_as,
aten.detach,
aten.t,
aten.transpose,
aten.expand,
aten._unsafe_view,
aten.select,
aten.split,
aten.split_with_sizes,
aten.empty,
aten.empty_strided,
aten.empty_like,
aten.is_same_size,
]
NO_FLOPS_OPS = [
aten._reshape_alias,
aten.reshape,
aten.clone,
aten.cat,
aten.select_backward,
aten.slice,
aten.slice_backward,
aten.ones,
aten.ones_like,
aten.zeros_like,
aten.zero_,
aten.zeros,
aten.masked_fill,
aten.masked_fill_,
]
flop_mapping = {
aten.mv: GemmOpComputeFlopsMv(), # mat-vec
aten.mm: GemmOpComputeFlops(),
aten.matmul: GemmOpComputeFlops(),
aten.addmm: GemmOpComputeFlopsAddmm(),
aten.bmm: GemmOpComputeFlopsBmm(),
aten.addbmm: GemmOpComputeFlopsAddbmm(),
aten.linear: GemmOpComputeFlopsLinear(),
aten.convolution: conv_flop,
aten._convolution: conv_flop,
aten.convolution_backward: conv_backward_flop,
# Operations with 0 flop
**{op: no_flop for op in NO_FLOPS_OPS},
**{op: no_flop for op in NO_FLOPS_NO_IO_OPS},
}
io_mapping = {
aten.clone: output_read_from_input,
aten.cat: output_read_from_input,
aten.slice: output_read_from_input,
aten.ones_like: output_total_size,
aten.zeros_like: output_total_size,
aten.zero_: input_total_size,
**{op: no_io for op in NO_FLOPS_NO_IO_OPS}
# TODO: Check how this is implemented in PT
# aten.slice_backward: no_flop,
# aten.select_backward: no_flop,
}
@dataclass
class _OpInfo:
flop_count: float = 0.0
time_ms: float = 0.0
io_bytes: int = 0
is_exact_flop: bool = True
op_name: str = ""
op_suffix: str = ""
stacktrace: Tuple[str, ...] = field(default_factory=tuple)
ev_start: torch.cuda.Event = field(
default_factory=lambda: torch.cuda.Event(enable_timing=True)
)
ev_end: torch.cuda.Event = field(
default_factory=lambda: torch.cuda.Event(enable_timing=True)
)
# Hardware limits for this operation (inf if unknown)
hardware_tflops_limit: float = math.inf
hardware_membw_limit: float = math.inf
@property
def time_membound_ms(self) -> float:
assert self.time_ms > 0.0
if self.io_bytes == 0:
return 0.0
return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit)
@property
def time_computebound_ms(self) -> float:
assert self.time_ms > 0.0
tflop = self.flop_count / (1000**4)
if tflop == 0.0:
return 0.0
return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit)
def finalize(self) -> None:
self.time_ms = self.ev_start.elapsed_time(self.ev_end)
@dataclass
class _OpInfoAggregated:
is_exact_flop: bool = True
total_flop_count: float = 0.0
total_io_bytes: int = 0
total_time_ms: float = 0.0
total_time_membound_ms: float = 0.0
total_time_computebound_ms: float = 0.0
num: int = 0
stacktraces: List[Tuple[str, ...]] = field(default_factory=list)
def add(self, op: _OpInfo) -> None:
self.total_flop_count += op.flop_count
self.total_time_ms += op.time_ms
self.total_io_bytes += op.io_bytes
self.total_time_membound_ms += op.time_membound_ms
self.total_time_computebound_ms += op.time_computebound_ms
self.num += 1
self.is_exact_flop = op.is_exact_flop
self.stacktraces.append(op.stacktrace)
def as_dict(self, **kwargs) -> Dict[str, Any]:
mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms)
tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4)
compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms)
return {
"is_exact_flop": self.is_exact_flop,
"total_flop_count": self.total_flop_count,
"total_time_ms": self.total_time_ms,
"total_io_bytes": self.total_io_bytes,
"num": self.num,
"Tflops": tflops,
"mem_bound": mem_bound,
"compute_bound": compute_bound,
**kwargs,
}
class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs):
"""
Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/
"""
def __init__(self, main_profiler: _Profiler) -> None:
self.main_profiler = main_profiler
self.trace: List[_OpInfo] = []
self.temp_disabled = False
def _hardware_tflops_membw_limit(
self, args: Tuple[Any, ...], outputs: Tuple[Any, ...]
) -> Tuple[float, float]:
device = None
dtypes: List[torch.dtype] = []
for a in itertools.chain(outputs, args):
if isinstance(a, torch.Tensor):
if device is None:
device = a.device
dtypes.append(a.dtype)
limits = get_device_limits(device)
dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops]
if not dtypes or device is None:
return (math.inf, math.inf)
dtype = dtypes[0]
if torch.is_autocast_enabled() and dtype is torch.float32:
dtype = torch.get_autocast_gpu_dtype()
return limits.gemm_tflops[dtype], limits.gmem_bandwidth
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
func_packet = func._overloadpacket
if self.temp_disabled or func_packet.__name__ in [
"_record_function_exit",
"_record_function_enter_new",
]:
return func(*args, **kwargs)
op = _OpInfo()
op.ev_start.record()
out = func(*args, **kwargs)
op.ev_end.record()
(
op.hardware_tflops_limit,
op.hardware_membw_limit,
) = self._hardware_tflops_membw_limit(
args, out if isinstance(out, tuple) else (out,)
)
op.op_name = func_packet.__name__
# Prevent functions called by flop counting ops to be recorded
self.temp_disabled = True
flop_count = -1
compute_flops = None
if func_packet in FUNC_TO_XFORMERS_OPERATOR:
flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop(
*args, **kwargs
)
if flop_count == -1:
compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op)
flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,))
if isinstance(compute_flops, GemmOpComputeFlops):
op.op_name += compute_flops.op_suffix(args)
compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes)
op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,))
self.temp_disabled = False
op.stacktrace = tuple(self.main_profiler.parents)
op.flop_count = flop_count
op.is_exact_flop = compute_flops is not guess_flops_unknown_op
self.trace.append(op)
return out
def __enter__(self):
self.main_profiler._install_hooks()
super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
super().__exit__(exc_type, exc_val, exc_tb)
self.main_profiler._remove_hooks()
torch.cuda.synchronize() # Wait for the events to be recorded
for op in self.trace:
op.finalize()
self.save_json()
def step(self) -> None:
pass
def save_json(self) -> None:
# Aggregate data at the module + op level
all_paths: Set[Tuple[str, ...]] = set()
per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict(
_OpInfoAggregated
)
per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated)
for op in self.trace:
all_paths.add(op.stacktrace)
for op in self.trace:
for i in range(len(op.stacktrace)):
if op.stacktrace[: i + 1] in all_paths:
per_module_data[op.stacktrace[: i + 1]].add(op)
per_op_data[op.op_name].add(op)
# Generate JSON
all_data = []
for stacktrace, agg_info in per_module_data.items():
all_data.append(
agg_info.as_dict(
agg="module", path=stacktrace, name=stacktrace[-1], op=""
)
)
for op_name, agg_info in per_op_data.items():
# Find the most common path
paths_count: Dict[Tuple[str, ...], int] = defaultdict(int)
agg_info.stacktraces.sort() # In case of a draw, let's always return the same
for p in agg_info.stacktraces:
paths_count[p] += 1
maxp = agg_info.stacktraces[0]
for p, count in paths_count.items():
if count > paths_count[maxp]:
maxp = p
all_data.append(
agg_info.as_dict(
agg="opname",
path=f"{'.'.join(maxp)} (x{paths_count[maxp]})",
name="",
op=op_name,
)
)
filename = self.main_profiler._create_output_filename("ops.json")
self.main_profiler.summary.append(("OpsSummary", str(filename)))
with open(filename, "w+") as f:
json.dump(all_data, f)