First commit
This commit is contained in:
660
pkgs/xformers/benchmarks/utils.py
Normal file
660
pkgs/xformers/benchmarks/utils.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import copy
|
||||
import csv
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict, namedtuple
|
||||
from dataclasses import replace
|
||||
from typing import Any, Dict, Generator, List, Set, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.utils import benchmark
|
||||
|
||||
sns.set()
|
||||
|
||||
TestCase = namedtuple("TestCase", ["function", "name"])
|
||||
|
||||
|
||||
_triton_is_available = torch.cuda.is_available()
|
||||
if _triton_is_available:
|
||||
try:
|
||||
import triton
|
||||
except ImportError as e:
|
||||
logging.warning(f"Triton is not available: {e}.\nbench_functions")
|
||||
_triton_is_available = False
|
||||
|
||||
|
||||
def pretty_print(results, title, units):
|
||||
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
|
||||
print(title)
|
||||
header = " Units: {:<45}".format(units)
|
||||
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
|
||||
|
||||
offset = len(header)
|
||||
print(
|
||||
"|-{}|".format("-" * offset)
|
||||
+ "".join("{}|".format("-" * 20) for _ in results.keys())
|
||||
)
|
||||
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(v[k])
|
||||
|
||||
for k, w in workloads.items():
|
||||
print(
|
||||
"| {0:<{offset}}|".format(k, offset=offset)
|
||||
+ "".join("{:<20}|".format(v) for v in w)
|
||||
)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
def pretty_plot(
|
||||
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
|
||||
):
|
||||
"""Graph out the contents of a dict.
|
||||
Dash key means that if the result label has this key, then it will be displayed with a dash
|
||||
"""
|
||||
|
||||
if not filename:
|
||||
filename = title + ".png"
|
||||
|
||||
# Sanitize the filename
|
||||
filename = (
|
||||
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
||||
)
|
||||
|
||||
# Gather all the results in "collumns"
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(float(v[k]))
|
||||
|
||||
# Make sure that the plot is big enough
|
||||
f = plt.figure()
|
||||
f.set_figwidth(6)
|
||||
f.set_figheight(6)
|
||||
|
||||
# Display the collections
|
||||
for k, v in workloads.items():
|
||||
if dash_key and dash_key in k:
|
||||
plt.plot(list(results.keys()), v, "--")
|
||||
else:
|
||||
plt.plot(list(results.keys()), v)
|
||||
|
||||
plt.title(title)
|
||||
plt.legend(list(workloads.keys()), loc=legend_loc)
|
||||
plt.ylabel(units)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
plt.savefig(filename, bbox_inches="tight")
|
||||
plt.close(f)
|
||||
|
||||
|
||||
if _triton_is_available:
|
||||
|
||||
def bench_functions(
|
||||
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in shapes:
|
||||
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
for testcase in test_cases:
|
||||
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
|
||||
|
||||
metric = metric_transform(a, time)
|
||||
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{metric:.1f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=" ------------- Type: {} ------------- ".format(dtype),
|
||||
units=unit,
|
||||
)
|
||||
pretty_plot(results, title + str(dtype), unit, dash_key="pytorch")
|
||||
|
||||
|
||||
def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
|
||||
"""Graph out the contents of a dict.
|
||||
Dash key means that if the result label has this key, then it will be displayed with a dash
|
||||
"""
|
||||
|
||||
if not filename:
|
||||
filename = title + ".png"
|
||||
|
||||
# Sanitize the filename
|
||||
filename = (
|
||||
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
||||
)
|
||||
|
||||
xlabels = list(results.keys())
|
||||
# Gather all the results in "collumns"
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(float(v[k]))
|
||||
|
||||
options = list(workloads.keys())
|
||||
group_len = len(options)
|
||||
for key in workloads.keys():
|
||||
num_groups = len(workloads[key])
|
||||
break
|
||||
group_width = group_len + 1
|
||||
|
||||
# Make sure that the plot is big enough
|
||||
f = plt.figure()
|
||||
f.set_figwidth(6)
|
||||
f.set_figheight(6)
|
||||
|
||||
for idx in range(group_len):
|
||||
option = options[idx]
|
||||
values = workloads[option]
|
||||
xloc = np.arange(1 + idx, group_width * num_groups, group_width)
|
||||
plt.bar(xloc, values, width=1, edgecolor="black")
|
||||
|
||||
plt.title(title)
|
||||
plt.legend(list(workloads.keys()), loc="upper right")
|
||||
plt.ylabel(units)
|
||||
|
||||
ax = plt.gca()
|
||||
xticks_loc = np.arange(
|
||||
1 + (group_len - 1) / 2.0, group_width * num_groups, group_width
|
||||
)
|
||||
ax.set_xticks(xticks_loc, xlabels)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), ha="right")
|
||||
ax.set_axisbelow(True)
|
||||
ax.yaxis.grid(color="gray", linestyle="dashed")
|
||||
ax.xaxis.grid(color="gray", linestyle="dashed")
|
||||
|
||||
plt.savefig(filename, bbox_inches="tight")
|
||||
plt.close(f)
|
||||
|
||||
|
||||
def rmf(filename: str) -> None:
|
||||
"""Remove a file like rm -f."""
|
||||
try:
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_files_ctx(num: int) -> Generator:
|
||||
"""A context to get tempfiles and ensure they are cleaned up."""
|
||||
files = [tempfile.mkstemp()[1] for _ in range(num)]
|
||||
|
||||
yield tuple(files)
|
||||
|
||||
# temp files could have been removed, so we use rmf.
|
||||
for name in files:
|
||||
rmf(name)
|
||||
|
||||
|
||||
META_ALGORITHM = "algorithm"
|
||||
BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"]
|
||||
|
||||
|
||||
# Serialize/unserialize to CSV
|
||||
# We could use pkl, but resort to CSV for readability
|
||||
def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]:
|
||||
parts = os.path.basename(filename).split(".")
|
||||
env = ""
|
||||
description = ""
|
||||
if len(parts) == 3:
|
||||
env = parts[1]
|
||||
description = parts[0]
|
||||
|
||||
data = []
|
||||
with open(filename, "r") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
if description != "" and row["description"] not in BASELINE_DESCRIPTIONS:
|
||||
row["description"] = description
|
||||
task_spec = benchmark.utils.common.TaskSpec(
|
||||
stmt="",
|
||||
setup="",
|
||||
global_setup="",
|
||||
label=row["label"],
|
||||
sub_label=row["sub_label"],
|
||||
description=row["description"],
|
||||
env=env,
|
||||
num_threads=int(row["num_threads"]),
|
||||
)
|
||||
measurement = benchmark.utils.common.Measurement(
|
||||
number_per_run=1,
|
||||
raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)],
|
||||
task_spec=task_spec,
|
||||
)
|
||||
measurement.mem_use = float(row["mem_use_mb"]) # type: ignore
|
||||
data.append(
|
||||
(
|
||||
{
|
||||
META_ALGORITHM: row["algorithm"]
|
||||
if row["algorithm"] != ""
|
||||
else None,
|
||||
},
|
||||
measurement,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _benchmark_results_to_csv(
|
||||
filename: str, results: List[Tuple[Dict[str, Any], Any]]
|
||||
) -> None:
|
||||
data = [
|
||||
{
|
||||
"sub_label": r.task_spec.sub_label,
|
||||
"label": r.task_spec.label,
|
||||
"num_threads": r.task_spec.num_threads,
|
||||
"algorithm": metadata.get(META_ALGORITHM, ""),
|
||||
"description": r.task_spec.description
|
||||
if r.task_spec.description in BASELINE_DESCRIPTIONS
|
||||
else "",
|
||||
"runtime_us": int(1000 * 1000 * r.mean),
|
||||
"mem_use_mb": r.mem_use,
|
||||
}
|
||||
for metadata, r in results
|
||||
]
|
||||
with open(filename, "w+", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys()))
|
||||
writer.writeheader()
|
||||
for d in data:
|
||||
writer.writerow(d)
|
||||
|
||||
|
||||
def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
|
||||
"""
|
||||
Returns a `benchmark.Compare` object, except that if we have runs
|
||||
with different algorithms, we also add the algorithm name
|
||||
in the column titles
|
||||
"""
|
||||
all_algorithms: Set[str] = set()
|
||||
all_description: Set[str] = set()
|
||||
for metadata, r in results:
|
||||
algo = metadata.get(META_ALGORITHM, None)
|
||||
if algo is not None:
|
||||
all_algorithms.add(algo)
|
||||
all_description.add(r.task_spec.description)
|
||||
display_algo = len(all_algorithms) > 1
|
||||
display_descr = len(all_description) > 1
|
||||
|
||||
display_results = []
|
||||
for metadata, r in results:
|
||||
algo = metadata.get(META_ALGORITHM, None)
|
||||
if algo is None:
|
||||
display_results.append(r)
|
||||
else:
|
||||
r = copy.copy(r)
|
||||
description = ""
|
||||
if display_descr:
|
||||
description = r.task_spec.description
|
||||
if display_algo:
|
||||
if display_descr:
|
||||
description += "["
|
||||
description += algo
|
||||
if display_descr:
|
||||
description += "]"
|
||||
r.task_spec = replace(r.task_spec, description=description)
|
||||
display_results.append(r)
|
||||
return display_results
|
||||
|
||||
|
||||
def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
|
||||
if not results:
|
||||
return
|
||||
runtime: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
all_descriptions: List[str] = []
|
||||
for r in results:
|
||||
# Hacky: use a list to preserve order
|
||||
if r.task_spec.description not in all_descriptions:
|
||||
if r.task_spec.description in BASELINE_DESCRIPTIONS:
|
||||
all_descriptions.insert(0, r.task_spec.description)
|
||||
else:
|
||||
all_descriptions.append(r.task_spec.description)
|
||||
runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean
|
||||
memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use
|
||||
all_data_mem: List[Any] = []
|
||||
all_data_run: List[Any] = []
|
||||
for key, runtime_values in runtime.items():
|
||||
memory_values = memory_usage[key]
|
||||
denom = memory_values.get(all_descriptions[0], math.inf)
|
||||
if denom == 0:
|
||||
all_data_mem.append([key] + [0] * len(all_descriptions))
|
||||
else:
|
||||
all_data_mem.append(
|
||||
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
|
||||
)
|
||||
all_data_run.append(
|
||||
[key]
|
||||
+ [
|
||||
runtime_values.get(all_descriptions[0], 0)
|
||||
/ runtime_values.get(d, math.inf)
|
||||
for d in all_descriptions
|
||||
]
|
||||
)
|
||||
if all_descriptions[0] == "":
|
||||
all_descriptions[0] = "baseline"
|
||||
else:
|
||||
all_descriptions[0] = f"{all_descriptions[0]} (baseline)"
|
||||
|
||||
for data, filename, title in [
|
||||
(all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"),
|
||||
(
|
||||
all_data_run,
|
||||
"runtime.png",
|
||||
"Runtime speedup (vs baseline, higher is better)",
|
||||
),
|
||||
]:
|
||||
df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions)
|
||||
df.plot(
|
||||
x="Configuration",
|
||||
kind="bar",
|
||||
stacked=False,
|
||||
title=title,
|
||||
)
|
||||
plt.tight_layout()
|
||||
filename_full = os.path.join(store_results_folder, filename)
|
||||
plt.savefig(filename_full)
|
||||
print(f"Saved plot: {filename_full}")
|
||||
|
||||
|
||||
def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None:
|
||||
"""
|
||||
Helper function to run benchmarks.
|
||||
Supports loading previous results for comparison, and saving current results to file.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--fn", default=None, type=str, help="Only benchmark this function"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label", default=None, type=str, help="Store results to a file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fail_if_regression",
|
||||
action="store_true",
|
||||
help="Enabled in CI to check against performance regressions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Compare to previously stored benchmarks (coma separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--omit-baselines",
|
||||
action="store_true",
|
||||
help="Do not run the (potentially slow) baselines",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="Skip intermediate results and progress bar",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.fn is not None and args.fn != benchmark_fn.__name__:
|
||||
print(f'Skipping benchmark "{benchmark_fn.__name__}"')
|
||||
return
|
||||
benchmark_run_and_compare(
|
||||
benchmark_fn=benchmark_fn,
|
||||
cases=cases,
|
||||
optimized_label="optimized" if args.label is None else args.label,
|
||||
fail_if_regression=args.fail_if_regression,
|
||||
compare=args.compare.split(",") if args.compare is not None else [],
|
||||
quiet=args.quiet,
|
||||
omit_baselines=args.omit_baselines,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_run_and_compare(
|
||||
benchmark_fn,
|
||||
cases: List[Dict[str, Any]],
|
||||
compare: List[str],
|
||||
omit_baselines: bool = False,
|
||||
fail_if_regression: bool = False,
|
||||
quiet: bool = False,
|
||||
optimized_label: str = "optimized",
|
||||
*,
|
||||
min_run_time: int = 2,
|
||||
atol_s: float = 30e-6,
|
||||
rtol: float = 0.05,
|
||||
) -> None:
|
||||
SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True
|
||||
results_compare_to = []
|
||||
results = []
|
||||
|
||||
store_results_folder = os.path.expanduser(
|
||||
os.path.join(
|
||||
os.environ.get(
|
||||
"XFORMERS_BENCHMARKS_CACHE",
|
||||
os.path.join("~", ".cache", "xformers", "benchmarks"),
|
||||
),
|
||||
benchmark_fn.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
env = (
|
||||
torch.cuda.get_device_name(torch.cuda.current_device())
|
||||
.replace(" ", "_")
|
||||
.replace("-", "_")
|
||||
.replace(".", "_")
|
||||
)
|
||||
except (RuntimeError, AssertionError): # No GPU
|
||||
env = "cpu"
|
||||
assert (
|
||||
"." not in optimized_label
|
||||
), f"label=`{optimized_label}` should not contain dots"
|
||||
assert "." not in env, f"env=`{env}` should not contain dots"
|
||||
|
||||
os.makedirs(store_results_folder, exist_ok=True)
|
||||
|
||||
# Load runs that we want to compare to
|
||||
skip_vanilla_tasks = set()
|
||||
for cmp_name in compare:
|
||||
name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*"
|
||||
for filename in glob.glob(
|
||||
os.path.join(store_results_folder, f"{name_with_env}.csv")
|
||||
):
|
||||
loaded = _benchmark_results_from_csv(filename)
|
||||
for m, r in loaded:
|
||||
if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE:
|
||||
skip_vanilla_tasks.add(
|
||||
(r.task_spec.sub_label, r.task_spec.num_threads)
|
||||
)
|
||||
results_compare_to += loaded
|
||||
|
||||
if not quiet:
|
||||
pbar = tqdm.tqdm(cases, leave=False)
|
||||
cases = pbar
|
||||
for case in cases:
|
||||
if quiet:
|
||||
print(str(case))
|
||||
else:
|
||||
pbar.write(f"====== {str(case)} ======")
|
||||
try:
|
||||
benchmarks_generator = benchmark_fn(**case)
|
||||
except NotImplementedError:
|
||||
# pbar.write(f"Skipped (NotImplementedError)")
|
||||
continue
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
continue
|
||||
|
||||
name = None
|
||||
try:
|
||||
for benchmark_object in benchmarks_generator:
|
||||
is_optimized = (
|
||||
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
|
||||
)
|
||||
metadata = {}
|
||||
if is_optimized:
|
||||
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
|
||||
benchmark_object._task_spec = replace(
|
||||
benchmark_object._task_spec, description=optimized_label
|
||||
)
|
||||
elif (
|
||||
omit_baselines
|
||||
or (
|
||||
benchmark_object._task_spec.sub_label,
|
||||
benchmark_object._task_spec.num_threads,
|
||||
)
|
||||
in skip_vanilla_tasks
|
||||
):
|
||||
continue
|
||||
|
||||
memory = math.inf
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_begin = torch.cuda.max_memory_allocated() / 2**20
|
||||
benchmark_object._task_spec = replace(
|
||||
benchmark_object._task_spec, env=env
|
||||
)
|
||||
measurement = benchmark_object.blocked_autorange(
|
||||
min_run_time=min_run_time
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
results.append((metadata, measurement))
|
||||
name = measurement.task_spec.description
|
||||
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
|
||||
measurement.mem_use = memory
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
finally:
|
||||
del benchmark_object
|
||||
if not quiet:
|
||||
pbar.write(f"{name}: memory used: {memory} MB")
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
# Display results for benchmarks we just calculated
|
||||
if name is not None and not quiet:
|
||||
|
||||
def matches_current(r):
|
||||
return (
|
||||
r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label
|
||||
and r[1].task_spec.label == results[-1][1].task_spec.label
|
||||
)
|
||||
|
||||
pbar.write(
|
||||
str(
|
||||
benchmark.Compare(
|
||||
_finalize_results(
|
||||
list(filter(matches_current, results))
|
||||
+ list(filter(matches_current, results_compare_to))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results_for_print = _finalize_results(results + results_compare_to)
|
||||
benchmark.Compare(results_for_print).print()
|
||||
_render_bar_plot(results_for_print, store_results_folder)
|
||||
|
||||
# Save runs to a file
|
||||
if results and optimized_label is not None:
|
||||
write_to_path = os.path.join(
|
||||
store_results_folder, f"{optimized_label}.{env}.csv"
|
||||
)
|
||||
_benchmark_results_to_csv(write_to_path, results)
|
||||
print(f"Saved results to {write_to_path}")
|
||||
|
||||
if fail_if_regression:
|
||||
_fail_if_regressions(
|
||||
results, reference=results_compare_to, atol_s=atol_s, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
def _fail_if_regressions(
|
||||
results: List[Any], reference: List[Any], atol_s: float, rtol: float
|
||||
) -> None:
|
||||
def get_measurement_id(r):
|
||||
return (
|
||||
r[0].get(META_ALGORITHM, ""),
|
||||
r[1].task_spec.label,
|
||||
r[1].task_spec.sub_label,
|
||||
r[1].task_spec.env,
|
||||
)
|
||||
|
||||
id_to_result = {}
|
||||
for r in results:
|
||||
id_to_result[get_measurement_id(r)] = r[1]
|
||||
|
||||
num_better = 0
|
||||
num_worse = 0
|
||||
num_nochange = 0
|
||||
num_unk = 0
|
||||
reference_set = set()
|
||||
for ref in reference:
|
||||
if ref[1].task_spec.description in BASELINE_DESCRIPTIONS:
|
||||
continue
|
||||
benchmark_id = get_measurement_id(ref)
|
||||
if benchmark_id in reference_set:
|
||||
raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}")
|
||||
reference_set.add(benchmark_id)
|
||||
if benchmark_id not in id_to_result:
|
||||
num_unk += 1
|
||||
continue
|
||||
res = id_to_result[benchmark_id]
|
||||
# If significative change
|
||||
if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s:
|
||||
is_now_better = res.mean < ref[1].mean
|
||||
if is_now_better:
|
||||
num_better += 1
|
||||
else:
|
||||
num_worse += 1
|
||||
cmp = "IMPROVED" if is_now_better else "REGRESS "
|
||||
print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}")
|
||||
else:
|
||||
num_nochange += 1
|
||||
|
||||
print("Regression test summary:")
|
||||
print(f" Better : {num_better}")
|
||||
print(f" No change: {num_nochange}")
|
||||
print(f" Worse : {num_worse}")
|
||||
if num_unk > 0:
|
||||
print(f" (no ref) : {num_unk}")
|
||||
if num_worse > 1:
|
||||
raise RuntimeError("At least one benchmark regressed!")
|
||||
if num_nochange == 0:
|
||||
raise RuntimeError("No reference found")
|
||||
Reference in New Issue
Block a user