Files
2025-08-05 19:02:46 +08:00

661 lines
21 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import copy
import csv
import glob
import logging
import math
import os
import tempfile
from collections import defaultdict, namedtuple
from dataclasses import replace
from typing import Any, Dict, Generator, List, Set, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from torch.utils import benchmark
sns.set()
TestCase = namedtuple("TestCase", ["function", "name"])
_triton_is_available = torch.cuda.is_available()
if _triton_is_available:
try:
import triton
except ImportError as e:
logging.warning(f"Triton is not available: {e}.\nbench_functions")
_triton_is_available = False
def pretty_print(results, title, units):
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
offset = len(header)
print(
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(v[k])
for k, w in workloads.items():
print(
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)
print("")
def pretty_plot(
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash
"""
if not filename:
filename = title + ".png"
# Sanitize the filename
filename = (
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
)
# Gather all the results in "collumns"
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(float(v[k]))
# Make sure that the plot is big enough
f = plt.figure()
f.set_figwidth(6)
f.set_figheight(6)
# Display the collections
for k, v in workloads.items():
if dash_key and dash_key in k:
plt.plot(list(results.keys()), v, "--")
else:
plt.plot(list(results.keys()), v)
plt.title(title)
plt.legend(list(workloads.keys()), loc=legend_loc)
plt.ylabel(units)
plt.xticks(rotation=45)
plt.savefig(filename, bbox_inches="tight")
plt.close(f)
if _triton_is_available:
def bench_functions(
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
):
device = torch.device("cuda")
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, M, K in shapes:
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True)
for testcase in test_cases:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
metric = metric_transform(a, time)
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{metric:.1f}"
pretty_print(
results,
title=" ------------- Type: {} ------------- ".format(dtype),
units=unit,
)
pretty_plot(results, title + str(dtype), unit, dash_key="pytorch")
def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash
"""
if not filename:
filename = title + ".png"
# Sanitize the filename
filename = (
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
)
xlabels = list(results.keys())
# Gather all the results in "collumns"
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(float(v[k]))
options = list(workloads.keys())
group_len = len(options)
for key in workloads.keys():
num_groups = len(workloads[key])
break
group_width = group_len + 1
# Make sure that the plot is big enough
f = plt.figure()
f.set_figwidth(6)
f.set_figheight(6)
for idx in range(group_len):
option = options[idx]
values = workloads[option]
xloc = np.arange(1 + idx, group_width * num_groups, group_width)
plt.bar(xloc, values, width=1, edgecolor="black")
plt.title(title)
plt.legend(list(workloads.keys()), loc="upper right")
plt.ylabel(units)
ax = plt.gca()
xticks_loc = np.arange(
1 + (group_len - 1) / 2.0, group_width * num_groups, group_width
)
ax.set_xticks(xticks_loc, xlabels)
plt.xticks(rotation=45)
plt.setp(ax.xaxis.get_majorticklabels(), ha="right")
ax.set_axisbelow(True)
ax.yaxis.grid(color="gray", linestyle="dashed")
ax.xaxis.grid(color="gray", linestyle="dashed")
plt.savefig(filename, bbox_inches="tight")
plt.close(f)
def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
META_ALGORITHM = "algorithm"
BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"]
# Serialize/unserialize to CSV
# We could use pkl, but resort to CSV for readability
def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]:
parts = os.path.basename(filename).split(".")
env = ""
description = ""
if len(parts) == 3:
env = parts[1]
description = parts[0]
data = []
with open(filename, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
if description != "" and row["description"] not in BASELINE_DESCRIPTIONS:
row["description"] = description
task_spec = benchmark.utils.common.TaskSpec(
stmt="",
setup="",
global_setup="",
label=row["label"],
sub_label=row["sub_label"],
description=row["description"],
env=env,
num_threads=int(row["num_threads"]),
)
measurement = benchmark.utils.common.Measurement(
number_per_run=1,
raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)],
task_spec=task_spec,
)
measurement.mem_use = float(row["mem_use_mb"]) # type: ignore
data.append(
(
{
META_ALGORITHM: row["algorithm"]
if row["algorithm"] != ""
else None,
},
measurement,
)
)
return data
def _benchmark_results_to_csv(
filename: str, results: List[Tuple[Dict[str, Any], Any]]
) -> None:
data = [
{
"sub_label": r.task_spec.sub_label,
"label": r.task_spec.label,
"num_threads": r.task_spec.num_threads,
"algorithm": metadata.get(META_ALGORITHM, ""),
"description": r.task_spec.description
if r.task_spec.description in BASELINE_DESCRIPTIONS
else "",
"runtime_us": int(1000 * 1000 * r.mean),
"mem_use_mb": r.mem_use,
}
for metadata, r in results
]
with open(filename, "w+", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys()))
writer.writeheader()
for d in data:
writer.writerow(d)
def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
"""
Returns a `benchmark.Compare` object, except that if we have runs
with different algorithms, we also add the algorithm name
in the column titles
"""
all_algorithms: Set[str] = set()
all_description: Set[str] = set()
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is not None:
all_algorithms.add(algo)
all_description.add(r.task_spec.description)
display_algo = len(all_algorithms) > 1
display_descr = len(all_description) > 1
display_results = []
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is None:
display_results.append(r)
else:
r = copy.copy(r)
description = ""
if display_descr:
description = r.task_spec.description
if display_algo:
if display_descr:
description += "["
description += algo
if display_descr:
description += "]"
r.task_spec = replace(r.task_spec, description=description)
display_results.append(r)
return display_results
def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
if not results:
return
runtime: Dict[str, Dict[str, float]] = defaultdict(dict)
memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict)
all_descriptions: List[str] = []
for r in results:
# Hacky: use a list to preserve order
if r.task_spec.description not in all_descriptions:
if r.task_spec.description in BASELINE_DESCRIPTIONS:
all_descriptions.insert(0, r.task_spec.description)
else:
all_descriptions.append(r.task_spec.description)
runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean
memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use
all_data_mem: List[Any] = []
all_data_run: List[Any] = []
for key, runtime_values in runtime.items():
memory_values = memory_usage[key]
denom = memory_values.get(all_descriptions[0], math.inf)
if denom == 0:
all_data_mem.append([key] + [0] * len(all_descriptions))
else:
all_data_mem.append(
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
)
all_data_run.append(
[key]
+ [
runtime_values.get(all_descriptions[0], 0)
/ runtime_values.get(d, math.inf)
for d in all_descriptions
]
)
if all_descriptions[0] == "":
all_descriptions[0] = "baseline"
else:
all_descriptions[0] = f"{all_descriptions[0]} (baseline)"
for data, filename, title in [
(all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"),
(
all_data_run,
"runtime.png",
"Runtime speedup (vs baseline, higher is better)",
),
]:
df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions)
df.plot(
x="Configuration",
kind="bar",
stacked=False,
title=title,
)
plt.tight_layout()
filename_full = os.path.join(store_results_folder, filename)
plt.savefig(filename_full)
print(f"Saved plot: {filename_full}")
def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None:
"""
Helper function to run benchmarks.
Supports loading previous results for comparison, and saving current results to file.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--fn", default=None, type=str, help="Only benchmark this function"
)
parser.add_argument(
"--label", default=None, type=str, help="Store results to a file"
)
parser.add_argument(
"--fail_if_regression",
action="store_true",
help="Enabled in CI to check against performance regressions",
)
parser.add_argument(
"--compare",
default=None,
type=str,
help="Compare to previously stored benchmarks (coma separated)",
)
parser.add_argument(
"--omit-baselines",
action="store_true",
help="Do not run the (potentially slow) baselines",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Skip intermediate results and progress bar",
)
args = parser.parse_args()
if args.fn is not None and args.fn != benchmark_fn.__name__:
print(f'Skipping benchmark "{benchmark_fn.__name__}"')
return
benchmark_run_and_compare(
benchmark_fn=benchmark_fn,
cases=cases,
optimized_label="optimized" if args.label is None else args.label,
fail_if_regression=args.fail_if_regression,
compare=args.compare.split(",") if args.compare is not None else [],
quiet=args.quiet,
omit_baselines=args.omit_baselines,
**kwargs,
)
def benchmark_run_and_compare(
benchmark_fn,
cases: List[Dict[str, Any]],
compare: List[str],
omit_baselines: bool = False,
fail_if_regression: bool = False,
quiet: bool = False,
optimized_label: str = "optimized",
*,
min_run_time: int = 2,
atol_s: float = 30e-6,
rtol: float = 0.05,
) -> None:
SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True
results_compare_to = []
results = []
store_results_folder = os.path.expanduser(
os.path.join(
os.environ.get(
"XFORMERS_BENCHMARKS_CACHE",
os.path.join("~", ".cache", "xformers", "benchmarks"),
),
benchmark_fn.__name__,
)
)
try:
env = (
torch.cuda.get_device_name(torch.cuda.current_device())
.replace(" ", "_")
.replace("-", "_")
.replace(".", "_")
)
except (RuntimeError, AssertionError): # No GPU
env = "cpu"
assert (
"." not in optimized_label
), f"label=`{optimized_label}` should not contain dots"
assert "." not in env, f"env=`{env}` should not contain dots"
os.makedirs(store_results_folder, exist_ok=True)
# Load runs that we want to compare to
skip_vanilla_tasks = set()
for cmp_name in compare:
name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*"
for filename in glob.glob(
os.path.join(store_results_folder, f"{name_with_env}.csv")
):
loaded = _benchmark_results_from_csv(filename)
for m, r in loaded:
if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE:
skip_vanilla_tasks.add(
(r.task_spec.sub_label, r.task_spec.num_threads)
)
results_compare_to += loaded
if not quiet:
pbar = tqdm.tqdm(cases, leave=False)
cases = pbar
for case in cases:
if quiet:
print(str(case))
else:
pbar.write(f"====== {str(case)} ======")
try:
benchmarks_generator = benchmark_fn(**case)
except NotImplementedError:
# pbar.write(f"Skipped (NotImplementedError)")
continue
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
continue
name = None
try:
for benchmark_object in benchmarks_generator:
is_optimized = (
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
)
metadata = {}
if is_optimized:
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
benchmark_object._task_spec = replace(
benchmark_object._task_spec, description=optimized_label
)
elif (
omit_baselines
or (
benchmark_object._task_spec.sub_label,
benchmark_object._task_spec.num_threads,
)
in skip_vanilla_tasks
):
continue
memory = math.inf
try:
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
mem_begin = torch.cuda.max_memory_allocated() / 2**20
benchmark_object._task_spec = replace(
benchmark_object._task_spec, env=env
)
measurement = benchmark_object.blocked_autorange(
min_run_time=min_run_time
)
torch.cuda.synchronize()
results.append((metadata, measurement))
name = measurement.task_spec.description
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
measurement.mem_use = memory
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
finally:
del benchmark_object
if not quiet:
pbar.write(f"{name}: memory used: {memory} MB")
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
# Display results for benchmarks we just calculated
if name is not None and not quiet:
def matches_current(r):
return (
r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label
and r[1].task_spec.label == results[-1][1].task_spec.label
)
pbar.write(
str(
benchmark.Compare(
_finalize_results(
list(filter(matches_current, results))
+ list(filter(matches_current, results_compare_to))
)
)
)
)
results_for_print = _finalize_results(results + results_compare_to)
benchmark.Compare(results_for_print).print()
_render_bar_plot(results_for_print, store_results_folder)
# Save runs to a file
if results and optimized_label is not None:
write_to_path = os.path.join(
store_results_folder, f"{optimized_label}.{env}.csv"
)
_benchmark_results_to_csv(write_to_path, results)
print(f"Saved results to {write_to_path}")
if fail_if_regression:
_fail_if_regressions(
results, reference=results_compare_to, atol_s=atol_s, rtol=rtol
)
def _fail_if_regressions(
results: List[Any], reference: List[Any], atol_s: float, rtol: float
) -> None:
def get_measurement_id(r):
return (
r[0].get(META_ALGORITHM, ""),
r[1].task_spec.label,
r[1].task_spec.sub_label,
r[1].task_spec.env,
)
id_to_result = {}
for r in results:
id_to_result[get_measurement_id(r)] = r[1]
num_better = 0
num_worse = 0
num_nochange = 0
num_unk = 0
reference_set = set()
for ref in reference:
if ref[1].task_spec.description in BASELINE_DESCRIPTIONS:
continue
benchmark_id = get_measurement_id(ref)
if benchmark_id in reference_set:
raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}")
reference_set.add(benchmark_id)
if benchmark_id not in id_to_result:
num_unk += 1
continue
res = id_to_result[benchmark_id]
# If significative change
if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s:
is_now_better = res.mean < ref[1].mean
if is_now_better:
num_better += 1
else:
num_worse += 1
cmp = "IMPROVED" if is_now_better else "REGRESS "
print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}")
else:
num_nochange += 1
print("Regression test summary:")
print(f" Better : {num_better}")
print(f" No change: {num_nochange}")
print(f" Worse : {num_worse}")
if num_unk > 0:
print(f" (no ref) : {num_unk}")
if num_worse > 1:
raise RuntimeError("At least one benchmark regressed!")
if num_nochange == 0:
raise RuntimeError("No reference found")