First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict
if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
"Collect results from a given batch of distributed results"
)
parser.add_argument("-ck", "--checkpoint_path", required=True)
args = parser.parse_args()
logging.getLogger().setLevel(logging.INFO)
# Go through all the data in the given repo, try to find the end results
root = Path(args.checkpoint_path)
# - list all the mechanisms being benchmarked
results: Dict[str, Any] = {}
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
logging.info(f"\nFound results for {attention.stem}")
task_jsons = attention.glob("*/test_eval_summary.json")
results[attention.stem] = {}
for task in task_jsons:
task_name = task.stem.split("__")[0]
logging.info(f"Logs found for task: {task_name}")
results[attention.stem][task_name] = -1
found_result = False
# - collect the individual results
with open(task, "r") as result_file:
dct = json.load(result_file)
if "test_accu_mean" in dct:
found_result = True
results[attention.stem][task_name] = dct["test_accu_mean"]
logging.info(
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
f"{results[attention.stem][task_name]}"
)
else:
break
# - report an error if no result was found
if not found_result:
ERR_TAIL = 30
logging.warning(
f"No result found for {task_name}, showing the error log in {task.parent}"
)
err_log = Path(task.parent).glob("*.err")
print("*****************************************************")
with open(next(err_log), "r") as err_file:
for i, line in enumerate(reversed(err_file.readlines())):
print(line, end="")
if i > ERR_TAIL:
break
print("*****************************************************")
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
# - reduction: compute the average
tasks = set(t for v in results.values() for t in v.keys())
# -- fill in the possible gaps
for att in results.keys():
for t in tasks:
if t not in results[att].keys():
results[att][t] = 0.0
# -- add the average value
for att in results.keys():
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
# - Format as an array, markdown style
tasks_sort = sorted(
set(t for v in results.values() for t in v.keys()), reverse=True
)
print(
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
)
for att in results.keys():
print(
"{0:<20}".format(att)
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
)

View File

@@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from pathlib import Path
from xformers.benchmarks.LRA.run_tasks import Task
from xformers.components.attention import ATTENTION_REGISTRY
def get_default_shared_folder() -> str:
checkpoint_paths = ["/checkpoint", "/checkpoints"]
for checkpoint_path in checkpoint_paths:
if Path(checkpoint_path).is_dir():
return checkpoint_path
return "."
if __name__ == "__main__":
default_checkpoint_path = get_default_shared_folder()
# Get the user requests
parser = argparse.ArgumentParser(
"Benchmark different attention mechanisms on various sequence lengths"
)
parser.add_argument("-c", "--config_path", required=True)
parser.add_argument("-ck", "--checkpoint_path", required=True)
parser.add_argument(
"-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys())
)
parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task])
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
args = parser.parse_args()
for attention in args.attentions:
for task in args.tasks:
os.system(
"python3 run_with_submitit.py"
+ f" --attention {attention} --task {task} --config {args.config_path}"
+ f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
+ f" --partition {args.partition}"
)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Almost as-is from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
import logging
import pickle
import torch
from torch.utils.data.dataset import Dataset
logging.getLogger().setLevel(logging.INFO)
class LRADataset(Dataset):
def __init__(self, file_path, seq_len):
with open(file_path, "rb") as f:
self.examples = pickle.load(f)
self.seq_len = seq_len
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.create_inst(self.examples[i], self.seq_len)
@staticmethod
def create_inst(inst, seq_len):
output = {
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
}
output["mask_0"] = (output["input_ids_0"] != 0).float()
if "input_ids_1" in inst:
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
:seq_len
]
output["mask_1"] = (output["input_ids_1"] != 0).float()
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
return output

View File

@@ -0,0 +1,288 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: adapted from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
from enum import Enum
from typing import Dict, Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
from xformers.components import build_attention
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
from xformers.utils import generate_matching_config
PLOutput = Dict[str, Union[float, torch.Tensor]]
class Pooling(str, Enum):
MEAN = "mean"
CLS = "cls"
def pooling(mode: Pooling):
def pool_cls(inp):
return inp[:, 0, :]
def pool_mean(inp):
return inp.mean(dim=1)
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
def append_cls(inp, mask, vocab_size):
batch_size = inp.size(0)
cls_id = (
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
).long()
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
return inp, mask
def patch_model_config(config, attention_name):
# Rebuild a specific config out of generic + extra params
commons = config["common"]
try:
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
except KeyError:
extra_attention_settings = None
for bc in config["xformer"]:
bc["dim_model"] = commons["dim_model"]
bc["position_encoding_config"].update(commons)
bc["feedforward_config"].update(commons)
bc["multi_head_config"].update(commons)
bc["multi_head_config"]["attention"].update(commons)
bc["multi_head_config"]["attention"]["name"] = attention_name
bc["multi_head_config"]["attention"]["dim_head"] = (
commons["dim_model"] / commons["num_heads"]
)
if extra_attention_settings is not None:
bc["multi_head_config"]["attention"].update(extra_attention_settings)
bc["multi_head_config"] = generate_matching_config(
bc["multi_head_config"], MultiHeadDispatchConfig
)
bc["multi_head_config"].attention = build_attention(
bc["multi_head_config"].attention
)
bc = generate_matching_config(bc, xFormerEncoderConfig)
return config
class SCHead(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(dim_embedding, dim_mlp),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp: torch.Tensor):
seq_score = self.mlpblock(self.pooling(inp))
return seq_score
class SCHeadDual(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(
dim_embedding * 4,
dim_mlp,
),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
X_0 = self.pooling(inp_0)
X_1 = self.pooling(inp_1)
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
return seq_score
class ModelTrunk(pl.LightningModule):
def __init__(self, config, model_name):
super().__init__()
config_model = config["model"]
self.config_training = config["training"]
self.enable_amp = config["training"]["mixed_precision"]
self.pooling_mode = Pooling(config_model["pooling_mode"])
self.vocab_size = config_model["common"]["vocab_size"]
# Rebuild a specific config out of generic + extra params
self.config_model = patch_model_config(config_model, model_name)
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
ff_config = self.config_model["xformer"][0]["feedforward_config"]
self.dim_mlp = (
self.config_model["common"]["dim_model"]
* ff_config["hidden_layer_multiplier"]
)
def training_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self(**batch)
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("train_accu", outputs["accu"], sync_dist=True)
return outputs
def training_epoch_end(self, outputs):
logs = self.eval_epoch_end(outputs)
self.log("train_accu_mean", logs["accu"], sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.config_training["learning_rate"],
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=self.config_training["weight_decay"],
)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.config_training["learning_rate"],
pct_start=self.config_training["warmup"]
/ self.config_training["num_train_steps"],
anneal_strategy=self.config_training["lr_decay"],
total_steps=self.config_training["num_train_steps"],
)
return [optimizer], [lr_scheduler]
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
outputs = self(**batch)
return outputs
def eval_epoch_end(self, outputs, prefix: str = "train"):
logs = {}
counts = torch.tensor([x["count"] for x in outputs]).float()
logs["count"] = counts.sum()
for k in ("accu", "loss"):
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
"count"
]
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
return logs
def validation_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self.eval_step(batch, batch_idx)
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
return outputs
def validation_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="val")
def test_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
return self.eval_step(batch, batch_idx)
def test_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="test")
class ModelForSC(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHead(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
):
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)
seq_scores = self.seq_classifer(token_out)
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs
class ModelForSCDual(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHeadDual(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self,
input_ids_0: torch.Tensor,
input_ids_1: torch.Tensor,
mask_0: torch.Tensor,
mask_1: torch.Tensor,
label: torch.Tensor,
):
mask_0, mask_1 = mask_0.long(), mask_1.long()
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)
tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs

View File

@@ -0,0 +1,148 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import os
import uuid
from datetime import date
from pathlib import Path
from typing import Dict, Iterable
import submitit
from xformers.benchmarks.LRA.run_with_submitit import (
Trainer,
get_init_file,
get_shared_folder,
parse_args,
)
def grid_parameters(grid: Dict):
"""
Yield all combinations of parameters in the grid (as a dict)
"""
grid_copy = dict(grid)
# Turn single value in an Iterable
for k in grid_copy:
if not isinstance(grid_copy[k], Iterable):
grid_copy[k] = [grid_copy[k]]
for p in itertools.product(*grid_copy.values()):
yield dict(zip(grid.keys(), p))
def grid_search(args):
if args.checkpoint_dir == "":
args.checkpoint_dir = get_shared_folder() / "%j"
date_curr = date.today().strftime("%m-%d-%Y")
orig_check_dir = os.path.join(args.checkpoint_dir, date_curr)
# Create the executor
# Note that the folder will depend on the job_id, to easily track experiments
executor = submitit.AutoExecutor(
folder=get_shared_folder() / "%j", slurm_max_num_timeout=30
)
num_gpus_per_node = args.ngpus
nodes = args.nodes
args.world_size = args.nodes * args.ngpus
partition = args.partition
executor.update_parameters(
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=10,
nodes=nodes,
timeout_min=60 * 72,
slurm_signal_delay_s=120,
slurm_partition=partition,
)
executor.update_parameters(name="lra")
if args.task == "text":
grid_meta = {
"training:learning_rate": (
[1e-4, 2e-4, 3e-4, 5e-5],
lambda val: f"lr{val}",
),
"training:warmup": ([3000, 8000], lambda val: f"warmup{val}"),
"training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"),
"training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"),
}
elif args.task == "retrieval":
grid_meta = {
"training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"),
"training:warmup": ([2000, 8000], lambda val: f"warmup{val}"),
"training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"),
"training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0], lambda val: f"drop{val}"),
}
elif args.task == "listops":
grid_meta = {
"training:learning_rate": (
[1e-4, 2e-4, 3e-4, 5e-5],
lambda val: f"lr{val}",
),
"training:warmup": ([3000, 2000], lambda val: f"warmup{val}"),
"training:seed": (
[
1234,
],
lambda val: f"seed{val}",
),
"training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0], lambda val: f"drop{val}"),
}
else:
grid_meta = {
"training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"),
"training:warmup": ([8000], lambda val: f"warmup{val}"),
"training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"),
"training:weight_decay": ([0.01], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0.1], lambda val: f"drop{val}"),
}
grid = {k: v[0] for k, v in grid_meta.items()}
save_key = {k: v[1] for k, v in grid_meta.items()}
hyper_parameters = list(grid_parameters(grid))
jobs = []
for i, grid_data in enumerate(hyper_parameters):
args.sweep_parameters = grid_data
run_name = f"{args.attention}"
# run_name = "paper_config"
for k, v in grid_data.items():
run_name += "prenorm-" + save_key[k](v)
args.checkpoint_dir = os.path.join(
orig_check_dir, f"{args.task}", "logs", run_name
)
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name)
Path(args.tb_dir).mkdir(parents=True, exist_ok=True)
# Chronos needs a different job name each time
executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}")
args.dist_url = get_init_file().as_uri()
args.temp_file = str(get_init_file())
trainer = Trainer(args)
job = executor.submit(trainer)
jobs.append(job)
print(f"Run {i:02d} submitted with train cfg: {args}")
print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}")
if __name__ == "__main__":
args = parse_args()
grid_search(args)

