661 lines
21 KiB
Python
661 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import contextlib
|
|
import copy
|
|
import csv
|
|
import glob
|
|
import logging
|
|
import math
|
|
import os
|
|
import tempfile
|
|
from collections import defaultdict, namedtuple
|
|
from dataclasses import replace
|
|
from typing import Any, Dict, Generator, List, Set, Tuple
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import pandas as pd
|
|
import seaborn as sns
|
|
import torch
|
|
import tqdm
|
|
from torch.utils import benchmark
|
|
|
|
sns.set()
|
|
|
|
TestCase = namedtuple("TestCase", ["function", "name"])
|
|
|
|
|
|
_triton_is_available = torch.cuda.is_available()
|
|
if _triton_is_available:
|
|
try:
|
|
import triton
|
|
except ImportError as e:
|
|
logging.warning(f"Triton is not available: {e}.\nbench_functions")
|
|
_triton_is_available = False
|
|
|
|
|
|
def pretty_print(results, title, units):
|
|
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
|
|
print(title)
|
|
header = " Units: {:<45}".format(units)
|
|
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
|
|
|
|
offset = len(header)
|
|
print(
|
|
"|-{}|".format("-" * offset)
|
|
+ "".join("{}|".format("-" * 20) for _ in results.keys())
|
|
)
|
|
|
|
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
|
for v in results.values():
|
|
for k in v.keys():
|
|
workloads[k].append(v[k])
|
|
|
|
for k, w in workloads.items():
|
|
print(
|
|
"| {0:<{offset}}|".format(k, offset=offset)
|
|
+ "".join("{:<20}|".format(v) for v in w)
|
|
)
|
|
|
|
print("")
|
|
|
|
|
|
def pretty_plot(
|
|
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
|
|
):
|
|
"""Graph out the contents of a dict.
|
|
Dash key means that if the result label has this key, then it will be displayed with a dash
|
|
"""
|
|
|
|
if not filename:
|
|
filename = title + ".png"
|
|
|
|
# Sanitize the filename
|
|
filename = (
|
|
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
|
)
|
|
|
|
# Gather all the results in "collumns"
|
|
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
|
for v in results.values():
|
|
for k in v.keys():
|
|
workloads[k].append(float(v[k]))
|
|
|
|
# Make sure that the plot is big enough
|
|
f = plt.figure()
|
|
f.set_figwidth(6)
|
|
f.set_figheight(6)
|
|
|
|
# Display the collections
|
|
for k, v in workloads.items():
|
|
if dash_key and dash_key in k:
|
|
plt.plot(list(results.keys()), v, "--")
|
|
else:
|
|
plt.plot(list(results.keys()), v)
|
|
|
|
plt.title(title)
|
|
plt.legend(list(workloads.keys()), loc=legend_loc)
|
|
plt.ylabel(units)
|
|
plt.xticks(rotation=45)
|
|
|
|
plt.savefig(filename, bbox_inches="tight")
|
|
plt.close(f)
|
|
|
|
|
|
if _triton_is_available:
|
|
|
|
def bench_functions(
|
|
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
|
|
):
|
|
device = torch.device("cuda")
|
|
|
|
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
|
|
results: Dict[str, Any] = {}
|
|
|
|
for B, M, K in shapes:
|
|
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True)
|
|
|
|
for testcase in test_cases:
|
|
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
|
|
|
|
metric = metric_transform(a, time)
|
|
|
|
key = f"B={B}, M={M}, K={K}"
|
|
if key not in results:
|
|
results[key] = {}
|
|
|
|
results[key][testcase.name] = f"{metric:.1f}"
|
|
|
|
pretty_print(
|
|
results,
|
|
title=" ------------- Type: {} ------------- ".format(dtype),
|
|
units=unit,
|
|
)
|
|
pretty_plot(results, title + str(dtype), unit, dash_key="pytorch")
|
|
|
|
|
|
def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
|
|
"""Graph out the contents of a dict.
|
|
Dash key means that if the result label has this key, then it will be displayed with a dash
|
|
"""
|
|
|
|
if not filename:
|
|
filename = title + ".png"
|
|
|
|
# Sanitize the filename
|
|
filename = (
|
|
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
|
)
|
|
|
|
xlabels = list(results.keys())
|
|
# Gather all the results in "collumns"
|
|
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
|
for v in results.values():
|
|
for k in v.keys():
|
|
workloads[k].append(float(v[k]))
|
|
|
|
options = list(workloads.keys())
|
|
group_len = len(options)
|
|
for key in workloads.keys():
|
|
num_groups = len(workloads[key])
|
|
break
|
|
group_width = group_len + 1
|
|
|
|
# Make sure that the plot is big enough
|
|
f = plt.figure()
|
|
f.set_figwidth(6)
|
|
f.set_figheight(6)
|
|
|
|
for idx in range(group_len):
|
|
option = options[idx]
|
|
values = workloads[option]
|
|
xloc = np.arange(1 + idx, group_width * num_groups, group_width)
|
|
plt.bar(xloc, values, width=1, edgecolor="black")
|
|
|
|
plt.title(title)
|
|
plt.legend(list(workloads.keys()), loc="upper right")
|
|
plt.ylabel(units)
|
|
|
|
ax = plt.gca()
|
|
xticks_loc = np.arange(
|
|
1 + (group_len - 1) / 2.0, group_width * num_groups, group_width
|
|
)
|
|
ax.set_xticks(xticks_loc, xlabels)
|
|
plt.xticks(rotation=45)
|
|
|
|
plt.setp(ax.xaxis.get_majorticklabels(), ha="right")
|
|
ax.set_axisbelow(True)
|
|
ax.yaxis.grid(color="gray", linestyle="dashed")
|
|
ax.xaxis.grid(color="gray", linestyle="dashed")
|
|
|
|
plt.savefig(filename, bbox_inches="tight")
|
|
plt.close(f)
|
|
|
|
|
|
def rmf(filename: str) -> None:
|
|
"""Remove a file like rm -f."""
|
|
try:
|
|
os.remove(filename)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def temp_files_ctx(num: int) -> Generator:
|
|
"""A context to get tempfiles and ensure they are cleaned up."""
|
|
files = [tempfile.mkstemp()[1] for _ in range(num)]
|
|
|
|
yield tuple(files)
|
|
|
|
# temp files could have been removed, so we use rmf.
|
|
for name in files:
|
|
rmf(name)
|
|
|
|
|
|
META_ALGORITHM = "algorithm"
|
|
BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"]
|
|
|
|
|
|
# Serialize/unserialize to CSV
|
|
# We could use pkl, but resort to CSV for readability
|
|
def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]:
|
|
parts = os.path.basename(filename).split(".")
|
|
env = ""
|
|
description = ""
|
|
if len(parts) == 3:
|
|
env = parts[1]
|
|
description = parts[0]
|
|
|
|
data = []
|
|
with open(filename, "r") as csvfile:
|
|
reader = csv.DictReader(csvfile)
|
|
for row in reader:
|
|
if description != "" and row["description"] not in BASELINE_DESCRIPTIONS:
|
|
row["description"] = description
|
|
task_spec = benchmark.utils.common.TaskSpec(
|
|
stmt="",
|
|
setup="",
|
|
global_setup="",
|
|
label=row["label"],
|
|
sub_label=row["sub_label"],
|
|
description=row["description"],
|
|
env=env,
|
|
num_threads=int(row["num_threads"]),
|
|
)
|
|
measurement = benchmark.utils.common.Measurement(
|
|
number_per_run=1,
|
|
raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)],
|
|
task_spec=task_spec,
|
|
)
|
|
measurement.mem_use = float(row["mem_use_mb"]) # type: ignore
|
|
data.append(
|
|
(
|
|
{
|
|
META_ALGORITHM: row["algorithm"]
|
|
if row["algorithm"] != ""
|
|
else None,
|
|
},
|
|
measurement,
|
|
)
|
|
)
|
|
return data
|
|
|
|
|
|
def _benchmark_results_to_csv(
|
|
filename: str, results: List[Tuple[Dict[str, Any], Any]]
|
|
) -> None:
|
|
data = [
|
|
{
|
|
"sub_label": r.task_spec.sub_label,
|
|
"label": r.task_spec.label,
|
|
"num_threads": r.task_spec.num_threads,
|
|
"algorithm": metadata.get(META_ALGORITHM, ""),
|
|
"description": r.task_spec.description
|
|
if r.task_spec.description in BASELINE_DESCRIPTIONS
|
|
else "",
|
|
"runtime_us": int(1000 * 1000 * r.mean),
|
|
"mem_use_mb": r.mem_use,
|
|
}
|
|
for metadata, r in results
|
|
]
|
|
with open(filename, "w+", newline="") as csvfile:
|
|
writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys()))
|
|
writer.writeheader()
|
|
for d in data:
|
|
writer.writerow(d)
|
|
|
|
|
|
def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
|
|
"""
|
|
Returns a `benchmark.Compare` object, except that if we have runs
|
|
with different algorithms, we also add the algorithm name
|
|
in the column titles
|
|
"""
|
|
all_algorithms: Set[str] = set()
|
|
all_description: Set[str] = set()
|
|
for metadata, r in results:
|
|
algo = metadata.get(META_ALGORITHM, None)
|
|
if algo is not None:
|
|
all_algorithms.add(algo)
|
|
all_description.add(r.task_spec.description)
|
|
display_algo = len(all_algorithms) > 1
|
|
display_descr = len(all_description) > 1
|
|
|
|
display_results = []
|
|
for metadata, r in results:
|
|
algo = metadata.get(META_ALGORITHM, None)
|
|
if algo is None:
|
|
display_results.append(r)
|
|
else:
|
|
r = copy.copy(r)
|
|
description = ""
|
|
if display_descr:
|
|
description = r.task_spec.description
|
|
if display_algo:
|
|
if display_descr:
|
|
description += "["
|
|
description += algo
|
|
if display_descr:
|
|
description += "]"
|
|
r.task_spec = replace(r.task_spec, description=description)
|
|
display_results.append(r)
|
|
return display_results
|
|
|
|
|
|
def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
|
|
if not results:
|
|
return
|
|
runtime: Dict[str, Dict[str, float]] = defaultdict(dict)
|
|
memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict)
|
|
all_descriptions: List[str] = []
|
|
for r in results:
|
|
# Hacky: use a list to preserve order
|
|
if r.task_spec.description not in all_descriptions:
|
|
if r.task_spec.description in BASELINE_DESCRIPTIONS:
|
|
all_descriptions.insert(0, r.task_spec.description)
|
|
else:
|
|
all_descriptions.append(r.task_spec.description)
|
|
runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean
|
|
memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use
|
|
all_data_mem: List[Any] = []
|
|
all_data_run: List[Any] = []
|
|
for key, runtime_values in runtime.items():
|
|
memory_values = memory_usage[key]
|
|
denom = memory_values.get(all_descriptions[0], math.inf)
|
|
if denom == 0:
|
|
all_data_mem.append([key] + [0] * len(all_descriptions))
|
|
else:
|
|
all_data_mem.append(
|
|
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
|
|
)
|
|
all_data_run.append(
|
|
[key]
|
|
+ [
|
|
runtime_values.get(all_descriptions[0], 0)
|
|
/ runtime_values.get(d, math.inf)
|
|
for d in all_descriptions
|
|
]
|
|
)
|
|
if all_descriptions[0] == "":
|
|
all_descriptions[0] = "baseline"
|
|
else:
|
|
all_descriptions[0] = f"{all_descriptions[0]} (baseline)"
|
|
|
|
for data, filename, title in [
|
|
(all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"),
|
|
(
|
|
all_data_run,
|
|
"runtime.png",
|
|
"Runtime speedup (vs baseline, higher is better)",
|
|
),
|
|
]:
|
|
df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions)
|
|
df.plot(
|
|
x="Configuration",
|
|
kind="bar",
|
|
stacked=False,
|
|
title=title,
|
|
)
|
|
plt.tight_layout()
|
|
filename_full = os.path.join(store_results_folder, filename)
|
|
plt.savefig(filename_full)
|
|
print(f"Saved plot: {filename_full}")
|
|
|
|
|
|
def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None:
|
|
"""
|
|
Helper function to run benchmarks.
|
|
Supports loading previous results for comparison, and saving current results to file.
|
|
"""
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--fn", default=None, type=str, help="Only benchmark this function"
|
|
)
|
|
parser.add_argument(
|
|
"--label", default=None, type=str, help="Store results to a file"
|
|
)
|
|
parser.add_argument(
|
|
"--fail_if_regression",
|
|
action="store_true",
|
|
help="Enabled in CI to check against performance regressions",
|
|
)
|
|
parser.add_argument(
|
|
"--compare",
|
|
default=None,
|
|
type=str,
|
|
help="Compare to previously stored benchmarks (coma separated)",
|
|
)
|
|
parser.add_argument(
|
|
"--omit-baselines",
|
|
action="store_true",
|
|
help="Do not run the (potentially slow) baselines",
|
|
)
|
|
parser.add_argument(
|
|
"--quiet",
|
|
action="store_true",
|
|
help="Skip intermediate results and progress bar",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.fn is not None and args.fn != benchmark_fn.__name__:
|
|
print(f'Skipping benchmark "{benchmark_fn.__name__}"')
|
|
return
|
|
benchmark_run_and_compare(
|
|
benchmark_fn=benchmark_fn,
|
|
cases=cases,
|
|
optimized_label="optimized" if args.label is None else args.label,
|
|
fail_if_regression=args.fail_if_regression,
|
|
compare=args.compare.split(",") if args.compare is not None else [],
|
|
quiet=args.quiet,
|
|
omit_baselines=args.omit_baselines,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def benchmark_run_and_compare(
|
|
benchmark_fn,
|
|
cases: List[Dict[str, Any]],
|
|
compare: List[str],
|
|
omit_baselines: bool = False,
|
|
fail_if_regression: bool = False,
|
|
quiet: bool = False,
|
|
optimized_label: str = "optimized",
|
|
*,
|
|
min_run_time: int = 2,
|
|
atol_s: float = 30e-6,
|
|
rtol: float = 0.05,
|
|
) -> None:
|
|
SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True
|
|
results_compare_to = []
|
|
results = []
|
|
|
|
store_results_folder = os.path.expanduser(
|
|
os.path.join(
|
|
os.environ.get(
|
|
"XFORMERS_BENCHMARKS_CACHE",
|
|
os.path.join("~", ".cache", "xformers", "benchmarks"),
|
|
),
|
|
benchmark_fn.__name__,
|
|
)
|
|
)
|
|
|
|
try:
|
|
env = (
|
|
torch.cuda.get_device_name(torch.cuda.current_device())
|
|
.replace(" ", "_")
|
|
.replace("-", "_")
|
|
.replace(".", "_")
|
|
)
|
|
except (RuntimeError, AssertionError): # No GPU
|
|
env = "cpu"
|
|
assert (
|
|
"." not in optimized_label
|
|
), f"label=`{optimized_label}` should not contain dots"
|
|
assert "." not in env, f"env=`{env}` should not contain dots"
|
|
|
|
os.makedirs(store_results_folder, exist_ok=True)
|
|
|
|
# Load runs that we want to compare to
|
|
skip_vanilla_tasks = set()
|
|
for cmp_name in compare:
|
|
name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*"
|
|
for filename in glob.glob(
|
|
os.path.join(store_results_folder, f"{name_with_env}.csv")
|
|
):
|
|
loaded = _benchmark_results_from_csv(filename)
|
|
for m, r in loaded:
|
|
if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE:
|
|
skip_vanilla_tasks.add(
|
|
(r.task_spec.sub_label, r.task_spec.num_threads)
|
|
)
|
|
results_compare_to += loaded
|
|
|
|
if not quiet:
|
|
pbar = tqdm.tqdm(cases, leave=False)
|
|
cases = pbar
|
|
for case in cases:
|
|
if quiet:
|
|
print(str(case))
|
|
else:
|
|
pbar.write(f"====== {str(case)} ======")
|
|
try:
|
|
benchmarks_generator = benchmark_fn(**case)
|
|
except NotImplementedError:
|
|
# pbar.write(f"Skipped (NotImplementedError)")
|
|
continue
|
|
except RuntimeError as e:
|
|
if "CUDA out of memory" not in str(e):
|
|
raise
|
|
if not quiet:
|
|
pbar.write("Skipped (OOM)")
|
|
continue
|
|
|
|
name = None
|
|
try:
|
|
for benchmark_object in benchmarks_generator:
|
|
is_optimized = (
|
|
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
|
|
)
|
|
metadata = {}
|
|
if is_optimized:
|
|
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
|
|
benchmark_object._task_spec = replace(
|
|
benchmark_object._task_spec, description=optimized_label
|
|
)
|
|
elif (
|
|
omit_baselines
|
|
or (
|
|
benchmark_object._task_spec.sub_label,
|
|
benchmark_object._task_spec.num_threads,
|
|
)
|
|
in skip_vanilla_tasks
|
|
):
|
|
continue
|
|
|
|
memory = math.inf
|
|
try:
|
|
torch.cuda.synchronize()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
mem_begin = torch.cuda.max_memory_allocated() / 2**20
|
|
benchmark_object._task_spec = replace(
|
|
benchmark_object._task_spec, env=env
|
|
)
|
|
measurement = benchmark_object.blocked_autorange(
|
|
min_run_time=min_run_time
|
|
)
|
|
torch.cuda.synchronize()
|
|
results.append((metadata, measurement))
|
|
name = measurement.task_spec.description
|
|
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
|
|
measurement.mem_use = memory
|
|
except RuntimeError as e:
|
|
if "CUDA out of memory" not in str(e):
|
|
raise
|
|
if not quiet:
|
|
pbar.write("Skipped (OOM)")
|
|
finally:
|
|
del benchmark_object
|
|
if not quiet:
|
|
pbar.write(f"{name}: memory used: {memory} MB")
|
|
except RuntimeError as e:
|
|
if "CUDA out of memory" not in str(e):
|
|
raise
|
|
if not quiet:
|
|
pbar.write("Skipped (OOM)")
|
|
# Display results for benchmarks we just calculated
|
|
if name is not None and not quiet:
|
|
|
|
def matches_current(r):
|
|
return (
|
|
r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label
|
|
and r[1].task_spec.label == results[-1][1].task_spec.label
|
|
)
|
|
|
|
pbar.write(
|
|
str(
|
|
benchmark.Compare(
|
|
_finalize_results(
|
|
list(filter(matches_current, results))
|
|
+ list(filter(matches_current, results_compare_to))
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
results_for_print = _finalize_results(results + results_compare_to)
|
|
benchmark.Compare(results_for_print).print()
|
|
_render_bar_plot(results_for_print, store_results_folder)
|
|
|
|
# Save runs to a file
|
|
if results and optimized_label is not None:
|
|
write_to_path = os.path.join(
|
|
store_results_folder, f"{optimized_label}.{env}.csv"
|
|
)
|
|
_benchmark_results_to_csv(write_to_path, results)
|
|
print(f"Saved results to {write_to_path}")
|
|
|
|
if fail_if_regression:
|
|
_fail_if_regressions(
|
|
results, reference=results_compare_to, atol_s=atol_s, rtol=rtol
|
|
)
|
|
|
|
|
|
def _fail_if_regressions(
|
|
results: List[Any], reference: List[Any], atol_s: float, rtol: float
|
|
) -> None:
|
|
def get_measurement_id(r):
|
|
return (
|
|
r[0].get(META_ALGORITHM, ""),
|
|
r[1].task_spec.label,
|
|
r[1].task_spec.sub_label,
|
|
r[1].task_spec.env,
|
|
)
|
|
|
|
id_to_result = {}
|
|
for r in results:
|
|
id_to_result[get_measurement_id(r)] = r[1]
|
|
|
|
num_better = 0
|
|
num_worse = 0
|
|
num_nochange = 0
|
|
num_unk = 0
|
|
reference_set = set()
|
|
for ref in reference:
|
|
if ref[1].task_spec.description in BASELINE_DESCRIPTIONS:
|
|
continue
|
|
benchmark_id = get_measurement_id(ref)
|
|
if benchmark_id in reference_set:
|
|
raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}")
|
|
reference_set.add(benchmark_id)
|
|
if benchmark_id not in id_to_result:
|
|
num_unk += 1
|
|
continue
|
|
res = id_to_result[benchmark_id]
|
|
# If significative change
|
|
if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s:
|
|
is_now_better = res.mean < ref[1].mean
|
|
if is_now_better:
|
|
num_better += 1
|
|
else:
|
|
num_worse += 1
|
|
cmp = "IMPROVED" if is_now_better else "REGRESS "
|
|
print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}")
|
|
else:
|
|
num_nochange += 1
|
|
|
|
print("Regression test summary:")
|
|
print(f" Better : {num_better}")
|
|
print(f" No change: {num_nochange}")
|
|
print(f" Worse : {num_worse}")
|
|
if num_unk > 0:
|
|
print(f" (no ref) : {num_unk}")
|
|
if num_worse > 1:
|
|
raise RuntimeError("At least one benchmark regressed!")
|
|
if num_nochange == 0:
|
|
raise RuntimeError("No reference found")
|