View File

@@ -0,0 +1,298 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Tuple, cast
import pytorch_lightning as pl
import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis, flop_count_str
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from torch.utils.data import DataLoader
from xformers.benchmarks.LRA.code.dataset import LRADataset
from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
from xformers.components.attention import ATTENTION_REGISTRY
class Task(str, Enum):
Retrieval = "retrieval"
ListOps = "listops"
Image = "image"
PathfinderBaseline = "pathfinder32-curv_baseline"
PathfinderContour9 = "pathfinder32-curv_contour_length_9"
PathfinderContour14 = "pathfinder32-curv_contour_length_14"
Text = "text"
def load_config(path: str) -> Dict:
with open(Path(path).absolute(), "r") as fileio:
config = json.load(fileio)
# Duplicate the pathfinder configs
config["pathfinder32-curv_baseline"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
return config
def build_model(args: argparse.Namespace, config: Dict) -> nn.Module:
task = args.task
attention_name = args.attention
model = cast(
pl.LightningModule,
ModelForSCDual(config[f"{task}"], attention_name)
if task == Task.Retrieval
else ModelForSC(config[f"{task}"], attention_name),
)
logging.info(model)
summary = pl.utilities.model_summary.LayerSummary(model)
logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M")
with torch.no_grad():
# Check the flops
seq_len = config[f"{task}"]["model"]["common"]["seq_len"]
x = torch.rand(1, seq_len).long()
mask = torch.rand(1, seq_len).long()
indices = torch.rand(1, seq_len).long()
flops = FlopCountAnalysis(model.model, (x, mask, indices))
logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops")
logging.info(flop_count_str(flops))
return model
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--attention",
type=str,
help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
A list can be passed to test several mechanisms in sequence",
dest="attention",
required=True,
)
parser.add_argument(
"--task",
type=Task,
help=f"Task to chose, among {[t.value for t in Task]}.",
dest="task",
required=True,
)
parser.add_argument(
"--skip_train",
type=bool,
help="Whether to skip training, and test an existing model",
dest="skip_train",
default=False,
)
parser.add_argument(
"--config",
type=str,
help="Path to the config being used",
dest="config",
default="./config.json",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Path to the checkpoint directory",
dest="checkpoint_dir",
default=f"/checkpoints/{os.getenv('USER')}/xformers",
)
parser.add_argument(
"--checkpoint_path",
type=str,
help="Path to checkpoint",
)
parser.add_argument(
"--debug",
help="Make it easier to debug a possible issue",
dest="debug",
default=False,
action="store_true",
)
parser.add_argument(
"--world_size",
help="Number of GPUs used",
dest="world_size",
type=int,
default=1,
)
parser.add_argument(
"--sweep_parameters",
help="Rewrite some hyperparameters in the config",
dest="sweep_parameters",
type=dict,
default=None,
)
return parser
def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]:
experiment_name = f"{task}__{attention_name}"
logger = TensorBoardLogger(
save_dir=args.checkpoint_dir,
name="", # remove lightning_logs subdirectory
version=experiment_name,
)
log_dir = os.path.join(logger._save_dir, experiment_name)
return log_dir, logger
def rewrite_hyper(config, rewrites):
def replace(config_dict, k, v):
if len(k.split(":")) == 1:
config_dict[k] = v
return
first_key = k.split(":")[0]
assert first_key in config_dict, first_key
k = k[len(first_key) + 1 :]
replace(config_dict[first_key], k, v)
for k, v in rewrites.items():
replace(config, k, v)
return config
def build_dataloaders(
args: argparse.Namespace,
config_training: Dict,
num_workers: int = 4,
) -> Dict[str, DataLoader]:
datasets = {}
for component in ("train", "dev", "test"):
datasets[component] = LRADataset(
file_path=f"datasets/{args.task}.{component}.pickle",
seq_len=config_training["seq_len"],
)
# Gradient accumulation
accumu_steps = config_training["gradient_accumulation"]
logging.info(f"accumu_steps={accumu_steps}")
# Batch size
per_gpu_batch_size = (
config_training["batch_size"] // args.world_size // accumu_steps
)
logging.warning(
f"Requested batch size: {config_training['batch_size']}. Given world\
size and grad accumulation, per-gpu batch is\
{per_gpu_batch_size}"
)
dataloaders = {
k: DataLoader(
v,
batch_size=per_gpu_batch_size,
shuffle=False,
pin_memory=True,
num_workers=num_workers,
)
for k, v in datasets.items()
}
return dataloaders
def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]:
eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step}
for k, v in trainer.callback_metrics.items():
eval_summary[k] = v.item()
return eval_summary
class BasicProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
def benchmark(args):
log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}")
args.logger = logger
config = load_config(args.config)
config_task = config[f"{args.task}"]
if args.sweep_parameters is not None:
logging.info("Replacing hyperparameters")
rewrite_hyper(config_task, args.sweep_parameters)
config_training = config_task["training"]
config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
logging.info(f"Learning rate: {config_training['learning_rate']}")
pl.seed_everything(config_training.get("seed", 0))
dataloaders = build_dataloaders(args, config_training)
model = build_model(args, config)
progress_bar = BasicProgressBar()
checkpoint_callback = ModelCheckpoint(
monitor="val_accu",
mode="max",
dirpath=args.checkpoint_dir,
filename="{epoch}-{val_accu:.2f}",
every_n_train_steps=config_training["eval_frequency"],
)
trainer = pl.Trainer(
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=args.debug)
if not args.skip_train
else None,
accumulate_grad_batches=config_training["gradient_accumulation"],
callbacks=[progress_bar, checkpoint_callback],
detect_anomaly=args.debug,
deterministic=True,
gpus=args.world_size,
limit_val_batches=config_training["num_eval_steps"],
logger=logger,
max_steps=config_training["num_train_steps"],
num_sanity_val_steps=int(not args.skip_train),
precision=16 if config_training["mixed_precision"] else 32,
val_check_interval=config_training["eval_frequency"]
/ float(len(dataloaders["train"])),
)
if not args.skip_train:
trainer.fit(
model,
train_dataloaders=dataloaders["train"],
val_dataloaders=dataloaders["dev"],
)
ckpt_path = checkpoint_callback.best_model_path
else:
ckpt_path = args.checkpoint_path
trainer.test(
model,
dataloaders=dataloaders["test"],
ckpt_path=ckpt_path,
)
eval_summary = get_eval_summary(trainer)
with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f:
logging.info(f"Saving test results at {f.name}")
json.dump(eval_summary, f)
if __name__ == "__main__":
parser = get_arg_parser()
args = parser.parse_args()
if args.skip_train and args.checkpoint_path is None:
raise parser.error("Must provide --checkpoint_path if --skip_train=True")
benchmark(args)

View File

@@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
A script to run multinode training with submitit.
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
"""
import argparse
import os
import uuid
from pathlib import Path
import submitit
from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser
def parse_args():
parser = argparse.ArgumentParser(
"Submitit for LRA", parents=[get_arg_parser()], add_help=False
)
parser.add_argument(
"--ngpus", default=1, type=int, help="Number of gpus to request on each node"
)
parser.add_argument(
"--nodes", default=1, type=int, help="Number of nodes to request"
)
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
parser.add_argument(
"--use_volta32", action="store_true", help="Big models? Use this"
)
parser.add_argument(
"--enforce_host_memory", action="store_true", help="Use if the host OOMs"
)
parser.add_argument(
"--comment",
default="",
type=str,
help="Comment to pass to scheduler, e.g. priority message",
)
return parser.parse_args()
def get_shared_folder() -> Path:
user = os.getenv("USER")
checkpoint_paths = ["/checkpoint", "/checkpoints"]
for checkpoint_path in checkpoint_paths:
if Path(checkpoint_path).is_dir():
p = Path(f"{checkpoint_path}/{user}/xformers/submitit")
p.mkdir(exist_ok=True, parents=True)
return p
raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}")
def get_init_file():
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_shared_folder()), exist_ok=True)
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file
class Trainer:
def __init__(self, args):
self.args = args
def __call__(self):
self._setup_gpu_args()
benchmark(self.args)
def checkpoint(self):
self.args.dist_url = get_init_file().as_uri()
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)
def _setup_gpu_args(self):
job_env = submitit.JobEnvironment()
self.args.checkpoint_dir = Path(
str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id))
)
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
args = parse_args()
if args.checkpoint_dir == "":
args.checkpoint_dir = get_shared_folder() / "%j"
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
executor = submitit.AutoExecutor(
folder=args.checkpoint_dir, slurm_max_num_timeout=30
)
num_gpus_per_node = args.ngpus
nodes = args.nodes
timeout_min = args.timeout
args.world_size = args.nodes * args.ngpus
partition = args.partition
kwargs = {
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"nodes": nodes,
"timeout_min": timeout_min, # max is 60 * 72
# Below are cluster dependent parameters
"slurm_partition": partition,
"slurm_signal_delay_s": 120,
}
if args.enforce_host_memory:
kwargs["mem_gb"] = (40 * num_gpus_per_node,)
if args.use_volta32:
kwargs["slurm_constraint"] = "volta32gb"
if args.comment:
kwargs["slurm_comment"] = args.comment
executor.update_parameters(
**kwargs,
)
executor.update_parameters(name="lra")
args.dist_url = get_init_file().as_uri()
args.temp_file = str(get_init_file())
trainer = Trainer(args)
job = executor.submit(trainer)
print(f"Submitted job_id: {job.job_id}")
print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}")
with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile:
jobfile.write(f"{job.job_id}\n")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from typing import Any
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops as xops
min_run_time = 0.5
device = torch.device("cuda")
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = [
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128)
for i in range(8, 18)
] + [
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128)
for i in range(8, 18)
]
def _setup_test(
functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs
):
for k, benchmark_cls in functions.items():
benchmark_object = benchmark_cls(**kwargs, bw=bw)
label = benchmark_object.label
label += "fw" if fw else ""
label += "bw" if bw else ""
def run_one():
if fw:
benchmark_object.fw()
if bw:
benchmark_object.bw()
if cuda_graph:
run_one()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
run_one()
def run_one():
g.replay()
yield benchmark.Timer(
stmt="fn()",
globals={
"fn": run_one,
},
label=label,
description=k,
sub_label=benchmark_object.sub_label,
)
class AttentionDecodingFlashDecoding:
OP: Any = xops.fmha.flash.FwOp
def __init__(
self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool
) -> None:
dtype = torch.float16
self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}"
self.label = "attn_decoding"
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)
assert Hkv <= Hq
assert Hq % Hkv == 0
self.q = torch.randn(
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
)
self.k = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)
self.v = torch.randn(
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
).expand(-1, -1, -1, Hq // Hkv, -1)
if Hq == Hkv:
self.q = self.q[:, :, :, 0]
self.k = self.k[:, :, :, 0]
self.v = self.v[:, :, :, 0]
if Hkv == 1:
self.q = self.q[:, :, 0]
self.k = self.k[:, :, 0]
self.v = self.v[:, :, 0]
def fw(self) -> None:
xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP)
class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding):
OP = xops.fmha.triton_splitk.FwOp
class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding):
def fw(self) -> None:
B, Mq, Mkv, Hq, Hkv, K = self.shapes
scale = 1 / K**0.5
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
return attn @ v
BENCHMARKS = {
"pytorch": AttentionDecodingPyTorchRepeat,
"flash-decoding": AttentionDecodingFlashDecoding,
"triton_splitK": AttentionDecodingSplitKV,
}
try:
import flash_attn
class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding):
def fw(self) -> None:
q, k, v = self.q, self.k, self.v
if q.ndim == 5:
B, Mq, H1, H2, K = q.shape
B, Mkv, H1, H2, K = k.shape
q = q.reshape([B, Mq, H1 * H2, K])
k = k[:, :, :, 0]
v = v[:, :, :, 0]
return flash_attn.flash_attn_func(q, k, v)
BENCHMARKS[
f"flash-attention@{flash_attn.__version__}"
] = AttentionDecodingFlashAttention
except ImportError:
pass
def attn_decoding(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
cuda_graph=True,
functions=BENCHMARKS,
)
benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import os
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.attention.attention_mask import AttentionMask
from xformers.components.attention.core import scaled_dot_product_attention
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
SHAPES = [
(8, 128, 2096),
(8, 1024, 256),
(12, 512, 1024),
(128, 128, 512),
(8, 2048, 4096),
(16, 1024, 5120),
(512, 128, 2560),
]
BLOCK_SIZES = [128]
N_HEADS = [8, 32]
def bench_blocksparse_compare(backward: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
use_amp = True
_use_cuda = True
for dtype in [torch.float16, torch.float32]:
datatype = "fp16" if dtype == torch.float16 else "fp32"
results: Dict[str, Any] = {}
results_mem: Dict[str, Any] = {}
for BS in BLOCK_SIZES:
for heads in N_HEADS:
for B, M, K in SHAPES:
q = torch.randn(
(B, heads, M, K // heads),
requires_grad=backward,
device=device,
dtype=dtype,
)
k = q
v = q
# Mask with causal flag
m_att_mask = AttentionMask.make_causal(
M, M, device=device, dtype=dtype
)
# Custom causal tensor mask
m_custom = torch.triu(
torch.ones(M, M, device=device, dtype=dtype) * float("-inf"),
diagonal=1,
)
def blocksparse_attention():
with torch.cuda.amp.autocast(enabled=use_amp):
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS
)
if backward:
torch.norm(y).backward()
return y
def sdp_attention():
with torch.cuda.amp.autocast(enabled=use_amp):
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=m_custom, block_size=BS
)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(blocksparse_attention, f"blocksparse - fw{bw}"),
TestCase(sdp_attention, f"standard sdp - fw{bw}"),
]:
if _use_cuda:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
time = triton.testing.do_bench(testcase.function)[0]
if _use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2**20
else:
max_memory = -1
key = f"B={B},M={M},K={K},NH={heads}"
if key not in results_mem:
results_mem[key] = {}
results_mem[key][testcase.name] = f"{max_memory:.1f}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
units="runtime in ms",
)
pretty_print(
results_mem,
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
units="peak memory usage in MB",
)
pretty_plot(
results,
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
units="runtime in ms",
dash_key="torch",
legend_loc="upper left",
)
pretty_plot(
results_mem,
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
units="peak memory usage in MB",
dash_key="torch",
legend_loc="upper left",
)
for bw in [False, True]:
bench_blocksparse_compare(bw)

View File

@@ -0,0 +1,258 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import torch
from torch.utils import benchmark
from xformers.components.attention.core import (
SparseCS,
_create_random_sparsity,
_matmul_with_mask,
_softmax,
bmm,
)
MIN_RUN_TIME = 1
SHAPES = [[8, 8], [256, 1024], [128, 256]]
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
def bench_sddmm():
min_run_time = MIN_RUN_TIME
SPARSITIES = [0.95, 0.98, 0.99, 0.995, 0.999]
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, K, device=device)
b = torch.rand(B, M, K, device=device)
for backend, prob in itertools.product(
["coo_pytorch", "csr_sputnik", "csr_ge"], SPARSITIES
):
mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), prob)
aa = a
bb = b
if "csr" in backend:
mask = SparseCS(mask, device)
aa = a
bb = b
row_indices = mask.row_indices
row_offsets = mask.row_offsets
column_indices = mask.column_indices
if "_ge" in backend:
fn = torch.ops.xformers.csr_sddmm
else:
fn = torch.ops.xformers.sddmm_sputnik
fn_str = "fn(a, b, row_indices, row_offsets, column_indices)"
else:
mask = mask.to_sparse().to(device)
_, row_offsets, column_indices = mask.indices().int().unbind()
row_offsets = row_offsets.contiguous()
column_indices = column_indices.contiguous()
row_indices = row_offsets
bb = b.transpose(-2, -1)
fn = _matmul_with_mask
fn_str = "fn(a, b, mask)"
results.append(
benchmark.Timer(
stmt=fn_str,
globals={
"a": aa,
"b": bb,
"mask": mask,
"row_indices": row_indices,
"row_offsets": row_offsets,
"column_indices": column_indices,
"fn": fn,
},
label="sddmm",
sub_label=f"sparsity {backend}: {prob:0.4f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
def bench_matmul_with_mask():
min_run_time = MIN_RUN_TIME
prob = 0.9
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, K, device=device)
b = torch.rand(B, K, M, device=device)
mask = torch.rand(B, M, M, device=device) > prob
results.extend(
[
benchmark.Timer(
stmt="_matmul_with_mask(a, b, mask)",
globals={
"a": a,
"b": b,
"mask": None,
"_matmul_with_mask": _matmul_with_mask,
},
label="matmul_with_mask",
sub_label="dense",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
benchmark.Timer(
stmt="_matmul_with_mask(a, b, mask)",
globals={
"a": a,
"b": b,
"mask": mask,
"_matmul_with_mask": _matmul_with_mask,
},
label="matmul_with_mask",
sub_label="dense with masking",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
]
)
for sputnik, prob in itertools.product([False, True], SPARSITIES):
mask = _create_random_sparsity(
torch.ones(B, M, M, dtype=torch.bool, device=device), prob
)
aa = a
bb = b
if sputnik:
mask = SparseCS(mask, device)
aa = a
bb = b.transpose(-2, -1).contiguous().transpose(-2, -1)
else:
mask = mask.to_sparse()
results.append(
benchmark.Timer(
stmt="_matmul_with_mask(a, b, mask)",
globals={
"a": aa,
"b": bb,
"mask": mask,
"_matmul_with_mask": _matmul_with_mask,
},
label="matmul_with_mask",
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
def bench_softmax():
min_run_time = MIN_RUN_TIME
prob = 0.9
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
results.extend(
[
benchmark.Timer(
stmt="_softmax(a)",
globals={
"a": a,
"_softmax": _softmax,
},
label="softmax",
sub_label="dense",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
]
)
for sputnik, prob in itertools.product([False, True], SPARSITIES):
a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob)
if sputnik:
a = SparseCS(a, device)
else:
a = a.to_sparse()
results.append(
benchmark.Timer(
stmt="_softmax(a)",
globals={
"a": a,
"_softmax": _softmax,
},
label="softmax",
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
def bench_bmm():
min_run_time = MIN_RUN_TIME
prob = 0.9
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
b = torch.rand(B, M, K, device=device)
results.extend(
[
benchmark.Timer(
stmt="bmm(a, b)",
globals={
"a": a,
"b": b,
"bmm": bmm,
},
label="bmm",
sub_label="dense",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
]
)
for sputnik, prob in itertools.product([False, True], SPARSITIES):
a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob)
bb = b
if sputnik:
a = SparseCS(a, device)
bb = b
else:
a = a.to_sparse()
results.append(
benchmark.Timer(
stmt="bmm(a, b)",
globals={
"a": a,
"b": bb,
"bmm": bmm,
},
label="bmm",
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
bench_sddmm()
bench_matmul_with_mask()
bench_softmax()
bench_bmm()

View File

@@ -0,0 +1,241 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops as xops
min_run_time = 0.5
device = torch.device("cuda")
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES_IADD = list(
product_dict(
shape=[
(int(48 * 0.6), 48, 1, 257 * 1536),
(int(48 * 0.6), 48, 257, 1536),
],
scaling=[False, True],
dtype=[torch.half],
)
) + list(
product_dict(
shape=[
# Format: [B_src, B_inp, M, D]
(int(192 * 0.6), 192, 50, 1536),
(int(48 * 257 * 0.6), 257 * 48, 1, 1536),
(int(192 * 50 * 0.6), 192 * 50, 1, 1536),
(int(16 * 257 * 0.6), 48 * 257, 1, 1536),
],
scaling=[False],
dtype=[torch.half],
)
)
CASES_ISELECT = list(
product_dict(
batches=[((48, 257), (50, 192))],
D=[1536],
keep_ratio=[0.6],
dtype=[torch.half],
)
)
DTYPE2STR = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float32: "f32",
}
def _setup_test(functions, fw: bool = False, bw: bool = False, **kwargs):
for k, benchmark_cls in functions.items():
benchmark_object = benchmark_cls(**kwargs, bw=bw)
label = benchmark_object.label
label += "fw" if fw else ""
label += "bw" if bw else ""
def run_one():
if fw:
benchmark_object.fw()
if bw:
benchmark_object.bw()
yield benchmark.Timer(
stmt="fn()",
globals={
"fn": run_one,
},
label=label,
description=k,
sub_label=benchmark_object.sub_label,
)
class ScaledIndexAddBenchmark:
def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None:
B_src, B_out, M, D = shape
torch.manual_seed(B_out + B_src)
dtype_str = DTYPE2STR.get(dtype, dtype)
self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}"
self.label = "scaled_index_add"
self.alpha = 0.73
self.inp = torch.randn(
[B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw
)
self.src = torch.randn(
[B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw
)
self.scaling = (
torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw)
if scaling
else None
)
self.index = torch.tensor(
[i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda"
)
self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype)
self.out = torch.Tensor()
def fw(self) -> None:
self.out = xops.scaled_index_add(
input=self.inp.clone(),
index=self.index,
source=self.src,
scaling=self.scaling,
alpha=self.alpha,
)
def bw(self):
self.inp.grad = None
self.src.grad = None
if self.scaling is not None:
self.scaling.grad = None
self.out.backward(self.grad, retain_graph=True)
class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark):
def fw(self) -> None:
src_scaled = self.src
if self.scaling is not None:
src_scaled * self.scaling.unsqueeze(0).unsqueeze(0)
self.out = self.inp.index_add(
dim=0,
source=src_scaled,
index=self.index,
alpha=self.alpha,
)
def scaled_index_add_fw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
functions={
"xformers": ScaledIndexAddBenchmark,
"pytorch": ScaledIndexAddBenchmarkBaseline,
},
)
def scaled_index_add_fwbw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
bw=True,
functions={
"xformers": ScaledIndexAddBenchmark,
"pytorch": ScaledIndexAddBenchmarkBaseline,
},
)
class IndexSelectBenchmark:
def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None:
dtype_str = DTYPE2STR.get(dtype, dtype)
self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}"
self.label = "index_select"
srcs = [torch.randn([B, seqlen * D]) for (B, seqlen) in batches]
src = torch.cat([s.view([-1, D]) for s in srcs], dim=0).cuda().to(dtype)
src.requires_grad_(True)
indices = []
sources = []
elements_i = 0
for source_i in srcs:
index = [i for i in range(source_i.shape[0])]
random.Random(source_i.shape[0]).shuffle(index)
indices.append(
torch.tensor(
index[: int(keep_ratio * source_i.shape[0])],
dtype=torch.int64,
device="cuda",
)
)
sources.append(
src[
elements_i : elements_i + source_i.shape[0] * source_i.shape[1] // D
].reshape(source_i.shape)
)
elements_i += source_i.shape[0] * source_i.shape[1] // D
self.indices, self.sources, self.src = indices, sources, src
self.out = torch.Tensor()
def fw(self) -> None:
self.out = xops.index_select_cat(self.sources, self.indices)
def bw(self):
self.src.grad = None
self.out.backward(self.out, retain_graph=True)
class IndexSelectBenchmarkBaseline(IndexSelectBenchmark):
def fw(self) -> None:
self.out = torch.cat(
[s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0
)
def index_select_fw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
functions={
"xformers": IndexSelectBenchmark,
"pytorch": IndexSelectBenchmarkBaseline,
},
)
def index_select_fwbw(**kwargs):
yield from _setup_test(
**kwargs,
fw=True,
bw=True,
functions={
"xformers": IndexSelectBenchmark,
"pytorch": IndexSelectBenchmarkBaseline,
},
)
benchmark_main_helper(scaled_index_add_fw, CASES_IADD, min_run_time=min_run_time)
benchmark_main_helper(scaled_index_add_fwbw, CASES_IADD, min_run_time=min_run_time)
benchmark_main_helper(index_select_fw, CASES_ISELECT, min_run_time=min_run_time)
benchmark_main_helper(index_select_fwbw, CASES_ISELECT, min_run_time=min_run_time)

View File

@@ -0,0 +1,316 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import random
from functools import partial
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops
import xformers.ops.fmha as fmha
torch.backends.cuda.matmul.allow_tf32 = False
def create_attn_bias(
bias_type,
batch_size: int,
num_heads: int,
q_len: int,
kv_len: int,
device,
dtype,
bias_requires_grad: bool = False,
):
NoneType = type(None)
if bias_type is NoneType:
return None
if bias_type is torch.Tensor:
attn_bias = torch.randn((1, 1, q_len, kv_len), device=device, dtype=dtype)
return attn_bias.expand(batch_size, num_heads, q_len, kv_len)
if bias_type is xformers.ops.LowerTriangularMask:
return bias_type()
assert False, f"Unsupported bias type: {bias_type}"
def ref_attention_bmk(q, k, v, attn_bias=None, p=0.0):
if isinstance(attn_bias, xformers.ops.AttentionMask):
attn_bias = (
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
.to(q)
.squeeze()
)
q = q * (1.0 / q.shape[-1] ** 0.5)
if attn_bias is None:
attn = q @ k.transpose(-2, -1)
else:
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
attn = attn.softmax(-1)
if p > 0:
attn = torch.nn.functional.dropout(attn, p=p)
return attn @ v
def ref_attention(q, k, v, attn_bias, p=0.0):
assert q.ndim == 4
B, M, H, K = q.shape
def T(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
if isinstance(attn_bias, torch.Tensor):
attn_bias = attn_bias.reshape(B * H, M, M)
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias, p)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3))
min_run_time = 0.5
device = torch.device("cuda")
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
SHAPES = [
# ViT
(384, 197, 1, 88),
(384, 197, 1, 80),
(384, 197, 1, 64),
(1024, 197, 1, 88),
(1024, 197, 1, 80),
(1024, 197, 1, 64),
# ViT-Huge
(32 * 16, 197, 1, 80),
(32, 197, 16, 80),
(32, 197, 16, 64),
(32, 197, 16, 128),
# ViT-Giant
(16 * 16, 197, 1, 88),
(16, 197, 16, 88),
(16, 197, 16, 64),
(16, 197, 16, 128),
# FB models
(1024, 82, 8, 64),
(150, 256, 16, 64),
(64, 256, 12, 64),
# Stable diffusion (https://github.com/huggingface/diffusers/pull/532)
(1, 4096, 16, 40), # 512x512
(1, 16384, 16, 40), # 1024x1024
(1, 4096, 16, 80),
(1, 16384, 16, 80),
# + bs4
(4, 4096, 16, 40),
(4, 16384, 16, 40),
(4, 4096, 16, 80),
(4, 16384, 16, 80),
# ParlAI model
(256, 4096, 16, 64),
# Zetta B M H K
(8, 2048, 20, 128),
# LLaMa 70b - mp=8/16
*sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])),
*sorted(
itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
),
]
OPS = [
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
# TODO: Triton is not stable: it can trigger Illegal Memory Accesses
# and its performance varies a lot between runs.
# (xformers.ops.fmha.triton.FwOp, xformers.ops.fmha.triton.BwOp),
]
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
shape=SHAPES,
num_threads=NUM_THREADS,
dropout_p=[0.0],
attn_bias_cfg=[(type(None), False)],
dtype=[torch.half],
)
)
# Add more cases with some variations
for c in CASES.copy():
c = c.copy()
c.update(
random.Random(str(c["shape"])).choice(
[
{"dropout_p": 0.3},
{"attn_bias_cfg": (torch.Tensor, False)},
{"attn_bias_cfg": (torch.Tensor, True)},
{"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)},
{"dtype": torch.bfloat16},
{"dtype": torch.float},
]
)
)
CASES.append(c)
def create_tensors(shape, dtype, requires_grad=False):
B, M, H, K = shape
qkv = torch.rand(
[B, M, 3, H, K], device=device, dtype=dtype, requires_grad=requires_grad
)
q, k, v = xformers.ops.unbind(qkv, 2)
return qkv, q, k, v
def mem_eff_attention_fw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
B, M, H, K = shape
_, q, k, v = create_tensors(shape, dtype)
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
if attn_bias_requires_grad:
return
bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=H,
q_len=M,
kv_len=M,
device=device,
dtype=dtype,
bias_requires_grad=attn_bias_requires_grad,
)
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}[dtype]
sub_label = (
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
f"BiasT={attn_bias_type.__name__}"
)
has_run = False
for fw_op, bw_op in OPS:
if not fw_op.supports(inp):
continue
yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias, p)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": inp.attn_bias,
"p": dropout_p,
"fn": partial(
xformers.ops.memory_efficient_attention, op=(fw_op, bw_op)
),
},
label=f"attention (attn_bias={attn_bias_type})",
description=fw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)
has_run = True
if not has_run:
return
yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias, p)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": inp.attn_bias,
"p": dropout_p,
"fn": ref_attention,
},
label=f"attention (attn_bias={attn_bias_type})",
description="eager",
sub_label=sub_label,
num_threads=num_threads,
)
def mem_eff_attention_bw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
B, M, H, K = shape
qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True)
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
bias = create_attn_bias(
attn_bias_type,
batch_size=B,
num_heads=H,
q_len=M,
kv_len=M,
device=device,
dtype=dtype,
bias_requires_grad=attn_bias_requires_grad,
)
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
dtype_str = {
torch.bfloat16: "b16",
torch.half: "f16",
torch.float: "f32",
}[dtype]
sub_label = (
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}"
)
has_run = False
for fw_op, bw_op in OPS:
if not fw_op.supports(inp) or not bw_op.supports(inp):
continue
has_run = True
out = xformers.ops.memory_efficient_attention(
inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op)
)
grad_benchmark = torch.ones_like(q)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad_benchmark,
},
label=f"attention backward (attn_bias={attn_bias_type})",
description=bw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)
del out
if not has_run:
return
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": ref_attention(q, k, v, inp.attn_bias, dropout_p),
"grad": grad_benchmark,
},
label=f"attention backward (attn_bias={attn_bias_type})",
description="vanilla",
sub_label=sub_label,
num_threads=num_threads,
)
benchmark_main_helper(mem_eff_attention_fw, CASES, min_run_time=min_run_time)
benchmark_main_helper(mem_eff_attention_bw, CASES, min_run_time=min_run_time)

View File

@@ -0,0 +1,187 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from functools import partial
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops
import xformers.ops.fmha as fmha
torch.backends.cuda.matmul.allow_tf32 = False
# Run with
# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet
# The baselines for these benchmarks are really slow because there is
# so much padding in the inputs, so there is no point running them.
def ref_attention_bmk(q, k, v, attn_bias=None):
if isinstance(attn_bias, xformers.ops.AttentionMask):
attn_bias = (
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
.to(q)
.squeeze()
)
q = q * (1.0 / q.shape[-1] ** 0.5)
if attn_bias is None:
attn = q @ k.transpose(-2, -1)
else:
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
# but faster, and is what is used in PyTorch now
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
attn = attn.softmax(-1)
return attn @ v
def ref_attention(q, k, v, attn_bias):
assert q.ndim == 4
def T(t):
return t.permute((0, 2, 1, 3)).reshape(
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
)
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias)
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
return out.permute((0, 2, 1, 3))
min_run_time = 0.5
device = torch.device("cuda")
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
OPS = [
xformers.ops.fmha.cutlass.FwOp,
xformers.ops.fmha.decoder.FwOp,
]
KV_SHAPES = [
# list of n_keys, padding_length, batchsize
(2, 64, 3),
(32, 1024, 500),
(1000, 1024, 2),
(8000, 8192, 1),
(240, 256, 32),
(2048, 2 * 1024, 4),
(4096 * 2, 8 * 1024, 1),
]
N_HEADS = [8, 16, 64]
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
kv_shape=KV_SHAPES,
n_heads=N_HEADS,
num_threads=NUM_THREADS,
multiquery=[True, False],
)
)
def mem_eff_attention_decoder(
kv_shape, n_heads: int, num_threads: int, multiquery: bool
):
n_keys, padding, B = kv_shape
torch.manual_seed(42)
k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist()
K = 128
q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16)
if multiquery:
k = torch.rand(
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
).expand(1, B * padding, n_heads, K)
v = torch.rand(
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
).expand(1, B * padding, n_heads, K)
else:
k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1] * B,
kv_seqlen=k_seqlen,
kv_padding=padding,
)
sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads"
if multiquery:
sub_label += "-mq"
has_run = False
for fw_op in OPS:
inp = fmha.Inputs(q, k, v, attn_bias=bias)
if not fw_op.supports(inp):
continue
fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op)
yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": bias,
"fn": fn,
},
label="attention",
description=fw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
fn(q, k, v, bias)
yield benchmark.Timer(
stmt="graph.replay()",
globals={
"graph": graph,
},
label="cuda graphed attention",
description=fw_op.NAME,
sub_label=sub_label,
num_threads=num_threads,
)
has_run = True
if not has_run:
return
RUN_BASELINES = False
if RUN_BASELINES:
yield benchmark.Timer(
stmt="fn(q, k, v, attn_bias)",
globals={
"q": q,
"k": k,
"v": v,
"attn_bias": bias,
"fn": ref_attention,
},
label="attention",
description="eager",
sub_label=sub_label,
num_threads=num_threads,
)
benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time)

View File

@@ -0,0 +1,127 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import Activation
from xformers.components.feedforward import MLP, FusedMLP
SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(1, 2048, 4096),
(1, 1024, 12288),
]
HIDDEN_LAYER_MULTIPLIER = [4]
def bench_MLP(backward: bool, bias: bool, dropout: float, activation: Activation):
device = torch.device("cuda")
bw = "+bw" if backward else ""
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, M, K in SHAPES:
for hlm in HIDDEN_LAYER_MULTIPLIER:
fused_mlp = FusedMLP(
dim_model=K,
dropout=dropout,
activation=activation,
hidden_layer_multiplier=hlm,
bias=bias,
).to(device=device, dtype=dtype)
standard_mlp = MLP(
dim_model=K,
dropout=dropout,
activation=activation,
hidden_layer_multiplier=hlm,
bias=bias,
).to(device=device, dtype=dtype)
a = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
def mlp_standard():
y = standard_mlp(a)
if backward:
torch.norm(y).backward()
return y
def mlp_fused():
y = fused_mlp(a)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(
mlp_standard,
"standard - {} - {} bias - {} drop - fw{}".format(
activation,
"no" if not bias else "",
dropout,
"+bw" if backward else "",
),
),
TestCase(
mlp_fused,
"fused - {} - {} bias - {} drop - fw{}".format(
activation,
"no" if not bias else "",
dropout,
"+bw" if backward else "",
),
),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"{B} x {M} x {K} - {hlm}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better. BMK - mul: ",
)
pretty_plot(
results,
title=f"MLP-{activation}-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)
if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser("Benchmark MLP")
parser.add_argument("-act", "--activations", nargs="+", default=[Activation.GeLU])
parser.add_argument("-bias", "--bias", nargs="+", default=[False, True])
parser.add_argument("-dropout", "--dropout", nargs="+", default=[0.0, 0.1])
args = parser.parse_args()
for bw in [False, True]:
for bias in args.bias:
for dropout in args.dropout:
for activation in args.activations:
bench_MLP(
backward=bw,
bias=bias,
dropout=float(dropout),
activation=activation,
)

View File

@@ -0,0 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
import torch
import torch.nn as nn
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import MultiHeadDispatch
from xformers.components.attention import ScaledDotProduct
SHAPES = [
(8, 384, 128),
(8, 784, 512),
(4, 1024, 768),
(4, 2048, 1024),
(2, 2048, 2048),
(2, 2048, 4096),
(2, 4096, 4096),
(1, 2048, 12288),
]
N_HEADS = [4]
def bench_multihead_dispatch(backward: bool, self_attention: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
sa = " (self_attn)" if self_attention else ""
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, M, K in SHAPES:
for heads in N_HEADS:
xf_multi_head = MultiHeadDispatch(
dim_model=K,
residual_dropout=0.0,
num_heads=heads,
attention=ScaledDotProduct(),
bias=(True, True, True, True),
).to(device=device, dtype=dtype)
torch_multi_head = nn.MultiheadAttention(
embed_dim=K, num_heads=heads, batch_first=True
).to(device=device, dtype=dtype)
q = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
if self_attention:
k = q
v = q
else:
k = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
v = torch.randn(
(B, M, K), requires_grad=backward, device=device, dtype=dtype
)
def torch_mha():
y, _ = torch_multi_head(query=q, key=k, value=v)
if backward:
torch.norm(y).backward()
return y
def xformers_mha():
y = xf_multi_head(query=q, key=k, value=v)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(torch_mha, f"torch - fw{bw}{sa}"),
TestCase(xformers_mha, f"xf - fw{bw}{sa}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better",
)
pretty_plot(
results,
title=f"MHA-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)
for bw in [False, True]:
for self_attention in [False, True]:
bench_multihead_dispatch(bw, self_attention)

View File

@@ -0,0 +1,99 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable
import torch
from torch.utils import benchmark
from xformers.components.attention.utils import iterative_pinv
MIN_RUN_TIME = 1
SHAPES = [[8, 8], [256, 1024], [128, 256]]
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]):
min_run_time = MIN_RUN_TIME
prob = 0.9
device = torch.device("cuda")
results = []
for B, M, K in zip(*SHAPES):
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
a = torch.softmax(a, dim=-1)
results.extend(
[
benchmark.Timer(
stmt=f"{inverse_fn.__name__}(a)",
globals={
"a": a,
f"{inverse_fn.__name__}": inverse_fn,
},
label=f"{inverse_fn.__name__}",
sub_label="dense",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time),
]
)
for prob in SPARSITIES:
a = torch.rand(B, M, M, device=device)
a[a < prob] = 0
a = a.to_sparse()
results.append(
benchmark.Timer(
stmt=f"{inverse_fn.__name__}(a)",
globals={
"a": a,
f"{inverse_fn.__name__}": inverse_fn,
},
label=f"{inverse_fn.__name__}",
sub_label=f"sparsity: {prob:0.2f}",
description=f"B={B}, M={M}, K={K}",
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
def iterative_pinv_analysis(
identity_tolerance: float = 1e-1,
pinv_tolerance: float = 5e-1,
max_iters: int = 30,
plot: bool = True,
):
for i in range(1, 10):
B, M = 1, 2**i
a = torch.rand(B, M, M)
a = torch.softmax(a, dim=-1)
for n_iter in range(1, max_iters + 1):
result = iterative_pinv(a, n_iter=n_iter)
expected = torch.linalg.pinv(a)
result_identity = torch.matmul(a, result)
identity = torch.eye(M)
# Default is frobenius norm.
identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1))
inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1))
if (identity_error < identity_tolerance).all() or n_iter == max_iters:
print(
f"Size {M}, n_iters {n_iter}: \n\t \
Final Error from Identity: {identity_error.item()} \n\t \
Final Error from linalg.pinv {inverse_error.item()}"
)
break
if __name__ == "__main__":
iterative_pinv_analysis()
bench_inverse(iterative_pinv)
bench_inverse(torch.linalg.pinv)

View File

@@ -0,0 +1,83 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.reversible import ReversibleSequence
SHAPES = [(16384, 32), (2048, 256), (128, 4096)]
DEPTH = [4, 32, 256]
def bench_revnet(backward: bool):
device = torch.device("cuda")
bw = "+bw" if backward else ""
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, K in SHAPES:
for depth in DEPTH:
f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
revseq = ReversibleSequence(
torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
)
revseq = revseq.to(device=device, dtype=dtype)
a = torch.rand(
1, B, K, device=device, dtype=dtype, requires_grad=backward
)
b = torch.rand(
1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
)
def normal_step():
y = a
for _ in range(depth):
y = y + f(y)
y = y + g(y)
if backward:
torch.norm(y).backward()
return y
def reversible_step():
y = revseq(b)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(normal_step, f"residual - fw{bw}"),
TestCase(reversible_step, f"reversible - fw{bw}"),
]:
time = triton.testing.do_bench(testcase.function)[0]
key = f"Batch={B}, Features={K}, Depth={depth}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{time:.2f}"
pretty_print(
results,
title=f"\n --- Type: {dtype} --- ",
units="runtime in ms, lower is better",
)
pretty_plot(
results,
title=f"RevNet-FW{bw}-{dtype}",
units="runtime in ms, lower is better",
dash_key="torch",
)
for bw in [False, True]:
bench_revnet(bw)

View File

@@ -0,0 +1,117 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import torch
from torch.utils import benchmark
from xformers.components.attention._sputnik_sparse import _csr_to_coo
from xformers.components.attention.core import SparseCS, _create_random_sparsity
MIN_RUN_TIME = 0.2
def _get_fn(backend):
if backend == "csr_ge":
fn = torch.ops.xformers.csr_sddmm
elif backend == "csr_sputnik":
fn = torch.ops.xformers.sddmm_sputnik
elif backend == "coo_ge":
def fn(a, b, row_indices, row_offsets, column_indices):
row_coo, _ = _csr_to_coo(
a.shape[-2], b.shape[-2], row_offsets, column_indices
)
return torch.ops.xformers.coo_sddmm(
a, b, row_indices, row_coo, column_indices
)
elif backend == "csr_to_coo":
def fn(a, b, row_indices, row_offsets, column_indices):
row_coo, _ = _csr_to_coo(
a.shape[-2], b.shape[-2], row_offsets, column_indices
)
return row_coo
return fn
def bench_sddmm(configs):
min_run_time = MIN_RUN_TIME
device = torch.device("cuda")
results = []
for (B, M, K), prob in configs:
a = torch.rand(B, M, K, device=device)
b = torch.rand(B, M, K, device=device)
mask = _create_random_sparsity(
torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16
)
aa = a
bb = b
mask = SparseCS(mask, device)
row_indices = mask.row_indices
row_offsets = mask.row_offsets
column_indices = mask.column_indices
for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]:
fn_str = "fn(a, b, row_indices, row_offsets, column_indices)"
fn = _get_fn(backend)
results.append(
benchmark.Timer(
stmt=fn_str,
globals={
"a": aa,
"b": bb,
"mask": mask,
"row_indices": row_indices,
"row_offsets": row_offsets,
"column_indices": column_indices,
"fn": fn,
},
label="sddmm",
sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}",
description=backend,
).blocked_autorange(min_run_time=min_run_time)
)
compare = benchmark.Compare(results)
compare.print()
return results
# batch size 32, for different layers
SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)]
swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0)))
# some random values
BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)]
SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999]
basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES))
# batch size 32 here
vit_sizes = [
(192, 785, 64), # deit_small_patch8_224
(192, 197, 64), # deit_small_patch16_224
(384, 785, 64), # deit_base_patch8_224
(384, 197, 64), # deit_base_patch16_224
]
SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97]
vit_config = list(itertools.product(vit_sizes, SPARSITIES))
results = []
print("Swin Transformer")
results += bench_sddmm(swin_t_config)
print("ViT")
results += bench_sddmm(vit_config)
print("Basic cases")
results += bench_sddmm(basic_config)

View File

@@ -0,0 +1,160 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from contextlib import nullcontext
from functools import partial
from typing import Any
import torch
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops.swiglu_op as xsw
min_run_time = 0.5
device = torch.device("cuda")
SHAPES = [
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
# ViT-Giant
(9456, 1536, 2736),
(4440, 1536, 2736),
(4728, 1536, 2736),
# Some smaller shapes as well
(4728, 1536, 1024),
# GPT-3 (small)
(32768, 2048, 5632),
# Chinchilla
(32768, 8192, 22016),
]
# OP = xsw._SwiGLUDecomposedOp
# OP = xsw.SwiGLUFusedOp
OP = xsw.SwiGLUPackedFusedOp
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
shape=SHAPES,
dtype=[torch.bfloat16, torch.half, "autocast_half"],
bias=[True, False],
)
)
DTYPE2STR = {
torch.bfloat16: "b16 ",
torch.half: "f16 ",
"autocast_half": "f16.ac",
}
def benchmark_swiglu(shape, dtype, bias: bool):
if dtype == "autocast_half":
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
else:
inp_dtype, model_dtype, autocast = dtype, dtype, False
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
module = (
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
dtype_str = DTYPE2STR.get(dtype, dtype)
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
params = module._ordered_params()
PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": params,
"fn": partial(xsw.swiglu, op=OP),
},
label="swiglu_fw",
description=OP.NAME,
sub_label=sub_label,
)
yield benchmark.Timer(
stmt=f"{PREFIX}fn(x, *args)",
globals={
"x": x,
"args": params,
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
},
label="swiglu_fw",
description="eager",
sub_label=sub_label,
)
def benchmark_swiglu_bw(shape, dtype, bias: bool):
if dtype == "autocast_half":
inp_dtype, model_dtype = torch.float, torch.float
cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
else:
inp_dtype, model_dtype = dtype, dtype
cm = nullcontext
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
x.requires_grad_()
module = (
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
.to(device)
.to(model_dtype)
)
dtype_str = DTYPE2STR.get(dtype, dtype)
bstr = "bias" if bias else "nobi"
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
params = module._ordered_params()
with cm():
out = xsw.swiglu(x, *params, op=OP)
grad = torch.zeros_like(out)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="swiglu_bw",
description=OP.NAME,
sub_label=sub_label,
)
del out
with cm():
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)
yield benchmark.Timer(
stmt="out.backward(grad, retain_graph=True)",
globals={
"out": out,
"grad": grad,
},
label="swiglu_bw",
description="eager",
sub_label=sub_label,
)
benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)

View File

@@ -0,0 +1,155 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from functools import partial, reduce
import timm
import torch
import torch.nn as nn
from timm.models.layers import Mlp as TimmMlp
from timm.models.vision_transformer import Attention as TimmAttention
from timm.models.vision_transformer import Block as TimmBlock
from torch.utils import benchmark
from utils import benchmark_main_helper
import xformers.ops as xops
def replace_module(module: nn.Module, replace_class, factory):
if isinstance(module, replace_class):
return factory(module)
module_output = module
for name, child in module.named_children():
module_output.add_module(name, replace_module(child, replace_class, factory))
del module
return module_output
class TimmMemEffAttention(nn.Module):
def __init__(self, attn: TimmAttention, op=None):
super().__init__()
self.op = None
self.num_heads = attn.num_heads
self.scale = attn.scale
self.qkv = attn.qkv
self.attn_drop = attn.attn_drop
self.proj = attn.proj
self.proj_drop = attn.proj_drop
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = xops.unbind(qkv, dim=2)
x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class TimmSwiGLU(nn.Module):
def __init__(self, mlp: TimmMlp, op=None) -> None:
super().__init__()
self.fc1 = mlp.fc1
self.swiglu = xops.SwiGLU(
in_features=mlp.fc1.in_features,
hidden_features=mlp.fc1.out_features,
bias=True,
)
self.op = op
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.swiglu(x)
def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op))
def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module:
def _mlp_to_swiglu(block: TimmBlock):
block.mlp = TimmSwiGLU(block.mlp, op=op)
return block
return replace_module(model, TimmBlock, _mlp_to_swiglu)
mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp)
mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None)
def compose(*fns):
def compose2(f, g):
return lambda *a, **kw: f(g(*a, **kw))
return reduce(compose2, fns)
MODELS = [
# model_name, model_factory, input_shape
("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]),
("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]),
("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]),
("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]),
]
MODIFIERS = [
["mlp", lambda x: x],
["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
["swiglu", mod_mlp_to_eagr_swiglu],
["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu],
["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
]
def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))
CASES = list(
product_dict(
model_info=MODELS,
dtype=[torch.half],
)
)
def benchmark_transformer(model_info, dtype):
device = "cuda"
model_name, model_factory, input_shape = model_info
inp = torch.randn(input_shape, dtype=dtype, device=device)
for mod_name, mod_apply in MODIFIERS:
model: nn.Module = model_factory()
model = mod_apply(model).to(device).to(dtype)
# Make sure we don't have errors
out = model(inp)
grad = out.clone()
out.backward(grad)
yield benchmark.Timer(
stmt="model(inp).backward(grad)",
globals={
"model": model,
"inp": inp,
"grad": grad,
},
label="fw+bw",
description=mod_name,
sub_label=model_name,
)
benchmark_main_helper(benchmark_transformer, CASES)

View File

@@ -0,0 +1,150 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Benchmark the blocksparse operations:
# matrix multiply and softmax
# Matmul can be of three types:
# - Dense x Dense (COO) -> Sparse
# - Sparse x Dense -> Dense
# - Dense x Sparse -> Dense
from typing import Any, Dict
import torch
import triton
from triton.ops.blocksparse import matmul as blocksparse_matmul
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components.attention.core import SparseCS, _matmul_with_mask
def bench_matmul(dtype: torch.dtype, shapes):
results: Dict[str, Any] = {}
Z, H = 1, 1
for M, N, K in shapes:
modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]]
for mode, block in modes:
# create inputs
a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda")
b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda")
shape = {
"sdd": (M, N),
"dsd": (a.shape[2], a.shape[3]),
"dds": (b.shape[2], b.shape[3]),
}[mode]
# Pre-sparsify everything
_layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long)
# - blocksparse
layout = _layout.unsqueeze(0).expand(H, -1, -1)
a_triton = (
triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a
)
b_triton = (
triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b
)
bsmm = blocksparse_matmul(
layout=layout,
block=block,
mode=mode,
device=torch.device("cuda"),
trans_a=False,
trans_b=False,
)
# - dense
ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a
tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b
# - sparse / sputnik
mask = torch.ones_like(a, dtype=torch.float, device="cuda")
mask = triton.testing.mask_tensor(mask, layout, block, value=0.0)
a_cs = a.flatten(start_dim=0, end_dim=1).to(
torch.float32
) # Sputnik kernels only handle fp32
b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32)
a_cs = a_cs.contiguous()
b_cs = b_cs.transpose(-2, -1).contiguous()
if mode == "sdd":
b_cs = b_cs.transpose(-2, -1)
# pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the
# `SparseCS` import.
sparse_cs_mask = SparseCS(
mask.flatten(start_dim=0, end_dim=1).contiguous(),
device=torch.device("cuda"),
)
# The raw compute steps
op_flops = {
"sdd": 2 * Z * K * float(layout.sum()) * block * block,
"dsd": 2 * Z * N * float(layout.sum()) * block * block,
"dds": 2 * Z * M * float(layout.sum()) * block * block,
}[
mode
] * 1e-12 # TFlops
def torch_step():
return torch.matmul(ta, tb)
def triton_step():
return bsmm(a_triton, b_triton)
def sparse_step():
if mode == "sdd":
return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
else:
return sparse_cs_mask.spmm(b_cs)
# Run and measure, report perf in terms of TFlops
for testcase in [
TestCase(
torch_step,
f"pytorch - {mode} - {block}: ",
),
TestCase(
sparse_step,
f"sparse - {mode} - {block}: ",
),
TestCase(
triton_step,
f"triton - {mode} - {block}: ",
),
]:
ms = triton.testing.do_bench(lambda: testcase.function())[0]
key = f"M={M}, N={N}, K={K}"
if key not in results:
results[key] = {}
num_flops = op_flops / ms * 1e3 # Get to TFlop per second
results[key][testcase.name] = f"{num_flops:.1f}"
print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops")
pretty_print(
results,
title="\n ------------- Type: {} -------------".format(dtype),
units="TFlops/s",
)
pretty_plot(
results,
title=f"Sparse/Blocksparse throughput - {dtype}",
filename=f"blocksparse_{dtype}.png",
dash_key="pytorch",
units="TFlops/s",
)
shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]]
bench_matmul(torch.float16, shapes)
bench_matmul(torch.float32, shapes)

View File

@@ -0,0 +1,115 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import Activation, build_activation
from xformers.triton import FusedDropoutBias
SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(1, 2048, 12288),
(2, 4096, 4096),
]
P = 0.1
def to_gbs_fw(a, ms, bias):
# Read and write the full array
total = 2 * a.numel() * a.element_size()
if bias:
# Read the bias, ideally only once
total += a.shape[-1] * a.element_size()
return total * 1e-9 / (ms * 1e-3)
def bench_dropout(bias: bool, backward: bool, activation: Optional[Activation]):
device = torch.device("cuda")
for dtype in [
torch.float16,
torch.float32,
]:
results: Dict[str, Any] = {}
for B, M, K in SHAPES:
a = torch.rand(
(B, M, K), device=device, dtype=dtype, requires_grad=backward
)
b = torch.rand(K, device=device, dtype=dtype, requires_grad=backward)
torch_act = build_activation(activation)
triton_dropout = FusedDropoutBias(
P, bias_shape=K if bias else None, activation=activation
)
def torch_step(x):
x_ = x + b if bias else x
y = torch.nn.functional.dropout(x_, P)
if activation:
y = torch_act(y)
if backward:
y.grad = None
torch.norm(y).backward()
return y
def triton_step(x):
y = triton_dropout(x)
if backward:
y.grad = None
torch.norm(y).backward()
return y
for testcase in [
TestCase(
torch_step,
"pytorch - bias: {} - fw{} - act: {}".format(
bias, "+bw" if backward else "", activation
),
),
TestCase(
triton_step,
"triton - bias: {} - fw{} - act: {}".format(
bias, "+bw" if backward else "", activation
),
),
]:
time = triton.testing.do_bench(
lambda: testcase.function(a), grad_to_none=[a, b]
)[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
# Record BW
bandwidth = to_gbs_fw(a, time, bias)
results[key][testcase.name] = f"{bandwidth:.1f}"
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
pretty_plot(
results,
title="Dropout-Bias-{}-FW{}-{}-Act: {}".format(
bias, "+BW" if backward else "", dtype, activation
),
units="GB/s",
dash_key="pytorch",
)
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
for bw in [True, False]:
for bias in [True, False]:
bench_dropout(bias, bw, activation)

View File

@@ -0,0 +1,160 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.components import Activation, build_activation
from xformers.triton.fused_linear_layer import FusedLinear
SHAPES = [
(8, 512, 256), # Batch x Seq x Embedding
(8, 512, 512),
(4, 512, 1024),
(2, 512, 2048),
(2, 512, 4096),
(2, 512, 8192),
]
# Switch PyTorch to TF32 accumulations, Triton does that also
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_metrics_transform(
activation: Optional[Activation],
a: torch.Tensor,
w: torch.Tensor,
b: Optional[torch.Tensor],
backward: bool,
):
# all operations will involve a * weight.
flop = a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1)
# optional activation on top
if activation is not None:
flop += a.numel()
# optionally * 2 (before the bias) if backward
if backward:
flop *= 2
# backward will also output a gradient with respect to the bias
# which consolidates on all the activation gradient
flop += a.shape[0] * a.shape[1] * w.shape[1]
# backward will also ouput another gradient with respect to the weight,
# which is another matmul, in between the grad_out and the inputs this time
flop += a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1)
# optional bias on top
if b is not None:
flop += b.numel()
def metric_conversion(ms):
# Returns TFlops/second
return flop * 1e-12 / (ms * 1e-3)
return metric_conversion
def bench_linear(activations: List[Optional[Activation]]):
device = torch.device("cuda")
for dtype in [
torch.float32,
torch.float16,
]:
for backward in [True, False]:
for activation in activations:
results: Dict[str, Any] = {}
for bias in [False, True]:
for B, M, K in SHAPES:
a = torch.rand(
B, M, K, device=device, dtype=dtype, requires_grad=backward
)
# Pytorch linear layer + activation
torch_linear = torch.nn.Linear(K, 4 * K, bias=bias).to(
dtype=dtype, device=device
)
torch_activation = build_activation(activation)
# Fused layer equivalent
fused_linear = FusedLinear(
K, 4 * K, bias=bias, activation=activation
).to(dtype=dtype, device=device)
def torch_step(x):
y = torch_activation(torch_linear(x))
if backward:
torch.norm(y).backward()
return y
def triton_step(x):
y = fused_linear(x)
if backward:
torch.norm(y).backward()
return y
metrics_transform = get_metrics_transform(
activation,
a,
torch_linear.weight,
torch_linear.bias,
backward,
)
for testcase in [
TestCase(
torch_step,
"pytorch - {} - {} bias - fw{}".format(
activation,
"no" if not bias else "",
"+bw" if backward else "",
),
),
TestCase(
triton_step,
"triton - {} - {} bias - fw{}".format(
activation,
"no" if not bias else "",
"+bw" if backward else "",
),
),
]:
time = triton.testing.do_bench(
lambda: testcase.function(a)
)[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
metric = metrics_transform(time)
results[key][testcase.name] = f"{metric:.1f}"
pretty_print(
results,
title="\n --- Type: {} ---".format(dtype),
units="TFlops/s",
)
_type = "_fp16" if dtype == torch.float16 else "_fp32"
title = "FusedLinear" + _type + "_FW"
if backward:
title += "_BW"
title += "_" + activation.value if activation else "_none"
pretty_plot(results, title, "TFlops/s", dash_key="pytorch")
activations = [ac for ac in Activation] + [None] # type: ignore
bench_linear(activations) # type: ignore

View File

@@ -0,0 +1,92 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton import FusedLayerNorm
SHAPES = [
(8, 256, 512),
(8, 512, 1024),
(4, 1024, 1024),
(2, 2048, 2048),
(2, 4096, 4096),
(1, 2048, 12288),
]
def to_gbs_fw(a, ms):
# Read and write the full array
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
def bench_layernorm(backward: bool):
device = torch.device("cuda")
for dtype in [
torch.float16,
torch.bfloat16,
torch.float32,
]:
results: Dict[str, Any] = {}
for B, M, K in SHAPES:
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward)
# Pytorch layer norn
torch_layernorm = torch.nn.LayerNorm([K]).to(dtype=dtype, device=device)
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize the
# `FusedLinearNorm` import.
# Fused layernorm equivalent
fused_layernorm = FusedLayerNorm([K]).to(dtype=dtype, device=device)
def torch_step(x):
y = torch_layernorm(x)
if backward:
torch.norm(y).backward()
return y
def triton_step(x):
y = fused_layernorm(x)
if backward:
torch.norm(y).backward()
return y
for testcase in [
TestCase(
torch_step,
"pytorch - fw{}".format("+bw" if backward else ""),
),
TestCase(
triton_step,
"triton - fw{}".format("+bw" if backward else ""),
),
]:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
# Record BW
bandwidth = to_gbs_fw(a, time)
results[key][testcase.name] = f"{bandwidth:.1f}"
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
pretty_plot(
results,
title="LayerNorm-FW{}-{}".format("+BW" if backward else "", dtype),
units="GB/s",
dash_key="pytorch",
)
for bw in [False, True]:
bench_layernorm(bw)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.benchmarks.utils import TestCase, bench_functions
from xformers.triton.softmax import log_softmax as triton_log_softmax
from xformers.triton.softmax import softmax as triton_softmax
SHAPES = [
(8, 384, 128),
(8, 784, 512),
(4, 1024, 768),
(4, 2048, 1024),
(2, 2048, 2048),
(2, 2048, 4096),
(2, 4096, 4096),
(1, 2048, 12288),
]
def pytorch_fw_bw(x):
y = torch.norm(torch.softmax(x, dim=-1))
y.backward()
def triton_causal_fw(x):
_ = triton_softmax(x, causal=True)
def triton_fw_bw(x):
y = torch.norm(triton_softmax(x))
y.backward()
def triton_causal_fw_bw(x):
y = torch.norm(triton_softmax(x, causal=True))
y.backward()
def pytorch_log_fw_bw(x):
y = torch.norm(torch.log_softmax(x, dim=-1))
y.backward()
def triton_log_fw_bw(x):
y = torch.norm(triton_log_softmax(x))
y.backward()
# Test FW
def to_gbs_fw(a, ms):
# Read and write the full array
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
def to_gbs_fwbw(a, ms):
# same as above, but we do it twice (FW and then gradient)
return 2 * to_gbs_fw(a, ms)
bench_functions(
[
TestCase(lambda x: torch.softmax(x, dim=-1), "pytorch - fw"),
TestCase(triton_softmax, "triton - fw"),
TestCase(triton_causal_fw, "triton - causal - fw"),
TestCase(lambda x: torch.log_softmax(x, dim=-1), "pytorch - log - fw"),
TestCase(triton_log_softmax, "triton - log - fw"),
],
SHAPES,
to_gbs_fw,
"GB/s",
"Softmax_Bandwidth_FW_",
)
# Test FW+BW
bench_functions(
[
TestCase(pytorch_fw_bw, "pytorch - fw+bw"),
TestCase(triton_fw_bw, "triton - fw+bw"),
TestCase(triton_causal_fw_bw, "triton - causal - fw+bw"),
TestCase(pytorch_log_fw_bw, "pytorch - log - fw+bw"),
TestCase(triton_log_fw_bw, "triton - log - fw+bw"),
],
SHAPES,
to_gbs_fwbw,
"GB/s",
"Softmax_Bandwidth_FW_BW_",
)

View File

@@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
import torch
import triton
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
from xformers.triton.sum_strided import sum_2d_dim_0
SHAPES = [
(128, 128),
(384, 128),
(784, 512),
(1024, 768),
(2048, 1024),
(4096, 4096),
]
def to_gbs(a, ms):
# Read the full array, write the non-reduced dimension
return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3)
def bench_functions(
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
):
device = torch.device("cuda")
for dtype in [torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for M, N in shapes:
a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True)
for testcase in test_cases:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
metric = metric_transform(a, time)
key = f"M={M}, N={N}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{metric:.1f}"
_type = " fp16" if dtype == torch.float16 else " fp32"
pretty_print(
results,
title=" ------------- Type: {} ------------- ".format(_type),
units=unit,
)
pretty_plot(results, title + _type, unit, dash_key="pytorch")
bench_functions(
[
TestCase(lambda x: torch.sum(x, dim=0), "pytorch"),
TestCase(sum_2d_dim_0, "triton"),
],
SHAPES,
to_gbs,
"GB/s",
"Strided_sum",
)

View File

@@ -0,0 +1,660 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import copy
import csv
import glob
import logging
import math
import os
import tempfile
from collections import defaultdict, namedtuple
from dataclasses import replace
from typing import Any, Dict, Generator, List, Set, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import tqdm
from torch.utils import benchmark
sns.set()
TestCase = namedtuple("TestCase", ["function", "name"])
_triton_is_available = torch.cuda.is_available()
if _triton_is_available:
try:
import triton
except ImportError as e:
logging.warning(f"Triton is not available: {e}.\nbench_functions")
_triton_is_available = False
def pretty_print(results, title, units):
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
print(title)
header = " Units: {:<45}".format(units)
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
offset = len(header)
print(
"|-{}|".format("-" * offset)
+ "".join("{}|".format("-" * 20) for _ in results.keys())
)
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(v[k])
for k, w in workloads.items():
print(
"| {0:<{offset}}|".format(k, offset=offset)
+ "".join("{:<20}|".format(v) for v in w)
)
print("")
def pretty_plot(
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash
"""
if not filename:
filename = title + ".png"
# Sanitize the filename
filename = (
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
)
# Gather all the results in "collumns"
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(float(v[k]))
# Make sure that the plot is big enough
f = plt.figure()
f.set_figwidth(6)
f.set_figheight(6)
# Display the collections
for k, v in workloads.items():
if dash_key and dash_key in k:
plt.plot(list(results.keys()), v, "--")
else:
plt.plot(list(results.keys()), v)
plt.title(title)
plt.legend(list(workloads.keys()), loc=legend_loc)
plt.ylabel(units)
plt.xticks(rotation=45)
plt.savefig(filename, bbox_inches="tight")
plt.close(f)
if _triton_is_available:
def bench_functions(
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
):
device = torch.device("cuda")
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
results: Dict[str, Any] = {}
for B, M, K in shapes:
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True)
for testcase in test_cases:
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
metric = metric_transform(a, time)
key = f"B={B}, M={M}, K={K}"
if key not in results:
results[key] = {}
results[key][testcase.name] = f"{metric:.1f}"
pretty_print(
results,
title=" ------------- Type: {} ------------- ".format(dtype),
units=unit,
)
pretty_plot(results, title + str(dtype), unit, dash_key="pytorch")
def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
"""Graph out the contents of a dict.
Dash key means that if the result label has this key, then it will be displayed with a dash
"""
if not filename:
filename = title + ".png"
# Sanitize the filename
filename = (
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
)
xlabels = list(results.keys())
# Gather all the results in "collumns"
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
for v in results.values():
for k in v.keys():
workloads[k].append(float(v[k]))
options = list(workloads.keys())
group_len = len(options)
for key in workloads.keys():
num_groups = len(workloads[key])
break
group_width = group_len + 1
# Make sure that the plot is big enough
f = plt.figure()
f.set_figwidth(6)
f.set_figheight(6)
for idx in range(group_len):
option = options[idx]
values = workloads[option]
xloc = np.arange(1 + idx, group_width * num_groups, group_width)
plt.bar(xloc, values, width=1, edgecolor="black")
plt.title(title)
plt.legend(list(workloads.keys()), loc="upper right")
plt.ylabel(units)
ax = plt.gca()
xticks_loc = np.arange(
1 + (group_len - 1) / 2.0, group_width * num_groups, group_width
)
ax.set_xticks(xticks_loc, xlabels)
plt.xticks(rotation=45)
plt.setp(ax.xaxis.get_majorticklabels(), ha="right")
ax.set_axisbelow(True)
ax.yaxis.grid(color="gray", linestyle="dashed")
ax.xaxis.grid(color="gray", linestyle="dashed")
plt.savefig(filename, bbox_inches="tight")
plt.close(f)
def rmf(filename: str) -> None:
"""Remove a file like rm -f."""
try:
os.remove(filename)
except FileNotFoundError:
pass
@contextlib.contextmanager
def temp_files_ctx(num: int) -> Generator:
"""A context to get tempfiles and ensure they are cleaned up."""
files = [tempfile.mkstemp()[1] for _ in range(num)]
yield tuple(files)
# temp files could have been removed, so we use rmf.
for name in files:
rmf(name)
META_ALGORITHM = "algorithm"
BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"]
# Serialize/unserialize to CSV
# We could use pkl, but resort to CSV for readability
def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]:
parts = os.path.basename(filename).split(".")
env = ""
description = ""
if len(parts) == 3:
env = parts[1]
description = parts[0]
data = []
with open(filename, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
if description != "" and row["description"] not in BASELINE_DESCRIPTIONS:
row["description"] = description
task_spec = benchmark.utils.common.TaskSpec(
stmt="",
setup="",
global_setup="",
label=row["label"],
sub_label=row["sub_label"],
description=row["description"],
env=env,
num_threads=int(row["num_threads"]),
)
measurement = benchmark.utils.common.Measurement(
number_per_run=1,
raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)],
task_spec=task_spec,
)
measurement.mem_use = float(row["mem_use_mb"]) # type: ignore
data.append(
(
{
META_ALGORITHM: row["algorithm"]
if row["algorithm"] != ""
else None,
},
measurement,
)
)
return data
def _benchmark_results_to_csv(
filename: str, results: List[Tuple[Dict[str, Any], Any]]
) -> None:
data = [
{
"sub_label": r.task_spec.sub_label,
"label": r.task_spec.label,
"num_threads": r.task_spec.num_threads,
"algorithm": metadata.get(META_ALGORITHM, ""),
"description": r.task_spec.description
if r.task_spec.description in BASELINE_DESCRIPTIONS
else "",
"runtime_us": int(1000 * 1000 * r.mean),
"mem_use_mb": r.mem_use,
}
for metadata, r in results
]
with open(filename, "w+", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys()))
writer.writeheader()
for d in data:
writer.writerow(d)
def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
"""
Returns a `benchmark.Compare` object, except that if we have runs
with different algorithms, we also add the algorithm name
in the column titles
"""
all_algorithms: Set[str] = set()
all_description: Set[str] = set()
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is not None:
all_algorithms.add(algo)
all_description.add(r.task_spec.description)
display_algo = len(all_algorithms) > 1
display_descr = len(all_description) > 1
display_results = []
for metadata, r in results:
algo = metadata.get(META_ALGORITHM, None)
if algo is None:
display_results.append(r)
else:
r = copy.copy(r)
description = ""
if display_descr:
description = r.task_spec.description
if display_algo:
if display_descr:
description += "["
description += algo
if display_descr:
description += "]"
r.task_spec = replace(r.task_spec, description=description)
display_results.append(r)
return display_results
def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
if not results:
return
runtime: Dict[str, Dict[str, float]] = defaultdict(dict)
memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict)
all_descriptions: List[str] = []
for r in results:
# Hacky: use a list to preserve order
if r.task_spec.description not in all_descriptions:
if r.task_spec.description in BASELINE_DESCRIPTIONS:
all_descriptions.insert(0, r.task_spec.description)
else:
all_descriptions.append(r.task_spec.description)
runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean
memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use
all_data_mem: List[Any] = []
all_data_run: List[Any] = []
for key, runtime_values in runtime.items():
memory_values = memory_usage[key]
denom = memory_values.get(all_descriptions[0], math.inf)
if denom == 0:
all_data_mem.append([key] + [0] * len(all_descriptions))
else:
all_data_mem.append(
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
)
all_data_run.append(
[key]
+ [
runtime_values.get(all_descriptions[0], 0)
/ runtime_values.get(d, math.inf)
for d in all_descriptions
]
)
if all_descriptions[0] == "":
all_descriptions[0] = "baseline"
else:
all_descriptions[0] = f"{all_descriptions[0]} (baseline)"
for data, filename, title in [
(all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"),
(
all_data_run,
"runtime.png",
"Runtime speedup (vs baseline, higher is better)",
),
]:
df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions)
df.plot(
x="Configuration",
kind="bar",
stacked=False,
title=title,
)
plt.tight_layout()
filename_full = os.path.join(store_results_folder, filename)
plt.savefig(filename_full)
print(f"Saved plot: {filename_full}")
def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None:
"""
Helper function to run benchmarks.
Supports loading previous results for comparison, and saving current results to file.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--fn", default=None, type=str, help="Only benchmark this function"
)
parser.add_argument(
"--label", default=None, type=str, help="Store results to a file"
)
parser.add_argument(
"--fail_if_regression",
action="store_true",
help="Enabled in CI to check against performance regressions",
)
parser.add_argument(
"--compare",
default=None,
type=str,
help="Compare to previously stored benchmarks (coma separated)",
)
parser.add_argument(
"--omit-baselines",
action="store_true",
help="Do not run the (potentially slow) baselines",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Skip intermediate results and progress bar",
)
args = parser.parse_args()
if args.fn is not None and args.fn != benchmark_fn.__name__:
print(f'Skipping benchmark "{benchmark_fn.__name__}"')
return
benchmark_run_and_compare(
benchmark_fn=benchmark_fn,
cases=cases,
optimized_label="optimized" if args.label is None else args.label,
fail_if_regression=args.fail_if_regression,
compare=args.compare.split(",") if args.compare is not None else [],
quiet=args.quiet,
omit_baselines=args.omit_baselines,
**kwargs,
)
def benchmark_run_and_compare(
benchmark_fn,
cases: List[Dict[str, Any]],
compare: List[str],
omit_baselines: bool = False,
fail_if_regression: bool = False,
quiet: bool = False,
optimized_label: str = "optimized",
*,
min_run_time: int = 2,
atol_s: float = 30e-6,
rtol: float = 0.05,
) -> None:
SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True
results_compare_to = []
results = []
store_results_folder = os.path.expanduser(
os.path.join(
os.environ.get(
"XFORMERS_BENCHMARKS_CACHE",
os.path.join("~", ".cache", "xformers", "benchmarks"),
),
benchmark_fn.__name__,
)
)
try:
env = (
torch.cuda.get_device_name(torch.cuda.current_device())
.replace(" ", "_")
.replace("-", "_")
.replace(".", "_")
)
except (RuntimeError, AssertionError): # No GPU
env = "cpu"
assert (
"." not in optimized_label
), f"label=`{optimized_label}` should not contain dots"
assert "." not in env, f"env=`{env}` should not contain dots"
os.makedirs(store_results_folder, exist_ok=True)
# Load runs that we want to compare to
skip_vanilla_tasks = set()
for cmp_name in compare:
name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*"
for filename in glob.glob(
os.path.join(store_results_folder, f"{name_with_env}.csv")
):
loaded = _benchmark_results_from_csv(filename)
for m, r in loaded:
if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE:
skip_vanilla_tasks.add(
(r.task_spec.sub_label, r.task_spec.num_threads)
)
results_compare_to += loaded
if not quiet:
pbar = tqdm.tqdm(cases, leave=False)
cases = pbar
for case in cases:
if quiet:
print(str(case))
else:
pbar.write(f"====== {str(case)} ======")
try:
benchmarks_generator = benchmark_fn(**case)
except NotImplementedError:
# pbar.write(f"Skipped (NotImplementedError)")
continue
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
continue
name = None
try:
for benchmark_object in benchmarks_generator:
is_optimized = (
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
)
metadata = {}
if is_optimized:
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
benchmark_object._task_spec = replace(
benchmark_object._task_spec, description=optimized_label
)
elif (
omit_baselines
or (
benchmark_object._task_spec.sub_label,
benchmark_object._task_spec.num_threads,
)
in skip_vanilla_tasks
):
continue
memory = math.inf
try:
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
mem_begin = torch.cuda.max_memory_allocated() / 2**20
benchmark_object._task_spec = replace(
benchmark_object._task_spec, env=env
)
measurement = benchmark_object.blocked_autorange(
min_run_time=min_run_time
)
torch.cuda.synchronize()
results.append((metadata, measurement))
name = measurement.task_spec.description
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
measurement.mem_use = memory
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
finally:
del benchmark_object
if not quiet:
pbar.write(f"{name}: memory used: {memory} MB")
except RuntimeError as e:
if "CUDA out of memory" not in str(e):
raise
if not quiet:
pbar.write("Skipped (OOM)")
# Display results for benchmarks we just calculated
if name is not None and not quiet:
def matches_current(r):
return (
r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label
and r[1].task_spec.label == results[-1][1].task_spec.label
)
pbar.write(
str(
benchmark.Compare(
_finalize_results(
list(filter(matches_current, results))
+ list(filter(matches_current, results_compare_to))
)
)
)
)
results_for_print = _finalize_results(results + results_compare_to)
benchmark.Compare(results_for_print).print()
_render_bar_plot(results_for_print, store_results_folder)
# Save runs to a file
if results and optimized_label is not None:
write_to_path = os.path.join(
store_results_folder, f"{optimized_label}.{env}.csv"
)
_benchmark_results_to_csv(write_to_path, results)
print(f"Saved results to {write_to_path}")
if fail_if_regression:
_fail_if_regressions(
results, reference=results_compare_to, atol_s=atol_s, rtol=rtol
)
def _fail_if_regressions(
results: List[Any], reference: List[Any], atol_s: float, rtol: float
) -> None:
def get_measurement_id(r):
return (
r[0].get(META_ALGORITHM, ""),
r[1].task_spec.label,
r[1].task_spec.sub_label,
r[1].task_spec.env,
)
id_to_result = {}
for r in results:
id_to_result[get_measurement_id(r)] = r[1]
num_better = 0
num_worse = 0
num_nochange = 0
num_unk = 0
reference_set = set()
for ref in reference:
if ref[1].task_spec.description in BASELINE_DESCRIPTIONS:
continue
benchmark_id = get_measurement_id(ref)
if benchmark_id in reference_set:
raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}")
reference_set.add(benchmark_id)
if benchmark_id not in id_to_result:
num_unk += 1
continue
res = id_to_result[benchmark_id]
# If significative change
if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s:
is_now_better = res.mean < ref[1].mean
if is_now_better:
num_better += 1
else:
num_worse += 1
cmp = "IMPROVED" if is_now_better else "REGRESS "
print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}")
else:
num_nochange += 1
print("Regression test summary:")
print(f" Better : {num_better}")
print(f" No change: {num_nochange}")
print(f" Worse : {num_worse}")
if num_unk > 0:
print(f" (no ref) : {num_unk}")
if num_worse > 1:
raise RuntimeError("At least one benchmark regressed!")
if num_nochange == 0:
raise RuntimeError("No reference found")