First commit
This commit is contained in:
4
pkgs/xformers/benchmarks/LRA/__init__.py
Normal file
4
pkgs/xformers/benchmarks/LRA/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
96
pkgs/xformers/benchmarks/LRA/batch_fetch_results.py
Normal file
96
pkgs/xformers/benchmarks/LRA/batch_fetch_results.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Get the user requests
|
||||
parser = argparse.ArgumentParser(
|
||||
"Collect results from a given batch of distributed results"
|
||||
)
|
||||
parser.add_argument("-ck", "--checkpoint_path", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Go through all the data in the given repo, try to find the end results
|
||||
root = Path(args.checkpoint_path)
|
||||
|
||||
# - list all the mechanisms being benchmarked
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
|
||||
logging.info(f"\nFound results for {attention.stem}")
|
||||
task_jsons = attention.glob("*/test_eval_summary.json")
|
||||
results[attention.stem] = {}
|
||||
|
||||
for task in task_jsons:
|
||||
task_name = task.stem.split("__")[0]
|
||||
logging.info(f"Logs found for task: {task_name}")
|
||||
results[attention.stem][task_name] = -1
|
||||
found_result = False
|
||||
|
||||
# - collect the individual results
|
||||
with open(task, "r") as result_file:
|
||||
dct = json.load(result_file)
|
||||
if "test_accu_mean" in dct:
|
||||
found_result = True
|
||||
results[attention.stem][task_name] = dct["test_accu_mean"]
|
||||
|
||||
logging.info(
|
||||
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
|
||||
f"{results[attention.stem][task_name]}"
|
||||
)
|
||||
else:
|
||||
break
|
||||
|
||||
# - report an error if no result was found
|
||||
if not found_result:
|
||||
ERR_TAIL = 30
|
||||
|
||||
logging.warning(
|
||||
f"No result found for {task_name}, showing the error log in {task.parent}"
|
||||
)
|
||||
err_log = Path(task.parent).glob("*.err")
|
||||
print("*****************************************************")
|
||||
with open(next(err_log), "r") as err_file:
|
||||
for i, line in enumerate(reversed(err_file.readlines())):
|
||||
print(line, end="")
|
||||
if i > ERR_TAIL:
|
||||
break
|
||||
print("*****************************************************")
|
||||
|
||||
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
|
||||
|
||||
# - reduction: compute the average
|
||||
tasks = set(t for v in results.values() for t in v.keys())
|
||||
# -- fill in the possible gaps
|
||||
for att in results.keys():
|
||||
for t in tasks:
|
||||
if t not in results[att].keys():
|
||||
results[att][t] = 0.0
|
||||
|
||||
# -- add the average value
|
||||
for att in results.keys():
|
||||
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
|
||||
|
||||
# - Format as an array, markdown style
|
||||
tasks_sort = sorted(
|
||||
set(t for v in results.values() for t in v.keys()), reverse=True
|
||||
)
|
||||
print(
|
||||
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
|
||||
)
|
||||
|
||||
for att in results.keys():
|
||||
print(
|
||||
"{0:<20}".format(att)
|
||||
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
|
||||
)
|
||||
49
pkgs/xformers/benchmarks/LRA/batch_submit.py
Normal file
49
pkgs/xformers/benchmarks/LRA/batch_submit.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from xformers.benchmarks.LRA.run_tasks import Task
|
||||
from xformers.components.attention import ATTENTION_REGISTRY
|
||||
|
||||
|
||||
def get_default_shared_folder() -> str:
|
||||
checkpoint_paths = ["/checkpoint", "/checkpoints"]
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
if Path(checkpoint_path).is_dir():
|
||||
return checkpoint_path
|
||||
|
||||
return "."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
default_checkpoint_path = get_default_shared_folder()
|
||||
|
||||
# Get the user requests
|
||||
parser = argparse.ArgumentParser(
|
||||
"Benchmark different attention mechanisms on various sequence lengths"
|
||||
)
|
||||
parser.add_argument("-c", "--config_path", required=True)
|
||||
parser.add_argument("-ck", "--checkpoint_path", required=True)
|
||||
parser.add_argument(
|
||||
"-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys())
|
||||
)
|
||||
parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task])
|
||||
parser.add_argument(
|
||||
"--partition", default="a100", type=str, help="Partition where to submit"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
for attention in args.attentions:
|
||||
for task in args.tasks:
|
||||
os.system(
|
||||
"python3 run_with_submitit.py"
|
||||
+ f" --attention {attention} --task {task} --config {args.config_path}"
|
||||
+ f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
|
||||
+ f" --partition {args.partition}"
|
||||
)
|
||||
4
pkgs/xformers/benchmarks/LRA/code/__init__.py
Normal file
4
pkgs/xformers/benchmarks/LRA/code/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
46
pkgs/xformers/benchmarks/LRA/code/dataset.py
Normal file
46
pkgs/xformers/benchmarks/LRA/code/dataset.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: Almost as-is from the Nystromformer repo
|
||||
# https://github.com/mlpen/Nystromformer
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LRADataset(Dataset):
|
||||
def __init__(self, file_path, seq_len):
|
||||
with open(file_path, "rb") as f:
|
||||
self.examples = pickle.load(f)
|
||||
|
||||
self.seq_len = seq_len
|
||||
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.create_inst(self.examples[i], self.seq_len)
|
||||
|
||||
@staticmethod
|
||||
def create_inst(inst, seq_len):
|
||||
output = {
|
||||
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
|
||||
}
|
||||
output["mask_0"] = (output["input_ids_0"] != 0).float()
|
||||
|
||||
if "input_ids_1" in inst:
|
||||
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
|
||||
:seq_len
|
||||
]
|
||||
output["mask_1"] = (output["input_ids_1"] != 0).float()
|
||||
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
|
||||
return output
|
||||
288
pkgs/xformers/benchmarks/LRA/code/model_wrapper.py
Normal file
288
pkgs/xformers/benchmarks/LRA/code/model_wrapper.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: adapted from the Nystromformer repo
|
||||
# https://github.com/mlpen/Nystromformer
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import build_attention
|
||||
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
|
||||
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
|
||||
from xformers.utils import generate_matching_config
|
||||
|
||||
PLOutput = Dict[str, Union[float, torch.Tensor]]
|
||||
|
||||
|
||||
class Pooling(str, Enum):
|
||||
MEAN = "mean"
|
||||
CLS = "cls"
|
||||
|
||||
|
||||
def pooling(mode: Pooling):
|
||||
def pool_cls(inp):
|
||||
return inp[:, 0, :]
|
||||
|
||||
def pool_mean(inp):
|
||||
return inp.mean(dim=1)
|
||||
|
||||
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
|
||||
|
||||
|
||||
def append_cls(inp, mask, vocab_size):
|
||||
batch_size = inp.size(0)
|
||||
cls_id = (
|
||||
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
|
||||
).long()
|
||||
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
|
||||
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
|
||||
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
|
||||
return inp, mask
|
||||
|
||||
|
||||
def patch_model_config(config, attention_name):
|
||||
# Rebuild a specific config out of generic + extra params
|
||||
commons = config["common"]
|
||||
try:
|
||||
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
|
||||
except KeyError:
|
||||
extra_attention_settings = None
|
||||
|
||||
for bc in config["xformer"]:
|
||||
bc["dim_model"] = commons["dim_model"]
|
||||
bc["position_encoding_config"].update(commons)
|
||||
bc["feedforward_config"].update(commons)
|
||||
bc["multi_head_config"].update(commons)
|
||||
bc["multi_head_config"]["attention"].update(commons)
|
||||
bc["multi_head_config"]["attention"]["name"] = attention_name
|
||||
bc["multi_head_config"]["attention"]["dim_head"] = (
|
||||
commons["dim_model"] / commons["num_heads"]
|
||||
)
|
||||
if extra_attention_settings is not None:
|
||||
bc["multi_head_config"]["attention"].update(extra_attention_settings)
|
||||
|
||||
bc["multi_head_config"] = generate_matching_config(
|
||||
bc["multi_head_config"], MultiHeadDispatchConfig
|
||||
)
|
||||
bc["multi_head_config"].attention = build_attention(
|
||||
bc["multi_head_config"].attention
|
||||
)
|
||||
bc = generate_matching_config(bc, xFormerEncoderConfig)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class SCHead(nn.Module):
|
||||
def __init__(self, config, dim_embedding, dim_mlp):
|
||||
super().__init__()
|
||||
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
||||
|
||||
self.mlpblock = nn.Sequential(
|
||||
nn.Linear(dim_embedding, dim_mlp),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
||||
)
|
||||
|
||||
def forward(self, inp: torch.Tensor):
|
||||
seq_score = self.mlpblock(self.pooling(inp))
|
||||
return seq_score
|
||||
|
||||
|
||||
class SCHeadDual(nn.Module):
|
||||
def __init__(self, config, dim_embedding, dim_mlp):
|
||||
super().__init__()
|
||||
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
||||
|
||||
self.mlpblock = nn.Sequential(
|
||||
nn.Linear(
|
||||
dim_embedding * 4,
|
||||
dim_mlp,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
||||
)
|
||||
|
||||
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
|
||||
X_0 = self.pooling(inp_0)
|
||||
X_1 = self.pooling(inp_1)
|
||||
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
|
||||
return seq_score
|
||||
|
||||
|
||||
class ModelTrunk(pl.LightningModule):
|
||||
def __init__(self, config, model_name):
|
||||
super().__init__()
|
||||
|
||||
config_model = config["model"]
|
||||
self.config_training = config["training"]
|
||||
|
||||
self.enable_amp = config["training"]["mixed_precision"]
|
||||
self.pooling_mode = Pooling(config_model["pooling_mode"])
|
||||
self.vocab_size = config_model["common"]["vocab_size"]
|
||||
|
||||
# Rebuild a specific config out of generic + extra params
|
||||
self.config_model = patch_model_config(config_model, model_name)
|
||||
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
|
||||
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
|
||||
|
||||
ff_config = self.config_model["xformer"][0]["feedforward_config"]
|
||||
self.dim_mlp = (
|
||||
self.config_model["common"]["dim_model"]
|
||||
* ff_config["hidden_layer_multiplier"]
|
||||
)
|
||||
|
||||
def training_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
outputs = self(**batch)
|
||||
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
|
||||
self.log("train_accu", outputs["accu"], sync_dist=True)
|
||||
return outputs
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
logs = self.eval_epoch_end(outputs)
|
||||
self.log("train_accu_mean", logs["accu"], sync_dist=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=self.config_training["learning_rate"],
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=self.config_training["weight_decay"],
|
||||
)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.config_training["learning_rate"],
|
||||
pct_start=self.config_training["warmup"]
|
||||
/ self.config_training["num_train_steps"],
|
||||
anneal_strategy=self.config_training["lr_decay"],
|
||||
total_steps=self.config_training["num_train_steps"],
|
||||
)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
|
||||
outputs = self(**batch)
|
||||
return outputs
|
||||
|
||||
def eval_epoch_end(self, outputs, prefix: str = "train"):
|
||||
logs = {}
|
||||
counts = torch.tensor([x["count"] for x in outputs]).float()
|
||||
logs["count"] = counts.sum()
|
||||
for k in ("accu", "loss"):
|
||||
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
|
||||
"count"
|
||||
]
|
||||
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
|
||||
return logs
|
||||
|
||||
def validation_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
outputs = self.eval_step(batch, batch_idx)
|
||||
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
|
||||
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
|
||||
return outputs
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.eval_epoch_end(outputs, prefix="val")
|
||||
|
||||
def test_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
return self.eval_step(batch, batch_idx)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
self.eval_epoch_end(outputs, prefix="test")
|
||||
|
||||
|
||||
class ModelForSC(ModelTrunk):
|
||||
def __init__(self, config, model_name):
|
||||
# Setup trunk
|
||||
super().__init__(config, model_name)
|
||||
|
||||
self.seq_classifer = SCHead(
|
||||
self.config_model,
|
||||
dim_embedding=self.config_model["common"]["dim_model"],
|
||||
dim_mlp=self.dim_mlp,
|
||||
)
|
||||
|
||||
def forward( # type: ignore
|
||||
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
|
||||
):
|
||||
|
||||
if self.pooling_mode == Pooling.CLS:
|
||||
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
||||
|
||||
token_out = self.norm(
|
||||
self.model(input_ids_0, encoder_input_mask=mask_0)
|
||||
) * mask_0.unsqueeze(-1)
|
||||
|
||||
seq_scores = self.seq_classifer(token_out)
|
||||
|
||||
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
||||
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
||||
outputs = {
|
||||
"loss": seq_loss.mean(),
|
||||
"accu": seq_accu.mean(),
|
||||
"count": label.size(0),
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ModelForSCDual(ModelTrunk):
|
||||
def __init__(self, config, model_name):
|
||||
# Setup trunk
|
||||
super().__init__(config, model_name)
|
||||
|
||||
self.seq_classifer = SCHeadDual(
|
||||
self.config_model,
|
||||
dim_embedding=self.config_model["common"]["dim_model"],
|
||||
dim_mlp=self.dim_mlp,
|
||||
)
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
input_ids_0: torch.Tensor,
|
||||
input_ids_1: torch.Tensor,
|
||||
mask_0: torch.Tensor,
|
||||
mask_1: torch.Tensor,
|
||||
label: torch.Tensor,
|
||||
):
|
||||
|
||||
mask_0, mask_1 = mask_0.long(), mask_1.long()
|
||||
|
||||
if self.pooling_mode == Pooling.CLS:
|
||||
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
||||
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
|
||||
|
||||
# Concatenate the two inputs into one batch
|
||||
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
|
||||
masks = torch.cat([mask_0, mask_1], dim=0)
|
||||
|
||||
tokens_out = self.norm(
|
||||
self.model(input_ids, encoder_input_mask=masks)
|
||||
) * masks.unsqueeze(-1)
|
||||
|
||||
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
|
||||
|
||||
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
||||
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
||||
outputs = {
|
||||
"loss": seq_loss.mean(),
|
||||
"accu": seq_accu.mean(),
|
||||
"count": label.size(0),
|
||||
}
|
||||
|
||||
return outputs
|
||||
148
pkgs/xformers/benchmarks/LRA/run_grid_search.py
Normal file
148
pkgs/xformers/benchmarks/LRA/run_grid_search.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import uuid
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
|
||||
import submitit
|
||||
|
||||
from xformers.benchmarks.LRA.run_with_submitit import (
|
||||
Trainer,
|
||||
get_init_file,
|
||||
get_shared_folder,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
|
||||
def grid_parameters(grid: Dict):
|
||||
"""
|
||||
Yield all combinations of parameters in the grid (as a dict)
|
||||
"""
|
||||
grid_copy = dict(grid)
|
||||
# Turn single value in an Iterable
|
||||
for k in grid_copy:
|
||||
if not isinstance(grid_copy[k], Iterable):
|
||||
grid_copy[k] = [grid_copy[k]]
|
||||
for p in itertools.product(*grid_copy.values()):
|
||||
yield dict(zip(grid.keys(), p))
|
||||
|
||||
|
||||
def grid_search(args):
|
||||
if args.checkpoint_dir == "":
|
||||
args.checkpoint_dir = get_shared_folder() / "%j"
|
||||
|
||||
date_curr = date.today().strftime("%m-%d-%Y")
|
||||
orig_check_dir = os.path.join(args.checkpoint_dir, date_curr)
|
||||
|
||||
# Create the executor
|
||||
# Note that the folder will depend on the job_id, to easily track experiments
|
||||
executor = submitit.AutoExecutor(
|
||||
folder=get_shared_folder() / "%j", slurm_max_num_timeout=30
|
||||
)
|
||||
num_gpus_per_node = args.ngpus
|
||||
nodes = args.nodes
|
||||
args.world_size = args.nodes * args.ngpus
|
||||
partition = args.partition
|
||||
|
||||
executor.update_parameters(
|
||||
gpus_per_node=num_gpus_per_node,
|
||||
tasks_per_node=num_gpus_per_node, # one task per GPU
|
||||
cpus_per_task=10,
|
||||
nodes=nodes,
|
||||
timeout_min=60 * 72,
|
||||
slurm_signal_delay_s=120,
|
||||
slurm_partition=partition,
|
||||
)
|
||||
executor.update_parameters(name="lra")
|
||||
|
||||
if args.task == "text":
|
||||
grid_meta = {
|
||||
"training:learning_rate": (
|
||||
[1e-4, 2e-4, 3e-4, 5e-5],
|
||||
lambda val: f"lr{val}",
|
||||
),
|
||||
"training:warmup": ([3000, 8000], lambda val: f"warmup{val}"),
|
||||
"training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"),
|
||||
"training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"),
|
||||
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
|
||||
"model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"),
|
||||
}
|
||||
elif args.task == "retrieval":
|
||||
grid_meta = {
|
||||
"training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"),
|
||||
"training:warmup": ([2000, 8000], lambda val: f"warmup{val}"),
|
||||
"training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"),
|
||||
"training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"),
|
||||
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
|
||||
"model:common:dropout": ([0], lambda val: f"drop{val}"),
|
||||
}
|
||||
elif args.task == "listops":
|
||||
grid_meta = {
|
||||
"training:learning_rate": (
|
||||
[1e-4, 2e-4, 3e-4, 5e-5],
|
||||
lambda val: f"lr{val}",
|
||||
),
|
||||
"training:warmup": ([3000, 2000], lambda val: f"warmup{val}"),
|
||||
"training:seed": (
|
||||
[
|
||||
1234,
|
||||
],
|
||||
lambda val: f"seed{val}",
|
||||
),
|
||||
"training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"),
|
||||
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
|
||||
"model:common:dropout": ([0], lambda val: f"drop{val}"),
|
||||
}
|
||||
else:
|
||||
grid_meta = {
|
||||
"training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"),
|
||||
"training:warmup": ([8000], lambda val: f"warmup{val}"),
|
||||
"training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"),
|
||||
"training:weight_decay": ([0.01], lambda val: f"wd{val}"),
|
||||
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
|
||||
"model:common:dropout": ([0.1], lambda val: f"drop{val}"),
|
||||
}
|
||||
|
||||
grid = {k: v[0] for k, v in grid_meta.items()}
|
||||
save_key = {k: v[1] for k, v in grid_meta.items()}
|
||||
|
||||
hyper_parameters = list(grid_parameters(grid))
|
||||
jobs = []
|
||||
|
||||
for i, grid_data in enumerate(hyper_parameters):
|
||||
|
||||
args.sweep_parameters = grid_data
|
||||
run_name = f"{args.attention}"
|
||||
# run_name = "paper_config"
|
||||
for k, v in grid_data.items():
|
||||
run_name += "prenorm-" + save_key[k](v)
|
||||
args.checkpoint_dir = os.path.join(
|
||||
orig_check_dir, f"{args.task}", "logs", run_name
|
||||
)
|
||||
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||
args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name)
|
||||
Path(args.tb_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Chronos needs a different job name each time
|
||||
executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}")
|
||||
|
||||
args.dist_url = get_init_file().as_uri()
|
||||
args.temp_file = str(get_init_file())
|
||||
|
||||
trainer = Trainer(args)
|
||||
job = executor.submit(trainer)
|
||||
jobs.append(job)
|
||||
print(f"Run {i:02d} submitted with train cfg: {args}")
|
||||
print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
grid_search(args)
|
||||
298
pkgs/xformers/benchmarks/LRA/run_tasks.py
Normal file
298
pkgs/xformers/benchmarks/LRA/run_tasks.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fvcore.nn import FlopCountAnalysis, flop_count_str
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from xformers.benchmarks.LRA.code.dataset import LRADataset
|
||||
from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
|
||||
from xformers.components.attention import ATTENTION_REGISTRY
|
||||
|
||||
|
||||
class Task(str, Enum):
|
||||
Retrieval = "retrieval"
|
||||
ListOps = "listops"
|
||||
Image = "image"
|
||||
PathfinderBaseline = "pathfinder32-curv_baseline"
|
||||
PathfinderContour9 = "pathfinder32-curv_contour_length_9"
|
||||
PathfinderContour14 = "pathfinder32-curv_contour_length_14"
|
||||
Text = "text"
|
||||
|
||||
|
||||
def load_config(path: str) -> Dict:
|
||||
with open(Path(path).absolute(), "r") as fileio:
|
||||
config = json.load(fileio)
|
||||
|
||||
# Duplicate the pathfinder configs
|
||||
config["pathfinder32-curv_baseline"] = config["pathfinder32"]
|
||||
config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
|
||||
config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
|
||||
return config
|
||||
|
||||
|
||||
def build_model(args: argparse.Namespace, config: Dict) -> nn.Module:
|
||||
task = args.task
|
||||
attention_name = args.attention
|
||||
|
||||
model = cast(
|
||||
pl.LightningModule,
|
||||
ModelForSCDual(config[f"{task}"], attention_name)
|
||||
if task == Task.Retrieval
|
||||
else ModelForSC(config[f"{task}"], attention_name),
|
||||
)
|
||||
|
||||
logging.info(model)
|
||||
summary = pl.utilities.model_summary.LayerSummary(model)
|
||||
logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M")
|
||||
|
||||
with torch.no_grad():
|
||||
# Check the flops
|
||||
seq_len = config[f"{task}"]["model"]["common"]["seq_len"]
|
||||
x = torch.rand(1, seq_len).long()
|
||||
mask = torch.rand(1, seq_len).long()
|
||||
indices = torch.rand(1, seq_len).long()
|
||||
flops = FlopCountAnalysis(model.model, (x, mask, indices))
|
||||
logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops")
|
||||
logging.info(flop_count_str(flops))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--attention",
|
||||
type=str,
|
||||
help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
|
||||
A list can be passed to test several mechanisms in sequence",
|
||||
dest="attention",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=Task,
|
||||
help=f"Task to chose, among {[t.value for t in Task]}.",
|
||||
dest="task",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_train",
|
||||
type=bool,
|
||||
help="Whether to skip training, and test an existing model",
|
||||
dest="skip_train",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to the config being used",
|
||||
dest="config",
|
||||
default="./config.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
help="Path to the checkpoint directory",
|
||||
dest="checkpoint_dir",
|
||||
default=f"/checkpoints/{os.getenv('USER')}/xformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
help="Path to checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
help="Make it easier to debug a possible issue",
|
||||
dest="debug",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--world_size",
|
||||
help="Number of GPUs used",
|
||||
dest="world_size",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep_parameters",
|
||||
help="Rewrite some hyperparameters in the config",
|
||||
dest="sweep_parameters",
|
||||
type=dict,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]:
|
||||
experiment_name = f"{task}__{attention_name}"
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=args.checkpoint_dir,
|
||||
name="", # remove lightning_logs subdirectory
|
||||
version=experiment_name,
|
||||
)
|
||||
log_dir = os.path.join(logger._save_dir, experiment_name)
|
||||
return log_dir, logger
|
||||
|
||||
|
||||
def rewrite_hyper(config, rewrites):
|
||||
def replace(config_dict, k, v):
|
||||
if len(k.split(":")) == 1:
|
||||
config_dict[k] = v
|
||||
return
|
||||
first_key = k.split(":")[0]
|
||||
assert first_key in config_dict, first_key
|
||||
k = k[len(first_key) + 1 :]
|
||||
replace(config_dict[first_key], k, v)
|
||||
|
||||
for k, v in rewrites.items():
|
||||
replace(config, k, v)
|
||||
return config
|
||||
|
||||
|
||||
def build_dataloaders(
|
||||
args: argparse.Namespace,
|
||||
config_training: Dict,
|
||||
num_workers: int = 4,
|
||||
) -> Dict[str, DataLoader]:
|
||||
datasets = {}
|
||||
for component in ("train", "dev", "test"):
|
||||
datasets[component] = LRADataset(
|
||||
file_path=f"datasets/{args.task}.{component}.pickle",
|
||||
seq_len=config_training["seq_len"],
|
||||
)
|
||||
|
||||
# Gradient accumulation
|
||||
accumu_steps = config_training["gradient_accumulation"]
|
||||
logging.info(f"accumu_steps={accumu_steps}")
|
||||
|
||||
# Batch size
|
||||
per_gpu_batch_size = (
|
||||
config_training["batch_size"] // args.world_size // accumu_steps
|
||||
)
|
||||
logging.warning(
|
||||
f"Requested batch size: {config_training['batch_size']}. Given world\
|
||||
size and grad accumulation, per-gpu batch is\
|
||||
{per_gpu_batch_size}"
|
||||
)
|
||||
|
||||
dataloaders = {
|
||||
k: DataLoader(
|
||||
v,
|
||||
batch_size=per_gpu_batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
for k, v in datasets.items()
|
||||
}
|
||||
return dataloaders
|
||||
|
||||
|
||||
def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]:
|
||||
eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step}
|
||||
for k, v in trainer.callback_metrics.items():
|
||||
eval_summary[k] = v.item()
|
||||
return eval_summary
|
||||
|
||||
|
||||
class BasicProgressBar(TQDMProgressBar):
|
||||
def get_metrics(self, trainer, model):
|
||||
items = super().get_metrics(trainer, model)
|
||||
items.pop("v_num", None)
|
||||
return items
|
||||
|
||||
|
||||
def benchmark(args):
|
||||
log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}")
|
||||
args.logger = logger
|
||||
|
||||
config = load_config(args.config)
|
||||
|
||||
config_task = config[f"{args.task}"]
|
||||
if args.sweep_parameters is not None:
|
||||
logging.info("Replacing hyperparameters")
|
||||
rewrite_hyper(config_task, args.sweep_parameters)
|
||||
|
||||
config_training = config_task["training"]
|
||||
config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
|
||||
logging.info(f"Learning rate: {config_training['learning_rate']}")
|
||||
|
||||
pl.seed_everything(config_training.get("seed", 0))
|
||||
dataloaders = build_dataloaders(args, config_training)
|
||||
|
||||
model = build_model(args, config)
|
||||
|
||||
progress_bar = BasicProgressBar()
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_accu",
|
||||
mode="max",
|
||||
dirpath=args.checkpoint_dir,
|
||||
filename="{epoch}-{val_accu:.2f}",
|
||||
every_n_train_steps=config_training["eval_frequency"],
|
||||
)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
strategy=DDPStrategy(find_unused_parameters=args.debug)
|
||||
if not args.skip_train
|
||||
else None,
|
||||
accumulate_grad_batches=config_training["gradient_accumulation"],
|
||||
callbacks=[progress_bar, checkpoint_callback],
|
||||
detect_anomaly=args.debug,
|
||||
deterministic=True,
|
||||
gpus=args.world_size,
|
||||
limit_val_batches=config_training["num_eval_steps"],
|
||||
logger=logger,
|
||||
max_steps=config_training["num_train_steps"],
|
||||
num_sanity_val_steps=int(not args.skip_train),
|
||||
precision=16 if config_training["mixed_precision"] else 32,
|
||||
val_check_interval=config_training["eval_frequency"]
|
||||
/ float(len(dataloaders["train"])),
|
||||
)
|
||||
|
||||
if not args.skip_train:
|
||||
trainer.fit(
|
||||
model,
|
||||
train_dataloaders=dataloaders["train"],
|
||||
val_dataloaders=dataloaders["dev"],
|
||||
)
|
||||
ckpt_path = checkpoint_callback.best_model_path
|
||||
else:
|
||||
ckpt_path = args.checkpoint_path
|
||||
|
||||
trainer.test(
|
||||
model,
|
||||
dataloaders=dataloaders["test"],
|
||||
ckpt_path=ckpt_path,
|
||||
)
|
||||
eval_summary = get_eval_summary(trainer)
|
||||
with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f:
|
||||
logging.info(f"Saving test results at {f.name}")
|
||||
json.dump(eval_summary, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_arg_parser()
|
||||
args = parser.parse_args()
|
||||
if args.skip_train and args.checkpoint_path is None:
|
||||
raise parser.error("Must provide --checkpoint_path if --skip_train=True")
|
||||
benchmark(args)
|
||||
153
pkgs/xformers/benchmarks/LRA/run_with_submitit.py
Normal file
153
pkgs/xformers/benchmarks/LRA/run_with_submitit.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
"""
|
||||
A script to run multinode training with submitit.
|
||||
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import submitit
|
||||
|
||||
from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Submitit for LRA", parents=[get_arg_parser()], add_help=False
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ngpus", default=1, type=int, help="Number of gpus to request on each node"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nodes", default=1, type=int, help="Number of nodes to request"
|
||||
)
|
||||
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
|
||||
|
||||
parser.add_argument(
|
||||
"--partition", default="a100", type=str, help="Partition where to submit"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_volta32", action="store_true", help="Big models? Use this"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enforce_host_memory", action="store_true", help="Use if the host OOMs"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
default="",
|
||||
type=str,
|
||||
help="Comment to pass to scheduler, e.g. priority message",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_shared_folder() -> Path:
|
||||
user = os.getenv("USER")
|
||||
checkpoint_paths = ["/checkpoint", "/checkpoints"]
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
if Path(checkpoint_path).is_dir():
|
||||
p = Path(f"{checkpoint_path}/{user}/xformers/submitit")
|
||||
p.mkdir(exist_ok=True, parents=True)
|
||||
return p
|
||||
raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}")
|
||||
|
||||
|
||||
def get_init_file():
|
||||
# Init file must not exist, but it's parent dir must exist.
|
||||
os.makedirs(str(get_shared_folder()), exist_ok=True)
|
||||
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
|
||||
if init_file.exists():
|
||||
os.remove(str(init_file))
|
||||
return init_file
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
self._setup_gpu_args()
|
||||
benchmark(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
self.args.dist_url = get_init_file().as_uri()
|
||||
print("Requeuing ", self.args)
|
||||
empty_trainer = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty_trainer)
|
||||
|
||||
def _setup_gpu_args(self):
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.checkpoint_dir = Path(
|
||||
str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id))
|
||||
)
|
||||
self.args.gpu = job_env.local_rank
|
||||
self.args.rank = job_env.global_rank
|
||||
self.args.world_size = job_env.num_tasks
|
||||
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.checkpoint_dir == "":
|
||||
args.checkpoint_dir = get_shared_folder() / "%j"
|
||||
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||
executor = submitit.AutoExecutor(
|
||||
folder=args.checkpoint_dir, slurm_max_num_timeout=30
|
||||
)
|
||||
|
||||
num_gpus_per_node = args.ngpus
|
||||
nodes = args.nodes
|
||||
timeout_min = args.timeout
|
||||
args.world_size = args.nodes * args.ngpus
|
||||
|
||||
partition = args.partition
|
||||
|
||||
kwargs = {
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": 10,
|
||||
"nodes": nodes,
|
||||
"timeout_min": timeout_min, # max is 60 * 72
|
||||
# Below are cluster dependent parameters
|
||||
"slurm_partition": partition,
|
||||
"slurm_signal_delay_s": 120,
|
||||
}
|
||||
|
||||
if args.enforce_host_memory:
|
||||
kwargs["mem_gb"] = (40 * num_gpus_per_node,)
|
||||
|
||||
if args.use_volta32:
|
||||
kwargs["slurm_constraint"] = "volta32gb"
|
||||
|
||||
if args.comment:
|
||||
kwargs["slurm_comment"] = args.comment
|
||||
|
||||
executor.update_parameters(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
executor.update_parameters(name="lra")
|
||||
|
||||
args.dist_url = get_init_file().as_uri()
|
||||
args.temp_file = str(get_init_file())
|
||||
|
||||
trainer = Trainer(args)
|
||||
job = executor.submit(trainer)
|
||||
|
||||
print(f"Submitted job_id: {job.job_id}")
|
||||
print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}")
|
||||
with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile:
|
||||
jobfile.write(f"{job.job_id}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
pkgs/xformers/benchmarks/__init__.py
Normal file
4
pkgs/xformers/benchmarks/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
BIN
pkgs/xformers/benchmarks/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/benchmarks/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/benchmarks/__pycache__/utils.cpython-310.pyc
Normal file
BIN
pkgs/xformers/benchmarks/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
159
pkgs/xformers/benchmarks/benchmark_attn_decoding.py
Normal file
159
pkgs/xformers/benchmarks/benchmark_attn_decoding.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = [
|
||||
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128)
|
||||
for i in range(8, 18)
|
||||
] + [
|
||||
dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128)
|
||||
for i in range(8, 18)
|
||||
]
|
||||
|
||||
|
||||
def _setup_test(
|
||||
functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs
|
||||
):
|
||||
for k, benchmark_cls in functions.items():
|
||||
benchmark_object = benchmark_cls(**kwargs, bw=bw)
|
||||
label = benchmark_object.label
|
||||
label += "fw" if fw else ""
|
||||
label += "bw" if bw else ""
|
||||
|
||||
def run_one():
|
||||
if fw:
|
||||
benchmark_object.fw()
|
||||
if bw:
|
||||
benchmark_object.bw()
|
||||
|
||||
if cuda_graph:
|
||||
run_one()
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
run_one()
|
||||
|
||||
def run_one():
|
||||
g.replay()
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn()",
|
||||
globals={
|
||||
"fn": run_one,
|
||||
},
|
||||
label=label,
|
||||
description=k,
|
||||
sub_label=benchmark_object.sub_label,
|
||||
)
|
||||
|
||||
|
||||
class AttentionDecodingFlashDecoding:
|
||||
OP: Any = xops.fmha.flash.FwOp
|
||||
|
||||
def __init__(
|
||||
self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool
|
||||
) -> None:
|
||||
dtype = torch.float16
|
||||
self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}"
|
||||
self.label = "attn_decoding"
|
||||
self.shapes = (B, Mq, Mkv, Hq, Hkv, K)
|
||||
|
||||
assert Hkv <= Hq
|
||||
assert Hq % Hkv == 0
|
||||
self.q = torch.randn(
|
||||
[B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
)
|
||||
self.k = torch.randn(
|
||||
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
).expand(-1, -1, -1, Hq // Hkv, -1)
|
||||
self.v = torch.randn(
|
||||
[B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw
|
||||
).expand(-1, -1, -1, Hq // Hkv, -1)
|
||||
|
||||
if Hq == Hkv:
|
||||
self.q = self.q[:, :, :, 0]
|
||||
self.k = self.k[:, :, :, 0]
|
||||
self.v = self.v[:, :, :, 0]
|
||||
if Hkv == 1:
|
||||
self.q = self.q[:, :, 0]
|
||||
self.k = self.k[:, :, 0]
|
||||
self.v = self.v[:, :, 0]
|
||||
|
||||
def fw(self) -> None:
|
||||
xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP)
|
||||
|
||||
|
||||
class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding):
|
||||
OP = xops.fmha.triton_splitk.FwOp
|
||||
|
||||
|
||||
class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding):
|
||||
def fw(self) -> None:
|
||||
B, Mq, Mkv, Hq, Hkv, K = self.shapes
|
||||
scale = 1 / K**0.5
|
||||
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
|
||||
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
||||
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
|
||||
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
|
||||
return attn @ v
|
||||
|
||||
|
||||
BENCHMARKS = {
|
||||
"pytorch": AttentionDecodingPyTorchRepeat,
|
||||
"flash-decoding": AttentionDecodingFlashDecoding,
|
||||
"triton_splitK": AttentionDecodingSplitKV,
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding):
|
||||
def fw(self) -> None:
|
||||
q, k, v = self.q, self.k, self.v
|
||||
if q.ndim == 5:
|
||||
B, Mq, H1, H2, K = q.shape
|
||||
B, Mkv, H1, H2, K = k.shape
|
||||
q = q.reshape([B, Mq, H1 * H2, K])
|
||||
k = k[:, :, :, 0]
|
||||
v = v[:, :, :, 0]
|
||||
return flash_attn.flash_attn_func(q, k, v)
|
||||
|
||||
BENCHMARKS[
|
||||
f"flash-attention@{flash_attn.__version__}"
|
||||
] = AttentionDecodingFlashAttention
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def attn_decoding(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
cuda_graph=True,
|
||||
functions=BENCHMARKS,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time)
|
||||
1065
pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py
Normal file
1065
pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py
Normal file
File diff suppressed because it is too large
Load Diff
137
pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py
Normal file
137
pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.attention.attention_mask import AttentionMask
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
|
||||
SHAPES = [
|
||||
(8, 128, 2096),
|
||||
(8, 1024, 256),
|
||||
(12, 512, 1024),
|
||||
(128, 128, 512),
|
||||
(8, 2048, 4096),
|
||||
(16, 1024, 5120),
|
||||
(512, 128, 2560),
|
||||
]
|
||||
|
||||
BLOCK_SIZES = [128]
|
||||
N_HEADS = [8, 32]
|
||||
|
||||
|
||||
def bench_blocksparse_compare(backward: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
use_amp = True
|
||||
_use_cuda = True
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
datatype = "fp16" if dtype == torch.float16 else "fp32"
|
||||
results: Dict[str, Any] = {}
|
||||
results_mem: Dict[str, Any] = {}
|
||||
for BS in BLOCK_SIZES:
|
||||
for heads in N_HEADS:
|
||||
for B, M, K in SHAPES:
|
||||
q = torch.randn(
|
||||
(B, heads, M, K // heads),
|
||||
requires_grad=backward,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
k = q
|
||||
v = q
|
||||
|
||||
# Mask with causal flag
|
||||
m_att_mask = AttentionMask.make_causal(
|
||||
M, M, device=device, dtype=dtype
|
||||
)
|
||||
# Custom causal tensor mask
|
||||
m_custom = torch.triu(
|
||||
torch.ones(M, M, device=device, dtype=dtype) * float("-inf"),
|
||||
diagonal=1,
|
||||
)
|
||||
|
||||
def blocksparse_attention():
|
||||
with torch.cuda.amp.autocast(enabled=use_amp):
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS
|
||||
)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def sdp_attention():
|
||||
with torch.cuda.amp.autocast(enabled=use_amp):
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=m_custom, block_size=BS
|
||||
)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(blocksparse_attention, f"blocksparse - fw{bw}"),
|
||||
TestCase(sdp_attention, f"standard sdp - fw{bw}"),
|
||||
]:
|
||||
if _use_cuda:
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
|
||||
if _use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
max_memory = torch.cuda.max_memory_allocated() / 2**20
|
||||
else:
|
||||
max_memory = -1
|
||||
|
||||
key = f"B={B},M={M},K={K},NH={heads}"
|
||||
|
||||
if key not in results_mem:
|
||||
results_mem[key] = {}
|
||||
results_mem[key][testcase.name] = f"{max_memory:.1f}"
|
||||
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
||||
units="runtime in ms",
|
||||
)
|
||||
pretty_print(
|
||||
results_mem,
|
||||
title=f"\n --- Type: {datatype} Block Size: {BS} --- ",
|
||||
units="peak memory usage in MB",
|
||||
)
|
||||
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}",
|
||||
units="runtime in ms",
|
||||
dash_key="torch",
|
||||
legend_loc="upper left",
|
||||
)
|
||||
pretty_plot(
|
||||
results_mem,
|
||||
title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}",
|
||||
units="peak memory usage in MB",
|
||||
dash_key="torch",
|
||||
legend_loc="upper left",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
bench_blocksparse_compare(bw)
|
||||
258
pkgs/xformers/benchmarks/benchmark_core.py
Normal file
258
pkgs/xformers/benchmarks/benchmark_core.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
|
||||
from xformers.components.attention.core import (
|
||||
SparseCS,
|
||||
_create_random_sparsity,
|
||||
_matmul_with_mask,
|
||||
_softmax,
|
||||
bmm,
|
||||
)
|
||||
|
||||
MIN_RUN_TIME = 1
|
||||
SHAPES = [[8, 8], [256, 1024], [128, 256]]
|
||||
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
|
||||
|
||||
|
||||
def bench_sddmm():
|
||||
min_run_time = MIN_RUN_TIME
|
||||
SPARSITIES = [0.95, 0.98, 0.99, 0.995, 0.999]
|
||||
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for B, M, K in zip(*SHAPES):
|
||||
a = torch.rand(B, M, K, device=device)
|
||||
b = torch.rand(B, M, K, device=device)
|
||||
|
||||
for backend, prob in itertools.product(
|
||||
["coo_pytorch", "csr_sputnik", "csr_ge"], SPARSITIES
|
||||
):
|
||||
mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), prob)
|
||||
aa = a
|
||||
bb = b
|
||||
if "csr" in backend:
|
||||
mask = SparseCS(mask, device)
|
||||
aa = a
|
||||
bb = b
|
||||
row_indices = mask.row_indices
|
||||
row_offsets = mask.row_offsets
|
||||
column_indices = mask.column_indices
|
||||
if "_ge" in backend:
|
||||
fn = torch.ops.xformers.csr_sddmm
|
||||
else:
|
||||
fn = torch.ops.xformers.sddmm_sputnik
|
||||
fn_str = "fn(a, b, row_indices, row_offsets, column_indices)"
|
||||
else:
|
||||
mask = mask.to_sparse().to(device)
|
||||
_, row_offsets, column_indices = mask.indices().int().unbind()
|
||||
row_offsets = row_offsets.contiguous()
|
||||
column_indices = column_indices.contiguous()
|
||||
row_indices = row_offsets
|
||||
|
||||
bb = b.transpose(-2, -1)
|
||||
fn = _matmul_with_mask
|
||||
fn_str = "fn(a, b, mask)"
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=fn_str,
|
||||
globals={
|
||||
"a": aa,
|
||||
"b": bb,
|
||||
"mask": mask,
|
||||
"row_indices": row_indices,
|
||||
"row_offsets": row_offsets,
|
||||
"column_indices": column_indices,
|
||||
"fn": fn,
|
||||
},
|
||||
label="sddmm",
|
||||
sub_label=f"sparsity {backend}: {prob:0.4f}",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
def bench_matmul_with_mask():
|
||||
min_run_time = MIN_RUN_TIME
|
||||
prob = 0.9
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for B, M, K in zip(*SHAPES):
|
||||
a = torch.rand(B, M, K, device=device)
|
||||
b = torch.rand(B, K, M, device=device)
|
||||
mask = torch.rand(B, M, M, device=device) > prob
|
||||
|
||||
results.extend(
|
||||
[
|
||||
benchmark.Timer(
|
||||
stmt="_matmul_with_mask(a, b, mask)",
|
||||
globals={
|
||||
"a": a,
|
||||
"b": b,
|
||||
"mask": None,
|
||||
"_matmul_with_mask": _matmul_with_mask,
|
||||
},
|
||||
label="matmul_with_mask",
|
||||
sub_label="dense",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time),
|
||||
benchmark.Timer(
|
||||
stmt="_matmul_with_mask(a, b, mask)",
|
||||
globals={
|
||||
"a": a,
|
||||
"b": b,
|
||||
"mask": mask,
|
||||
"_matmul_with_mask": _matmul_with_mask,
|
||||
},
|
||||
label="matmul_with_mask",
|
||||
sub_label="dense with masking",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time),
|
||||
]
|
||||
)
|
||||
for sputnik, prob in itertools.product([False, True], SPARSITIES):
|
||||
mask = _create_random_sparsity(
|
||||
torch.ones(B, M, M, dtype=torch.bool, device=device), prob
|
||||
)
|
||||
aa = a
|
||||
bb = b
|
||||
if sputnik:
|
||||
mask = SparseCS(mask, device)
|
||||
aa = a
|
||||
bb = b.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
else:
|
||||
mask = mask.to_sparse()
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="_matmul_with_mask(a, b, mask)",
|
||||
globals={
|
||||
"a": aa,
|
||||
"b": bb,
|
||||
"mask": mask,
|
||||
"_matmul_with_mask": _matmul_with_mask,
|
||||
},
|
||||
label="matmul_with_mask",
|
||||
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
def bench_softmax():
|
||||
min_run_time = MIN_RUN_TIME
|
||||
prob = 0.9
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for B, M, K in zip(*SHAPES):
|
||||
a = torch.rand(B, M, M, device=device)
|
||||
a[a < prob] = 0
|
||||
|
||||
results.extend(
|
||||
[
|
||||
benchmark.Timer(
|
||||
stmt="_softmax(a)",
|
||||
globals={
|
||||
"a": a,
|
||||
"_softmax": _softmax,
|
||||
},
|
||||
label="softmax",
|
||||
sub_label="dense",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time),
|
||||
]
|
||||
)
|
||||
for sputnik, prob in itertools.product([False, True], SPARSITIES):
|
||||
a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob)
|
||||
if sputnik:
|
||||
a = SparseCS(a, device)
|
||||
else:
|
||||
a = a.to_sparse()
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="_softmax(a)",
|
||||
globals={
|
||||
"a": a,
|
||||
"_softmax": _softmax,
|
||||
},
|
||||
label="softmax",
|
||||
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
def bench_bmm():
|
||||
min_run_time = MIN_RUN_TIME
|
||||
prob = 0.9
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for B, M, K in zip(*SHAPES):
|
||||
a = torch.rand(B, M, M, device=device)
|
||||
a[a < prob] = 0
|
||||
b = torch.rand(B, M, K, device=device)
|
||||
|
||||
results.extend(
|
||||
[
|
||||
benchmark.Timer(
|
||||
stmt="bmm(a, b)",
|
||||
globals={
|
||||
"a": a,
|
||||
"b": b,
|
||||
"bmm": bmm,
|
||||
},
|
||||
label="bmm",
|
||||
sub_label="dense",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time),
|
||||
]
|
||||
)
|
||||
for sputnik, prob in itertools.product([False, True], SPARSITIES):
|
||||
a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob)
|
||||
bb = b
|
||||
if sputnik:
|
||||
a = SparseCS(a, device)
|
||||
bb = b
|
||||
else:
|
||||
a = a.to_sparse()
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="bmm(a, b)",
|
||||
globals={
|
||||
"a": a,
|
||||
"b": bb,
|
||||
"bmm": bmm,
|
||||
},
|
||||
label="bmm",
|
||||
sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
bench_sddmm()
|
||||
bench_matmul_with_mask()
|
||||
bench_softmax()
|
||||
bench_bmm()
|
||||
241
pkgs/xformers/benchmarks/benchmark_indexing.py
Normal file
241
pkgs/xformers/benchmarks/benchmark_indexing.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES_IADD = list(
|
||||
product_dict(
|
||||
shape=[
|
||||
(int(48 * 0.6), 48, 1, 257 * 1536),
|
||||
(int(48 * 0.6), 48, 257, 1536),
|
||||
],
|
||||
scaling=[False, True],
|
||||
dtype=[torch.half],
|
||||
)
|
||||
) + list(
|
||||
product_dict(
|
||||
shape=[
|
||||
# Format: [B_src, B_inp, M, D]
|
||||
(int(192 * 0.6), 192, 50, 1536),
|
||||
(int(48 * 257 * 0.6), 257 * 48, 1, 1536),
|
||||
(int(192 * 50 * 0.6), 192 * 50, 1, 1536),
|
||||
(int(16 * 257 * 0.6), 48 * 257, 1, 1536),
|
||||
],
|
||||
scaling=[False],
|
||||
dtype=[torch.half],
|
||||
)
|
||||
)
|
||||
|
||||
CASES_ISELECT = list(
|
||||
product_dict(
|
||||
batches=[((48, 257), (50, 192))],
|
||||
D=[1536],
|
||||
keep_ratio=[0.6],
|
||||
dtype=[torch.half],
|
||||
)
|
||||
)
|
||||
|
||||
DTYPE2STR = {
|
||||
torch.bfloat16: "b16",
|
||||
torch.half: "f16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
|
||||
def _setup_test(functions, fw: bool = False, bw: bool = False, **kwargs):
|
||||
for k, benchmark_cls in functions.items():
|
||||
benchmark_object = benchmark_cls(**kwargs, bw=bw)
|
||||
label = benchmark_object.label
|
||||
label += "fw" if fw else ""
|
||||
label += "bw" if bw else ""
|
||||
|
||||
def run_one():
|
||||
if fw:
|
||||
benchmark_object.fw()
|
||||
if bw:
|
||||
benchmark_object.bw()
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn()",
|
||||
globals={
|
||||
"fn": run_one,
|
||||
},
|
||||
label=label,
|
||||
description=k,
|
||||
sub_label=benchmark_object.sub_label,
|
||||
)
|
||||
|
||||
|
||||
class ScaledIndexAddBenchmark:
|
||||
def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None:
|
||||
B_src, B_out, M, D = shape
|
||||
torch.manual_seed(B_out + B_src)
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}"
|
||||
self.label = "scaled_index_add"
|
||||
self.alpha = 0.73
|
||||
|
||||
self.inp = torch.randn(
|
||||
[B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw
|
||||
)
|
||||
self.src = torch.randn(
|
||||
[B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw
|
||||
)
|
||||
self.scaling = (
|
||||
torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw)
|
||||
if scaling
|
||||
else None
|
||||
)
|
||||
self.index = torch.tensor(
|
||||
[i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda"
|
||||
)
|
||||
self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype)
|
||||
self.out = torch.Tensor()
|
||||
|
||||
def fw(self) -> None:
|
||||
self.out = xops.scaled_index_add(
|
||||
input=self.inp.clone(),
|
||||
index=self.index,
|
||||
source=self.src,
|
||||
scaling=self.scaling,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
def bw(self):
|
||||
self.inp.grad = None
|
||||
self.src.grad = None
|
||||
if self.scaling is not None:
|
||||
self.scaling.grad = None
|
||||
self.out.backward(self.grad, retain_graph=True)
|
||||
|
||||
|
||||
class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark):
|
||||
def fw(self) -> None:
|
||||
src_scaled = self.src
|
||||
if self.scaling is not None:
|
||||
src_scaled * self.scaling.unsqueeze(0).unsqueeze(0)
|
||||
self.out = self.inp.index_add(
|
||||
dim=0,
|
||||
source=src_scaled,
|
||||
index=self.index,
|
||||
alpha=self.alpha,
|
||||
)
|
||||
|
||||
|
||||
def scaled_index_add_fw(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
functions={
|
||||
"xformers": ScaledIndexAddBenchmark,
|
||||
"pytorch": ScaledIndexAddBenchmarkBaseline,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def scaled_index_add_fwbw(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
bw=True,
|
||||
functions={
|
||||
"xformers": ScaledIndexAddBenchmark,
|
||||
"pytorch": ScaledIndexAddBenchmarkBaseline,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class IndexSelectBenchmark:
|
||||
def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None:
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}"
|
||||
self.label = "index_select"
|
||||
srcs = [torch.randn([B, seqlen * D]) for (B, seqlen) in batches]
|
||||
src = torch.cat([s.view([-1, D]) for s in srcs], dim=0).cuda().to(dtype)
|
||||
src.requires_grad_(True)
|
||||
|
||||
indices = []
|
||||
sources = []
|
||||
elements_i = 0
|
||||
for source_i in srcs:
|
||||
index = [i for i in range(source_i.shape[0])]
|
||||
random.Random(source_i.shape[0]).shuffle(index)
|
||||
indices.append(
|
||||
torch.tensor(
|
||||
index[: int(keep_ratio * source_i.shape[0])],
|
||||
dtype=torch.int64,
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
sources.append(
|
||||
src[
|
||||
elements_i : elements_i + source_i.shape[0] * source_i.shape[1] // D
|
||||
].reshape(source_i.shape)
|
||||
)
|
||||
elements_i += source_i.shape[0] * source_i.shape[1] // D
|
||||
self.indices, self.sources, self.src = indices, sources, src
|
||||
self.out = torch.Tensor()
|
||||
|
||||
def fw(self) -> None:
|
||||
self.out = xops.index_select_cat(self.sources, self.indices)
|
||||
|
||||
def bw(self):
|
||||
self.src.grad = None
|
||||
self.out.backward(self.out, retain_graph=True)
|
||||
|
||||
|
||||
class IndexSelectBenchmarkBaseline(IndexSelectBenchmark):
|
||||
def fw(self) -> None:
|
||||
self.out = torch.cat(
|
||||
[s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0
|
||||
)
|
||||
|
||||
|
||||
def index_select_fw(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
functions={
|
||||
"xformers": IndexSelectBenchmark,
|
||||
"pytorch": IndexSelectBenchmarkBaseline,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def index_select_fwbw(**kwargs):
|
||||
yield from _setup_test(
|
||||
**kwargs,
|
||||
fw=True,
|
||||
bw=True,
|
||||
functions={
|
||||
"xformers": IndexSelectBenchmark,
|
||||
"pytorch": IndexSelectBenchmarkBaseline,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(scaled_index_add_fw, CASES_IADD, min_run_time=min_run_time)
|
||||
benchmark_main_helper(scaled_index_add_fwbw, CASES_IADD, min_run_time=min_run_time)
|
||||
benchmark_main_helper(index_select_fw, CASES_ISELECT, min_run_time=min_run_time)
|
||||
benchmark_main_helper(index_select_fwbw, CASES_ISELECT, min_run_time=min_run_time)
|
||||
316
pkgs/xformers/benchmarks/benchmark_mem_eff_attention.py
Normal file
316
pkgs/xformers/benchmarks/benchmark_mem_eff_attention.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops
|
||||
import xformers.ops.fmha as fmha
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def create_attn_bias(
|
||||
bias_type,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
device,
|
||||
dtype,
|
||||
bias_requires_grad: bool = False,
|
||||
):
|
||||
NoneType = type(None)
|
||||
if bias_type is NoneType:
|
||||
return None
|
||||
if bias_type is torch.Tensor:
|
||||
attn_bias = torch.randn((1, 1, q_len, kv_len), device=device, dtype=dtype)
|
||||
return attn_bias.expand(batch_size, num_heads, q_len, kv_len)
|
||||
if bias_type is xformers.ops.LowerTriangularMask:
|
||||
return bias_type()
|
||||
assert False, f"Unsupported bias type: {bias_type}"
|
||||
|
||||
|
||||
def ref_attention_bmk(q, k, v, attn_bias=None, p=0.0):
|
||||
if isinstance(attn_bias, xformers.ops.AttentionMask):
|
||||
attn_bias = (
|
||||
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
|
||||
.to(q)
|
||||
.squeeze()
|
||||
)
|
||||
q = q * (1.0 / q.shape[-1] ** 0.5)
|
||||
if attn_bias is None:
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
else:
|
||||
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
|
||||
# but faster, and is what is used in PyTorch now
|
||||
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
|
||||
attn = attn.softmax(-1)
|
||||
if p > 0:
|
||||
attn = torch.nn.functional.dropout(attn, p=p)
|
||||
return attn @ v
|
||||
|
||||
|
||||
def ref_attention(q, k, v, attn_bias, p=0.0):
|
||||
assert q.ndim == 4
|
||||
B, M, H, K = q.shape
|
||||
|
||||
def T(t):
|
||||
return t.permute((0, 2, 1, 3)).reshape(
|
||||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||||
)
|
||||
|
||||
if isinstance(attn_bias, torch.Tensor):
|
||||
attn_bias = attn_bias.reshape(B * H, M, M)
|
||||
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias, p)
|
||||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||||
return out.permute((0, 2, 1, 3))
|
||||
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
|
||||
SHAPES = [
|
||||
# ViT
|
||||
(384, 197, 1, 88),
|
||||
(384, 197, 1, 80),
|
||||
(384, 197, 1, 64),
|
||||
(1024, 197, 1, 88),
|
||||
(1024, 197, 1, 80),
|
||||
(1024, 197, 1, 64),
|
||||
# ViT-Huge
|
||||
(32 * 16, 197, 1, 80),
|
||||
(32, 197, 16, 80),
|
||||
(32, 197, 16, 64),
|
||||
(32, 197, 16, 128),
|
||||
# ViT-Giant
|
||||
(16 * 16, 197, 1, 88),
|
||||
(16, 197, 16, 88),
|
||||
(16, 197, 16, 64),
|
||||
(16, 197, 16, 128),
|
||||
# FB models
|
||||
(1024, 82, 8, 64),
|
||||
(150, 256, 16, 64),
|
||||
(64, 256, 12, 64),
|
||||
# Stable diffusion (https://github.com/huggingface/diffusers/pull/532)
|
||||
(1, 4096, 16, 40), # 512x512
|
||||
(1, 16384, 16, 40), # 1024x1024
|
||||
(1, 4096, 16, 80),
|
||||
(1, 16384, 16, 80),
|
||||
# + bs4
|
||||
(4, 4096, 16, 40),
|
||||
(4, 16384, 16, 40),
|
||||
(4, 4096, 16, 80),
|
||||
(4, 16384, 16, 80),
|
||||
# ParlAI model
|
||||
(256, 4096, 16, 64),
|
||||
# Zetta B M H K
|
||||
(8, 2048, 20, 128),
|
||||
# LLaMa 70b - mp=8/16
|
||||
*sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])),
|
||||
*sorted(
|
||||
itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256])
|
||||
),
|
||||
]
|
||||
|
||||
OPS = [
|
||||
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
|
||||
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
|
||||
# TODO: Triton is not stable: it can trigger Illegal Memory Accesses
|
||||
# and its performance varies a lot between runs.
|
||||
# (xformers.ops.fmha.triton.FwOp, xformers.ops.fmha.triton.BwOp),
|
||||
]
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
shape=SHAPES,
|
||||
num_threads=NUM_THREADS,
|
||||
dropout_p=[0.0],
|
||||
attn_bias_cfg=[(type(None), False)],
|
||||
dtype=[torch.half],
|
||||
)
|
||||
)
|
||||
|
||||
# Add more cases with some variations
|
||||
for c in CASES.copy():
|
||||
c = c.copy()
|
||||
c.update(
|
||||
random.Random(str(c["shape"])).choice(
|
||||
[
|
||||
{"dropout_p": 0.3},
|
||||
{"attn_bias_cfg": (torch.Tensor, False)},
|
||||
{"attn_bias_cfg": (torch.Tensor, True)},
|
||||
{"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)},
|
||||
{"dtype": torch.bfloat16},
|
||||
{"dtype": torch.float},
|
||||
]
|
||||
)
|
||||
)
|
||||
CASES.append(c)
|
||||
|
||||
|
||||
def create_tensors(shape, dtype, requires_grad=False):
|
||||
B, M, H, K = shape
|
||||
qkv = torch.rand(
|
||||
[B, M, 3, H, K], device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
q, k, v = xformers.ops.unbind(qkv, 2)
|
||||
return qkv, q, k, v
|
||||
|
||||
|
||||
def mem_eff_attention_fw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
|
||||
B, M, H, K = shape
|
||||
_, q, k, v = create_tensors(shape, dtype)
|
||||
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
||||
if attn_bias_requires_grad:
|
||||
return
|
||||
bias = create_attn_bias(
|
||||
attn_bias_type,
|
||||
batch_size=B,
|
||||
num_heads=H,
|
||||
q_len=M,
|
||||
kv_len=M,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
bias_requires_grad=attn_bias_requires_grad,
|
||||
)
|
||||
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
||||
|
||||
dtype_str = {
|
||||
torch.bfloat16: "b16",
|
||||
torch.half: "f16",
|
||||
torch.float: "f32",
|
||||
}[dtype]
|
||||
sub_label = (
|
||||
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
|
||||
f"BiasT={attn_bias_type.__name__}"
|
||||
)
|
||||
|
||||
has_run = False
|
||||
for fw_op, bw_op in OPS:
|
||||
if not fw_op.supports(inp):
|
||||
continue
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias, p)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": inp.attn_bias,
|
||||
"p": dropout_p,
|
||||
"fn": partial(
|
||||
xformers.ops.memory_efficient_attention, op=(fw_op, bw_op)
|
||||
),
|
||||
},
|
||||
label=f"attention (attn_bias={attn_bias_type})",
|
||||
description=fw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
has_run = True
|
||||
|
||||
if not has_run:
|
||||
return
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias, p)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": inp.attn_bias,
|
||||
"p": dropout_p,
|
||||
"fn": ref_attention,
|
||||
},
|
||||
label=f"attention (attn_bias={attn_bias_type})",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
|
||||
def mem_eff_attention_bw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype):
|
||||
B, M, H, K = shape
|
||||
qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True)
|
||||
|
||||
attn_bias_type, attn_bias_requires_grad = attn_bias_cfg
|
||||
bias = create_attn_bias(
|
||||
attn_bias_type,
|
||||
batch_size=B,
|
||||
num_heads=H,
|
||||
q_len=M,
|
||||
kv_len=M,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
bias_requires_grad=attn_bias_requires_grad,
|
||||
)
|
||||
inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p)
|
||||
|
||||
dtype_str = {
|
||||
torch.bfloat16: "b16",
|
||||
torch.half: "f16",
|
||||
torch.float: "f32",
|
||||
}[dtype]
|
||||
sub_label = (
|
||||
f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, "
|
||||
f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}"
|
||||
)
|
||||
|
||||
has_run = False
|
||||
for fw_op, bw_op in OPS:
|
||||
if not fw_op.supports(inp) or not bw_op.supports(inp):
|
||||
continue
|
||||
has_run = True
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op)
|
||||
)
|
||||
grad_benchmark = torch.ones_like(q)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": out,
|
||||
"grad": grad_benchmark,
|
||||
},
|
||||
label=f"attention backward (attn_bias={attn_bias_type})",
|
||||
description=bw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
del out
|
||||
|
||||
if not has_run:
|
||||
return
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": ref_attention(q, k, v, inp.attn_bias, dropout_p),
|
||||
"grad": grad_benchmark,
|
||||
},
|
||||
label=f"attention backward (attn_bias={attn_bias_type})",
|
||||
description="vanilla",
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(mem_eff_attention_fw, CASES, min_run_time=min_run_time)
|
||||
benchmark_main_helper(mem_eff_attention_bw, CASES, min_run_time=min_run_time)
|
||||
187
pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Normal file
187
pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Normal file
@@ -0,0 +1,187 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops
|
||||
import xformers.ops.fmha as fmha
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
# Run with
|
||||
# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet
|
||||
# The baselines for these benchmarks are really slow because there is
|
||||
# so much padding in the inputs, so there is no point running them.
|
||||
|
||||
|
||||
def ref_attention_bmk(q, k, v, attn_bias=None):
|
||||
if isinstance(attn_bias, xformers.ops.AttentionMask):
|
||||
attn_bias = (
|
||||
attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1]))
|
||||
.to(q)
|
||||
.squeeze()
|
||||
)
|
||||
q = q * (1.0 / q.shape[-1] ** 0.5)
|
||||
if attn_bias is None:
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
else:
|
||||
# equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v
|
||||
# but faster, and is what is used in PyTorch now
|
||||
attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1))
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ v
|
||||
|
||||
|
||||
def ref_attention(q, k, v, attn_bias):
|
||||
assert q.ndim == 4
|
||||
|
||||
def T(t):
|
||||
return t.permute((0, 2, 1, 3)).reshape(
|
||||
[t.shape[0] * t.shape[2], t.shape[1], t.shape[3]]
|
||||
)
|
||||
|
||||
out = ref_attention_bmk(T(q), T(k), T(v), attn_bias)
|
||||
out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]])
|
||||
return out.permute((0, 2, 1, 3))
|
||||
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
|
||||
|
||||
OPS = [
|
||||
xformers.ops.fmha.cutlass.FwOp,
|
||||
xformers.ops.fmha.decoder.FwOp,
|
||||
]
|
||||
|
||||
KV_SHAPES = [
|
||||
# list of n_keys, padding_length, batchsize
|
||||
(2, 64, 3),
|
||||
(32, 1024, 500),
|
||||
(1000, 1024, 2),
|
||||
(8000, 8192, 1),
|
||||
(240, 256, 32),
|
||||
(2048, 2 * 1024, 4),
|
||||
(4096 * 2, 8 * 1024, 1),
|
||||
]
|
||||
|
||||
N_HEADS = [8, 16, 64]
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
kv_shape=KV_SHAPES,
|
||||
n_heads=N_HEADS,
|
||||
num_threads=NUM_THREADS,
|
||||
multiquery=[True, False],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def mem_eff_attention_decoder(
|
||||
kv_shape, n_heads: int, num_threads: int, multiquery: bool
|
||||
):
|
||||
n_keys, padding, B = kv_shape
|
||||
torch.manual_seed(42)
|
||||
k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist()
|
||||
K = 128
|
||||
|
||||
q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
if multiquery:
|
||||
k = torch.rand(
|
||||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
|
||||
).expand(1, B * padding, n_heads, K)
|
||||
v = torch.rand(
|
||||
1, B * padding, 1, K, device=device, dtype=torch.bfloat16
|
||||
).expand(1, B * padding, n_heads, K)
|
||||
else:
|
||||
k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16)
|
||||
|
||||
bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
|
||||
q_seqlen=[1] * B,
|
||||
kv_seqlen=k_seqlen,
|
||||
kv_padding=padding,
|
||||
)
|
||||
|
||||
sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads"
|
||||
if multiquery:
|
||||
sub_label += "-mq"
|
||||
|
||||
has_run = False
|
||||
for fw_op in OPS:
|
||||
inp = fmha.Inputs(q, k, v, attn_bias=bias)
|
||||
if not fw_op.supports(inp):
|
||||
continue
|
||||
|
||||
fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": bias,
|
||||
"fn": fn,
|
||||
},
|
||||
label="attention",
|
||||
description=fw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
fn(q, k, v, bias)
|
||||
yield benchmark.Timer(
|
||||
stmt="graph.replay()",
|
||||
globals={
|
||||
"graph": graph,
|
||||
},
|
||||
label="cuda graphed attention",
|
||||
description=fw_op.NAME,
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
has_run = True
|
||||
|
||||
if not has_run:
|
||||
return
|
||||
|
||||
RUN_BASELINES = False
|
||||
if RUN_BASELINES:
|
||||
yield benchmark.Timer(
|
||||
stmt="fn(q, k, v, attn_bias)",
|
||||
globals={
|
||||
"q": q,
|
||||
"k": k,
|
||||
"v": v,
|
||||
"attn_bias": bias,
|
||||
"fn": ref_attention,
|
||||
},
|
||||
label="attention",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time)
|
||||
127
pkgs/xformers/benchmarks/benchmark_mlp.py
Normal file
127
pkgs/xformers/benchmarks/benchmark_mlp.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import Activation
|
||||
from xformers.components.feedforward import MLP, FusedMLP
|
||||
|
||||
SHAPES = [
|
||||
(8, 256, 512),
|
||||
(8, 512, 1024),
|
||||
(4, 1024, 1024),
|
||||
(2, 2048, 2048),
|
||||
(1, 2048, 4096),
|
||||
(1, 1024, 12288),
|
||||
]
|
||||
|
||||
HIDDEN_LAYER_MULTIPLIER = [4]
|
||||
|
||||
|
||||
def bench_MLP(backward: bool, bias: bool, dropout: float, activation: Activation):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
for hlm in HIDDEN_LAYER_MULTIPLIER:
|
||||
fused_mlp = FusedMLP(
|
||||
dim_model=K,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
hidden_layer_multiplier=hlm,
|
||||
bias=bias,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
standard_mlp = MLP(
|
||||
dim_model=K,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
hidden_layer_multiplier=hlm,
|
||||
bias=bias,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
a = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def mlp_standard():
|
||||
y = standard_mlp(a)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def mlp_fused():
|
||||
y = fused_mlp(a)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(
|
||||
mlp_standard,
|
||||
"standard - {} - {} bias - {} drop - fw{}".format(
|
||||
activation,
|
||||
"no" if not bias else "",
|
||||
dropout,
|
||||
"+bw" if backward else "",
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
mlp_fused,
|
||||
"fused - {} - {} bias - {} drop - fw{}".format(
|
||||
activation,
|
||||
"no" if not bias else "",
|
||||
dropout,
|
||||
"+bw" if backward else "",
|
||||
),
|
||||
),
|
||||
]:
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
key = f"{B} x {M} x {K} - {hlm}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {dtype} --- ",
|
||||
units="runtime in ms, lower is better. BMK - mul: ",
|
||||
)
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"MLP-{activation}-FW{bw}-{dtype}",
|
||||
units="runtime in ms, lower is better",
|
||||
dash_key="torch",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Get the user requests
|
||||
parser = argparse.ArgumentParser("Benchmark MLP")
|
||||
parser.add_argument("-act", "--activations", nargs="+", default=[Activation.GeLU])
|
||||
parser.add_argument("-bias", "--bias", nargs="+", default=[False, True])
|
||||
parser.add_argument("-dropout", "--dropout", nargs="+", default=[0.0, 0.1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for bw in [False, True]:
|
||||
for bias in args.bias:
|
||||
for dropout in args.dropout:
|
||||
for activation in args.activations:
|
||||
bench_MLP(
|
||||
backward=bw,
|
||||
bias=bias,
|
||||
dropout=float(dropout),
|
||||
activation=activation,
|
||||
)
|
||||
105
pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py
Normal file
105
pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import MultiHeadDispatch
|
||||
from xformers.components.attention import ScaledDotProduct
|
||||
|
||||
SHAPES = [
|
||||
(8, 384, 128),
|
||||
(8, 784, 512),
|
||||
(4, 1024, 768),
|
||||
(4, 2048, 1024),
|
||||
(2, 2048, 2048),
|
||||
(2, 2048, 4096),
|
||||
(2, 4096, 4096),
|
||||
(1, 2048, 12288),
|
||||
]
|
||||
|
||||
N_HEADS = [4]
|
||||
|
||||
|
||||
def bench_multihead_dispatch(backward: bool, self_attention: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
sa = " (self_attn)" if self_attention else ""
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
for heads in N_HEADS:
|
||||
xf_multi_head = MultiHeadDispatch(
|
||||
dim_model=K,
|
||||
residual_dropout=0.0,
|
||||
num_heads=heads,
|
||||
attention=ScaledDotProduct(),
|
||||
bias=(True, True, True, True),
|
||||
).to(device=device, dtype=dtype)
|
||||
torch_multi_head = nn.MultiheadAttention(
|
||||
embed_dim=K, num_heads=heads, batch_first=True
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
q = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
if self_attention:
|
||||
k = q
|
||||
v = q
|
||||
else:
|
||||
k = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
v = torch.randn(
|
||||
(B, M, K), requires_grad=backward, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def torch_mha():
|
||||
y, _ = torch_multi_head(query=q, key=k, value=v)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def xformers_mha():
|
||||
y = xf_multi_head(query=q, key=k, value=v)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(torch_mha, f"torch - fw{bw}{sa}"),
|
||||
TestCase(xformers_mha, f"xf - fw{bw}{sa}"),
|
||||
]:
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
key = f"B={B}, M={M}, K={K}, N_HEADS={heads}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {dtype} --- ",
|
||||
units="runtime in ms, lower is better",
|
||||
)
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"MHA-FW{bw}-{dtype}",
|
||||
units="runtime in ms, lower is better",
|
||||
dash_key="torch",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
for self_attention in [False, True]:
|
||||
bench_multihead_dispatch(bw, self_attention)
|
||||
99
pkgs/xformers/benchmarks/benchmark_nystrom_utils.py
Normal file
99
pkgs/xformers/benchmarks/benchmark_nystrom_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
|
||||
from xformers.components.attention.utils import iterative_pinv
|
||||
|
||||
MIN_RUN_TIME = 1
|
||||
SHAPES = [[8, 8], [256, 1024], [128, 256]]
|
||||
SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99]
|
||||
|
||||
|
||||
def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]):
|
||||
min_run_time = MIN_RUN_TIME
|
||||
prob = 0.9
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for B, M, K in zip(*SHAPES):
|
||||
a = torch.rand(B, M, M, device=device)
|
||||
a[a < prob] = 0
|
||||
a = torch.softmax(a, dim=-1)
|
||||
|
||||
results.extend(
|
||||
[
|
||||
benchmark.Timer(
|
||||
stmt=f"{inverse_fn.__name__}(a)",
|
||||
globals={
|
||||
"a": a,
|
||||
f"{inverse_fn.__name__}": inverse_fn,
|
||||
},
|
||||
label=f"{inverse_fn.__name__}",
|
||||
sub_label="dense",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time),
|
||||
]
|
||||
)
|
||||
for prob in SPARSITIES:
|
||||
a = torch.rand(B, M, M, device=device)
|
||||
a[a < prob] = 0
|
||||
a = a.to_sparse()
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=f"{inverse_fn.__name__}(a)",
|
||||
globals={
|
||||
"a": a,
|
||||
f"{inverse_fn.__name__}": inverse_fn,
|
||||
},
|
||||
label=f"{inverse_fn.__name__}",
|
||||
sub_label=f"sparsity: {prob:0.2f}",
|
||||
description=f"B={B}, M={M}, K={K}",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
def iterative_pinv_analysis(
|
||||
identity_tolerance: float = 1e-1,
|
||||
pinv_tolerance: float = 5e-1,
|
||||
max_iters: int = 30,
|
||||
plot: bool = True,
|
||||
):
|
||||
|
||||
for i in range(1, 10):
|
||||
B, M = 1, 2**i
|
||||
a = torch.rand(B, M, M)
|
||||
a = torch.softmax(a, dim=-1)
|
||||
|
||||
for n_iter in range(1, max_iters + 1):
|
||||
result = iterative_pinv(a, n_iter=n_iter)
|
||||
expected = torch.linalg.pinv(a)
|
||||
|
||||
result_identity = torch.matmul(a, result)
|
||||
identity = torch.eye(M)
|
||||
|
||||
# Default is frobenius norm.
|
||||
identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1))
|
||||
inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1))
|
||||
|
||||
if (identity_error < identity_tolerance).all() or n_iter == max_iters:
|
||||
print(
|
||||
f"Size {M}, n_iters {n_iter}: \n\t \
|
||||
Final Error from Identity: {identity_error.item()} \n\t \
|
||||
Final Error from linalg.pinv {inverse_error.item()}"
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
iterative_pinv_analysis()
|
||||
bench_inverse(iterative_pinv)
|
||||
bench_inverse(torch.linalg.pinv)
|
||||
83
pkgs/xformers/benchmarks/benchmark_revnet.py
Normal file
83
pkgs/xformers/benchmarks/benchmark_revnet.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.reversible import ReversibleSequence
|
||||
|
||||
SHAPES = [(16384, 32), (2048, 256), (128, 4096)]
|
||||
|
||||
DEPTH = [4, 32, 256]
|
||||
|
||||
|
||||
def bench_revnet(backward: bool):
|
||||
device = torch.device("cuda")
|
||||
bw = "+bw" if backward else ""
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, K in SHAPES:
|
||||
for depth in DEPTH:
|
||||
f = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
|
||||
g = torch.nn.Linear(K, K).to(device=device, dtype=dtype)
|
||||
revseq = ReversibleSequence(
|
||||
torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth)
|
||||
)
|
||||
revseq = revseq.to(device=device, dtype=dtype)
|
||||
|
||||
a = torch.rand(
|
||||
1, B, K, device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
b = torch.rand(
|
||||
1, B, K * 2, device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
|
||||
def normal_step():
|
||||
y = a
|
||||
for _ in range(depth):
|
||||
y = y + f(y)
|
||||
y = y + g(y)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def reversible_step():
|
||||
y = revseq(b)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(normal_step, f"residual - fw{bw}"),
|
||||
TestCase(reversible_step, f"reversible - fw{bw}"),
|
||||
]:
|
||||
time = triton.testing.do_bench(testcase.function)[0]
|
||||
key = f"Batch={B}, Features={K}, Depth={depth}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{time:.2f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=f"\n --- Type: {dtype} --- ",
|
||||
units="runtime in ms, lower is better",
|
||||
)
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"RevNet-FW{bw}-{dtype}",
|
||||
units="runtime in ms, lower is better",
|
||||
dash_key="torch",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
bench_revnet(bw)
|
||||
117
pkgs/xformers/benchmarks/benchmark_sddmm.py
Normal file
117
pkgs/xformers/benchmarks/benchmark_sddmm.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
|
||||
from xformers.components.attention._sputnik_sparse import _csr_to_coo
|
||||
from xformers.components.attention.core import SparseCS, _create_random_sparsity
|
||||
|
||||
MIN_RUN_TIME = 0.2
|
||||
|
||||
|
||||
def _get_fn(backend):
|
||||
if backend == "csr_ge":
|
||||
fn = torch.ops.xformers.csr_sddmm
|
||||
elif backend == "csr_sputnik":
|
||||
fn = torch.ops.xformers.sddmm_sputnik
|
||||
elif backend == "coo_ge":
|
||||
|
||||
def fn(a, b, row_indices, row_offsets, column_indices):
|
||||
row_coo, _ = _csr_to_coo(
|
||||
a.shape[-2], b.shape[-2], row_offsets, column_indices
|
||||
)
|
||||
return torch.ops.xformers.coo_sddmm(
|
||||
a, b, row_indices, row_coo, column_indices
|
||||
)
|
||||
|
||||
elif backend == "csr_to_coo":
|
||||
|
||||
def fn(a, b, row_indices, row_offsets, column_indices):
|
||||
row_coo, _ = _csr_to_coo(
|
||||
a.shape[-2], b.shape[-2], row_offsets, column_indices
|
||||
)
|
||||
return row_coo
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def bench_sddmm(configs):
|
||||
min_run_time = MIN_RUN_TIME
|
||||
|
||||
device = torch.device("cuda")
|
||||
results = []
|
||||
|
||||
for (B, M, K), prob in configs:
|
||||
a = torch.rand(B, M, K, device=device)
|
||||
b = torch.rand(B, M, K, device=device)
|
||||
|
||||
mask = _create_random_sparsity(
|
||||
torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16
|
||||
)
|
||||
aa = a
|
||||
bb = b
|
||||
mask = SparseCS(mask, device)
|
||||
row_indices = mask.row_indices
|
||||
row_offsets = mask.row_offsets
|
||||
column_indices = mask.column_indices
|
||||
|
||||
for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]:
|
||||
|
||||
fn_str = "fn(a, b, row_indices, row_offsets, column_indices)"
|
||||
fn = _get_fn(backend)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=fn_str,
|
||||
globals={
|
||||
"a": aa,
|
||||
"b": bb,
|
||||
"mask": mask,
|
||||
"row_indices": row_indices,
|
||||
"row_offsets": row_offsets,
|
||||
"column_indices": column_indices,
|
||||
"fn": fn,
|
||||
},
|
||||
label="sddmm",
|
||||
sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}",
|
||||
description=backend,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
return results
|
||||
|
||||
|
||||
# batch size 32, for different layers
|
||||
SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)]
|
||||
swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0)))
|
||||
|
||||
# some random values
|
||||
BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)]
|
||||
SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999]
|
||||
basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES))
|
||||
|
||||
# batch size 32 here
|
||||
vit_sizes = [
|
||||
(192, 785, 64), # deit_small_patch8_224
|
||||
(192, 197, 64), # deit_small_patch16_224
|
||||
(384, 785, 64), # deit_base_patch8_224
|
||||
(384, 197, 64), # deit_base_patch16_224
|
||||
]
|
||||
SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97]
|
||||
vit_config = list(itertools.product(vit_sizes, SPARSITIES))
|
||||
|
||||
results = []
|
||||
|
||||
print("Swin Transformer")
|
||||
results += bench_sddmm(swin_t_config)
|
||||
print("ViT")
|
||||
results += bench_sddmm(vit_config)
|
||||
print("Basic cases")
|
||||
results += bench_sddmm(basic_config)
|
||||
160
pkgs/xformers/benchmarks/benchmark_swiglu.py
Normal file
160
pkgs/xformers/benchmarks/benchmark_swiglu.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops.swiglu_op as xsw
|
||||
|
||||
min_run_time = 0.5
|
||||
device = torch.device("cuda")
|
||||
|
||||
SHAPES = [
|
||||
# Format: [inp.shape[0], inp.shape[1], hidden.shape[1]]
|
||||
# ViT-Giant
|
||||
(9456, 1536, 2736),
|
||||
(4440, 1536, 2736),
|
||||
(4728, 1536, 2736),
|
||||
# Some smaller shapes as well
|
||||
(4728, 1536, 1024),
|
||||
# GPT-3 (small)
|
||||
(32768, 2048, 5632),
|
||||
# Chinchilla
|
||||
(32768, 8192, 22016),
|
||||
]
|
||||
|
||||
|
||||
# OP = xsw._SwiGLUDecomposedOp
|
||||
# OP = xsw.SwiGLUFusedOp
|
||||
OP = xsw.SwiGLUPackedFusedOp
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
shape=SHAPES,
|
||||
dtype=[torch.bfloat16, torch.half, "autocast_half"],
|
||||
bias=[True, False],
|
||||
)
|
||||
)
|
||||
|
||||
DTYPE2STR = {
|
||||
torch.bfloat16: "b16 ",
|
||||
torch.half: "f16 ",
|
||||
"autocast_half": "f16.ac",
|
||||
}
|
||||
|
||||
|
||||
def benchmark_swiglu(shape, dtype, bias: bool):
|
||||
if dtype == "autocast_half":
|
||||
inp_dtype, model_dtype, autocast = torch.float, torch.float, True
|
||||
else:
|
||||
inp_dtype, model_dtype, autocast = dtype, dtype, False
|
||||
|
||||
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
||||
module = (
|
||||
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
||||
.to(device)
|
||||
.to(model_dtype)
|
||||
)
|
||||
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
bstr = "bias" if bias else "nobi"
|
||||
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
||||
|
||||
params = module._ordered_params()
|
||||
|
||||
PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else ""
|
||||
yield benchmark.Timer(
|
||||
stmt=f"{PREFIX}fn(x, *args)",
|
||||
globals={
|
||||
"x": x,
|
||||
"args": params,
|
||||
"fn": partial(xsw.swiglu, op=OP),
|
||||
},
|
||||
label="swiglu_fw",
|
||||
description=OP.NAME,
|
||||
sub_label=sub_label,
|
||||
)
|
||||
yield benchmark.Timer(
|
||||
stmt=f"{PREFIX}fn(x, *args)",
|
||||
globals={
|
||||
"x": x,
|
||||
"args": params,
|
||||
"fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp),
|
||||
},
|
||||
label="swiglu_fw",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_swiglu_bw(shape, dtype, bias: bool):
|
||||
if dtype == "autocast_half":
|
||||
inp_dtype, model_dtype = torch.float, torch.float
|
||||
cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16)
|
||||
else:
|
||||
inp_dtype, model_dtype = dtype, dtype
|
||||
cm = nullcontext
|
||||
|
||||
x = torch.randn(shape[:2], device=device, dtype=inp_dtype)
|
||||
x.requires_grad_()
|
||||
module = (
|
||||
xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias)
|
||||
.to(device)
|
||||
.to(model_dtype)
|
||||
)
|
||||
|
||||
dtype_str = DTYPE2STR.get(dtype, dtype)
|
||||
bstr = "bias" if bias else "nobi"
|
||||
sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}"
|
||||
|
||||
params = module._ordered_params()
|
||||
with cm():
|
||||
out = xsw.swiglu(x, *params, op=OP)
|
||||
grad = torch.zeros_like(out)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": out,
|
||||
"grad": grad,
|
||||
},
|
||||
label="swiglu_bw",
|
||||
description=OP.NAME,
|
||||
sub_label=sub_label,
|
||||
)
|
||||
del out
|
||||
|
||||
with cm():
|
||||
out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="out.backward(grad, retain_graph=True)",
|
||||
globals={
|
||||
"out": out,
|
||||
"grad": grad,
|
||||
},
|
||||
label="swiglu_bw",
|
||||
description="eager",
|
||||
sub_label=sub_label,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time)
|
||||
benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time)
|
||||
155
pkgs/xformers/benchmarks/benchmark_transformer.py
Normal file
155
pkgs/xformers/benchmarks/benchmark_transformer.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import itertools
|
||||
from functools import partial, reduce
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import Mlp as TimmMlp
|
||||
from timm.models.vision_transformer import Attention as TimmAttention
|
||||
from timm.models.vision_transformer import Block as TimmBlock
|
||||
from torch.utils import benchmark
|
||||
from utils import benchmark_main_helper
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
|
||||
def replace_module(module: nn.Module, replace_class, factory):
|
||||
if isinstance(module, replace_class):
|
||||
return factory(module)
|
||||
module_output = module
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, replace_module(child, replace_class, factory))
|
||||
del module
|
||||
return module_output
|
||||
|
||||
|
||||
class TimmMemEffAttention(nn.Module):
|
||||
def __init__(self, attn: TimmAttention, op=None):
|
||||
super().__init__()
|
||||
self.op = None
|
||||
self.num_heads = attn.num_heads
|
||||
self.scale = attn.scale
|
||||
|
||||
self.qkv = attn.qkv
|
||||
self.attn_drop = attn.attn_drop
|
||||
self.proj = attn.proj
|
||||
self.proj_drop = attn.proj_drop
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
q, k, v = xops.unbind(qkv, dim=2)
|
||||
|
||||
x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimmSwiGLU(nn.Module):
|
||||
def __init__(self, mlp: TimmMlp, op=None) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = mlp.fc1
|
||||
self.swiglu = xops.SwiGLU(
|
||||
in_features=mlp.fc1.in_features,
|
||||
hidden_features=mlp.fc1.out_features,
|
||||
bias=True,
|
||||
)
|
||||
self.op = op
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module:
|
||||
return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op))
|
||||
|
||||
|
||||
def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module:
|
||||
def _mlp_to_swiglu(block: TimmBlock):
|
||||
block.mlp = TimmSwiGLU(block.mlp, op=op)
|
||||
return block
|
||||
|
||||
return replace_module(model, TimmBlock, _mlp_to_swiglu)
|
||||
|
||||
|
||||
mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp)
|
||||
mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None)
|
||||
|
||||
|
||||
def compose(*fns):
|
||||
def compose2(f, g):
|
||||
return lambda *a, **kw: f(g(*a, **kw))
|
||||
|
||||
return reduce(compose2, fns)
|
||||
|
||||
|
||||
MODELS = [
|
||||
# model_name, model_factory, input_shape
|
||||
("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]),
|
||||
("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]),
|
||||
("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]),
|
||||
("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]),
|
||||
]
|
||||
|
||||
MODIFIERS = [
|
||||
["mlp", lambda x: x],
|
||||
["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
||||
["swiglu", mod_mlp_to_eagr_swiglu],
|
||||
["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu],
|
||||
["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)],
|
||||
]
|
||||
|
||||
|
||||
def product_dict(**kwargs):
|
||||
keys = kwargs.keys()
|
||||
vals = kwargs.values()
|
||||
for instance in itertools.product(*vals):
|
||||
yield dict(zip(keys, instance))
|
||||
|
||||
|
||||
CASES = list(
|
||||
product_dict(
|
||||
model_info=MODELS,
|
||||
dtype=[torch.half],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def benchmark_transformer(model_info, dtype):
|
||||
device = "cuda"
|
||||
|
||||
model_name, model_factory, input_shape = model_info
|
||||
|
||||
inp = torch.randn(input_shape, dtype=dtype, device=device)
|
||||
|
||||
for mod_name, mod_apply in MODIFIERS:
|
||||
model: nn.Module = model_factory()
|
||||
model = mod_apply(model).to(device).to(dtype)
|
||||
|
||||
# Make sure we don't have errors
|
||||
out = model(inp)
|
||||
grad = out.clone()
|
||||
out.backward(grad)
|
||||
|
||||
yield benchmark.Timer(
|
||||
stmt="model(inp).backward(grad)",
|
||||
globals={
|
||||
"model": model,
|
||||
"inp": inp,
|
||||
"grad": grad,
|
||||
},
|
||||
label="fw+bw",
|
||||
description=mod_name,
|
||||
sub_label=model_name,
|
||||
)
|
||||
|
||||
|
||||
benchmark_main_helper(benchmark_transformer, CASES)
|
||||
150
pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py
Normal file
150
pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Benchmark the blocksparse operations:
|
||||
# matrix multiply and softmax
|
||||
|
||||
# Matmul can be of three types:
|
||||
# - Dense x Dense (COO) -> Sparse
|
||||
# - Sparse x Dense -> Dense
|
||||
# - Dense x Sparse -> Dense
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from triton.ops.blocksparse import matmul as blocksparse_matmul
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components.attention.core import SparseCS, _matmul_with_mask
|
||||
|
||||
|
||||
def bench_matmul(dtype: torch.dtype, shapes):
|
||||
results: Dict[str, Any] = {}
|
||||
Z, H = 1, 1
|
||||
|
||||
for M, N, K in shapes:
|
||||
|
||||
modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]]
|
||||
|
||||
for mode, block in modes:
|
||||
# create inputs
|
||||
a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda")
|
||||
b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda")
|
||||
shape = {
|
||||
"sdd": (M, N),
|
||||
"dsd": (a.shape[2], a.shape[3]),
|
||||
"dds": (b.shape[2], b.shape[3]),
|
||||
}[mode]
|
||||
|
||||
# Pre-sparsify everything
|
||||
_layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long)
|
||||
|
||||
# - blocksparse
|
||||
layout = _layout.unsqueeze(0).expand(H, -1, -1)
|
||||
a_triton = (
|
||||
triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a
|
||||
)
|
||||
b_triton = (
|
||||
triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b
|
||||
)
|
||||
bsmm = blocksparse_matmul(
|
||||
layout=layout,
|
||||
block=block,
|
||||
mode=mode,
|
||||
device=torch.device("cuda"),
|
||||
trans_a=False,
|
||||
trans_b=False,
|
||||
)
|
||||
|
||||
# - dense
|
||||
ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a
|
||||
tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b
|
||||
|
||||
# - sparse / sputnik
|
||||
mask = torch.ones_like(a, dtype=torch.float, device="cuda")
|
||||
mask = triton.testing.mask_tensor(mask, layout, block, value=0.0)
|
||||
a_cs = a.flatten(start_dim=0, end_dim=1).to(
|
||||
torch.float32
|
||||
) # Sputnik kernels only handle fp32
|
||||
b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32)
|
||||
a_cs = a_cs.contiguous()
|
||||
b_cs = b_cs.transpose(-2, -1).contiguous()
|
||||
|
||||
if mode == "sdd":
|
||||
b_cs = b_cs.transpose(-2, -1)
|
||||
|
||||
# pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the
|
||||
# `SparseCS` import.
|
||||
sparse_cs_mask = SparseCS(
|
||||
mask.flatten(start_dim=0, end_dim=1).contiguous(),
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# The raw compute steps
|
||||
op_flops = {
|
||||
"sdd": 2 * Z * K * float(layout.sum()) * block * block,
|
||||
"dsd": 2 * Z * N * float(layout.sum()) * block * block,
|
||||
"dds": 2 * Z * M * float(layout.sum()) * block * block,
|
||||
}[
|
||||
mode
|
||||
] * 1e-12 # TFlops
|
||||
|
||||
def torch_step():
|
||||
return torch.matmul(ta, tb)
|
||||
|
||||
def triton_step():
|
||||
return bsmm(a_triton, b_triton)
|
||||
|
||||
def sparse_step():
|
||||
if mode == "sdd":
|
||||
return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask)
|
||||
else:
|
||||
return sparse_cs_mask.spmm(b_cs)
|
||||
|
||||
# Run and measure, report perf in terms of TFlops
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
f"pytorch - {mode} - {block}: ",
|
||||
),
|
||||
TestCase(
|
||||
sparse_step,
|
||||
f"sparse - {mode} - {block}: ",
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
f"triton - {mode} - {block}: ",
|
||||
),
|
||||
]:
|
||||
ms = triton.testing.do_bench(lambda: testcase.function())[0]
|
||||
key = f"M={M}, N={N}, K={K}"
|
||||
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
num_flops = op_flops / ms * 1e3 # Get to TFlop per second
|
||||
results[key][testcase.name] = f"{num_flops:.1f}"
|
||||
print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops")
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title="\n ------------- Type: {} -------------".format(dtype),
|
||||
units="TFlops/s",
|
||||
)
|
||||
|
||||
pretty_plot(
|
||||
results,
|
||||
title=f"Sparse/Blocksparse throughput - {dtype}",
|
||||
filename=f"blocksparse_{dtype}.png",
|
||||
dash_key="pytorch",
|
||||
units="TFlops/s",
|
||||
)
|
||||
|
||||
|
||||
shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]]
|
||||
bench_matmul(torch.float16, shapes)
|
||||
bench_matmul(torch.float32, shapes)
|
||||
115
pkgs/xformers/benchmarks/benchmark_triton_dropout.py
Normal file
115
pkgs/xformers/benchmarks/benchmark_triton_dropout.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.triton import FusedDropoutBias
|
||||
|
||||
SHAPES = [
|
||||
(8, 256, 512),
|
||||
(8, 512, 1024),
|
||||
(4, 1024, 1024),
|
||||
(2, 2048, 2048),
|
||||
(1, 2048, 12288),
|
||||
(2, 4096, 4096),
|
||||
]
|
||||
|
||||
P = 0.1
|
||||
|
||||
|
||||
def to_gbs_fw(a, ms, bias):
|
||||
# Read and write the full array
|
||||
total = 2 * a.numel() * a.element_size()
|
||||
|
||||
if bias:
|
||||
# Read the bias, ideally only once
|
||||
total += a.shape[-1] * a.element_size()
|
||||
|
||||
return total * 1e-9 / (ms * 1e-3)
|
||||
|
||||
|
||||
def bench_dropout(bias: bool, backward: bool, activation: Optional[Activation]):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [
|
||||
torch.float16,
|
||||
torch.float32,
|
||||
]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
a = torch.rand(
|
||||
(B, M, K), device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
b = torch.rand(K, device=device, dtype=dtype, requires_grad=backward)
|
||||
torch_act = build_activation(activation)
|
||||
triton_dropout = FusedDropoutBias(
|
||||
P, bias_shape=K if bias else None, activation=activation
|
||||
)
|
||||
|
||||
def torch_step(x):
|
||||
x_ = x + b if bias else x
|
||||
y = torch.nn.functional.dropout(x_, P)
|
||||
if activation:
|
||||
y = torch_act(y)
|
||||
|
||||
if backward:
|
||||
y.grad = None
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def triton_step(x):
|
||||
y = triton_dropout(x)
|
||||
if backward:
|
||||
y.grad = None
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
"pytorch - bias: {} - fw{} - act: {}".format(
|
||||
bias, "+bw" if backward else "", activation
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
"triton - bias: {} - fw{} - act: {}".format(
|
||||
bias, "+bw" if backward else "", activation
|
||||
),
|
||||
),
|
||||
]:
|
||||
time = triton.testing.do_bench(
|
||||
lambda: testcase.function(a), grad_to_none=[a, b]
|
||||
)[0]
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
# Record BW
|
||||
bandwidth = to_gbs_fw(a, time, bias)
|
||||
results[key][testcase.name] = f"{bandwidth:.1f}"
|
||||
|
||||
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
|
||||
pretty_plot(
|
||||
results,
|
||||
title="Dropout-Bias-{}-FW{}-{}-Act: {}".format(
|
||||
bias, "+BW" if backward else "", dtype, activation
|
||||
),
|
||||
units="GB/s",
|
||||
dash_key="pytorch",
|
||||
)
|
||||
|
||||
|
||||
for activation in [Activation.GeLU, None, Activation.SquaredReLU]:
|
||||
for bw in [True, False]:
|
||||
for bias in [True, False]:
|
||||
bench_dropout(bias, bw, activation)
|
||||
160
pkgs/xformers/benchmarks/benchmark_triton_fused_linear.py
Normal file
160
pkgs/xformers/benchmarks/benchmark_triton_fused_linear.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.triton.fused_linear_layer import FusedLinear
|
||||
|
||||
SHAPES = [
|
||||
(8, 512, 256), # Batch x Seq x Embedding
|
||||
(8, 512, 512),
|
||||
(4, 512, 1024),
|
||||
(2, 512, 2048),
|
||||
(2, 512, 4096),
|
||||
(2, 512, 8192),
|
||||
]
|
||||
|
||||
# Switch PyTorch to TF32 accumulations, Triton does that also
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
def get_metrics_transform(
|
||||
activation: Optional[Activation],
|
||||
a: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
b: Optional[torch.Tensor],
|
||||
backward: bool,
|
||||
):
|
||||
# all operations will involve a * weight.
|
||||
flop = a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1)
|
||||
|
||||
# optional activation on top
|
||||
if activation is not None:
|
||||
flop += a.numel()
|
||||
|
||||
# optionally * 2 (before the bias) if backward
|
||||
if backward:
|
||||
flop *= 2
|
||||
|
||||
# backward will also output a gradient with respect to the bias
|
||||
# which consolidates on all the activation gradient
|
||||
flop += a.shape[0] * a.shape[1] * w.shape[1]
|
||||
|
||||
# backward will also ouput another gradient with respect to the weight,
|
||||
# which is another matmul, in between the grad_out and the inputs this time
|
||||
flop += a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1)
|
||||
|
||||
# optional bias on top
|
||||
if b is not None:
|
||||
flop += b.numel()
|
||||
|
||||
def metric_conversion(ms):
|
||||
# Returns TFlops/second
|
||||
return flop * 1e-12 / (ms * 1e-3)
|
||||
|
||||
return metric_conversion
|
||||
|
||||
|
||||
def bench_linear(activations: List[Optional[Activation]]):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
]:
|
||||
for backward in [True, False]:
|
||||
|
||||
for activation in activations:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for bias in [False, True]:
|
||||
for B, M, K in SHAPES:
|
||||
a = torch.rand(
|
||||
B, M, K, device=device, dtype=dtype, requires_grad=backward
|
||||
)
|
||||
|
||||
# Pytorch linear layer + activation
|
||||
torch_linear = torch.nn.Linear(K, 4 * K, bias=bias).to(
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
torch_activation = build_activation(activation)
|
||||
|
||||
# Fused layer equivalent
|
||||
fused_linear = FusedLinear(
|
||||
K, 4 * K, bias=bias, activation=activation
|
||||
).to(dtype=dtype, device=device)
|
||||
|
||||
def torch_step(x):
|
||||
y = torch_activation(torch_linear(x))
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def triton_step(x):
|
||||
y = fused_linear(x)
|
||||
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
metrics_transform = get_metrics_transform(
|
||||
activation,
|
||||
a,
|
||||
torch_linear.weight,
|
||||
torch_linear.bias,
|
||||
backward,
|
||||
)
|
||||
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
"pytorch - {} - {} bias - fw{}".format(
|
||||
activation,
|
||||
"no" if not bias else "",
|
||||
"+bw" if backward else "",
|
||||
),
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
"triton - {} - {} bias - fw{}".format(
|
||||
activation,
|
||||
"no" if not bias else "",
|
||||
"+bw" if backward else "",
|
||||
),
|
||||
),
|
||||
]:
|
||||
time = triton.testing.do_bench(
|
||||
lambda: testcase.function(a)
|
||||
)[0]
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
metric = metrics_transform(time)
|
||||
results[key][testcase.name] = f"{metric:.1f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title="\n --- Type: {} ---".format(dtype),
|
||||
units="TFlops/s",
|
||||
)
|
||||
|
||||
_type = "_fp16" if dtype == torch.float16 else "_fp32"
|
||||
title = "FusedLinear" + _type + "_FW"
|
||||
if backward:
|
||||
title += "_BW"
|
||||
title += "_" + activation.value if activation else "_none"
|
||||
pretty_plot(results, title, "TFlops/s", dash_key="pytorch")
|
||||
|
||||
|
||||
activations = [ac for ac in Activation] + [None] # type: ignore
|
||||
bench_linear(activations) # type: ignore
|
||||
92
pkgs/xformers/benchmarks/benchmark_triton_layernorm.py
Normal file
92
pkgs/xformers/benchmarks/benchmark_triton_layernorm.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.triton import FusedLayerNorm
|
||||
|
||||
SHAPES = [
|
||||
(8, 256, 512),
|
||||
(8, 512, 1024),
|
||||
(4, 1024, 1024),
|
||||
(2, 2048, 2048),
|
||||
(2, 4096, 4096),
|
||||
(1, 2048, 12288),
|
||||
]
|
||||
|
||||
|
||||
def to_gbs_fw(a, ms):
|
||||
# Read and write the full array
|
||||
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
|
||||
|
||||
def bench_layernorm(backward: bool):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in SHAPES:
|
||||
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward)
|
||||
|
||||
# Pytorch layer norn
|
||||
torch_layernorm = torch.nn.LayerNorm([K]).to(dtype=dtype, device=device)
|
||||
|
||||
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize the
|
||||
# `FusedLinearNorm` import.
|
||||
# Fused layernorm equivalent
|
||||
fused_layernorm = FusedLayerNorm([K]).to(dtype=dtype, device=device)
|
||||
|
||||
def torch_step(x):
|
||||
y = torch_layernorm(x)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
def triton_step(x):
|
||||
y = fused_layernorm(x)
|
||||
if backward:
|
||||
torch.norm(y).backward()
|
||||
return y
|
||||
|
||||
for testcase in [
|
||||
TestCase(
|
||||
torch_step,
|
||||
"pytorch - fw{}".format("+bw" if backward else ""),
|
||||
),
|
||||
TestCase(
|
||||
triton_step,
|
||||
"triton - fw{}".format("+bw" if backward else ""),
|
||||
),
|
||||
]:
|
||||
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
# Record BW
|
||||
bandwidth = to_gbs_fw(a, time)
|
||||
results[key][testcase.name] = f"{bandwidth:.1f}"
|
||||
|
||||
pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s")
|
||||
pretty_plot(
|
||||
results,
|
||||
title="LayerNorm-FW{}-{}".format("+BW" if backward else "", dtype),
|
||||
units="GB/s",
|
||||
dash_key="pytorch",
|
||||
)
|
||||
|
||||
|
||||
for bw in [False, True]:
|
||||
bench_layernorm(bw)
|
||||
91
pkgs/xformers/benchmarks/benchmark_triton_softmax.py
Normal file
91
pkgs/xformers/benchmarks/benchmark_triton_softmax.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, bench_functions
|
||||
from xformers.triton.softmax import log_softmax as triton_log_softmax
|
||||
from xformers.triton.softmax import softmax as triton_softmax
|
||||
|
||||
SHAPES = [
|
||||
(8, 384, 128),
|
||||
(8, 784, 512),
|
||||
(4, 1024, 768),
|
||||
(4, 2048, 1024),
|
||||
(2, 2048, 2048),
|
||||
(2, 2048, 4096),
|
||||
(2, 4096, 4096),
|
||||
(1, 2048, 12288),
|
||||
]
|
||||
|
||||
|
||||
def pytorch_fw_bw(x):
|
||||
y = torch.norm(torch.softmax(x, dim=-1))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_causal_fw(x):
|
||||
_ = triton_softmax(x, causal=True)
|
||||
|
||||
|
||||
def triton_fw_bw(x):
|
||||
y = torch.norm(triton_softmax(x))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_causal_fw_bw(x):
|
||||
y = torch.norm(triton_softmax(x, causal=True))
|
||||
y.backward()
|
||||
|
||||
|
||||
def pytorch_log_fw_bw(x):
|
||||
y = torch.norm(torch.log_softmax(x, dim=-1))
|
||||
y.backward()
|
||||
|
||||
|
||||
def triton_log_fw_bw(x):
|
||||
y = torch.norm(triton_log_softmax(x))
|
||||
y.backward()
|
||||
|
||||
|
||||
# Test FW
|
||||
def to_gbs_fw(a, ms):
|
||||
# Read and write the full array
|
||||
return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
|
||||
|
||||
def to_gbs_fwbw(a, ms):
|
||||
# same as above, but we do it twice (FW and then gradient)
|
||||
return 2 * to_gbs_fw(a, ms)
|
||||
|
||||
|
||||
bench_functions(
|
||||
[
|
||||
TestCase(lambda x: torch.softmax(x, dim=-1), "pytorch - fw"),
|
||||
TestCase(triton_softmax, "triton - fw"),
|
||||
TestCase(triton_causal_fw, "triton - causal - fw"),
|
||||
TestCase(lambda x: torch.log_softmax(x, dim=-1), "pytorch - log - fw"),
|
||||
TestCase(triton_log_softmax, "triton - log - fw"),
|
||||
],
|
||||
SHAPES,
|
||||
to_gbs_fw,
|
||||
"GB/s",
|
||||
"Softmax_Bandwidth_FW_",
|
||||
)
|
||||
|
||||
# Test FW+BW
|
||||
bench_functions(
|
||||
[
|
||||
TestCase(pytorch_fw_bw, "pytorch - fw+bw"),
|
||||
TestCase(triton_fw_bw, "triton - fw+bw"),
|
||||
TestCase(triton_causal_fw_bw, "triton - causal - fw+bw"),
|
||||
TestCase(pytorch_log_fw_bw, "pytorch - log - fw+bw"),
|
||||
TestCase(triton_log_fw_bw, "triton - log - fw+bw"),
|
||||
],
|
||||
SHAPES,
|
||||
to_gbs_fwbw,
|
||||
"GB/s",
|
||||
"Softmax_Bandwidth_FW_BW_",
|
||||
)
|
||||
71
pkgs/xformers/benchmarks/benchmark_triton_stride_sum.py
Normal file
71
pkgs/xformers/benchmarks/benchmark_triton_stride_sum.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print
|
||||
from xformers.triton.sum_strided import sum_2d_dim_0
|
||||
|
||||
SHAPES = [
|
||||
(128, 128),
|
||||
(384, 128),
|
||||
(784, 512),
|
||||
(1024, 768),
|
||||
(2048, 1024),
|
||||
(4096, 4096),
|
||||
]
|
||||
|
||||
|
||||
def to_gbs(a, ms):
|
||||
# Read the full array, write the non-reduced dimension
|
||||
return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3)
|
||||
|
||||
|
||||
def bench_functions(
|
||||
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for M, N in shapes:
|
||||
a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
for testcase in test_cases:
|
||||
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
|
||||
|
||||
metric = metric_transform(a, time)
|
||||
|
||||
key = f"M={M}, N={N}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{metric:.1f}"
|
||||
|
||||
_type = " fp16" if dtype == torch.float16 else " fp32"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=" ------------- Type: {} ------------- ".format(_type),
|
||||
units=unit,
|
||||
)
|
||||
|
||||
pretty_plot(results, title + _type, unit, dash_key="pytorch")
|
||||
|
||||
|
||||
bench_functions(
|
||||
[
|
||||
TestCase(lambda x: torch.sum(x, dim=0), "pytorch"),
|
||||
TestCase(sum_2d_dim_0, "triton"),
|
||||
],
|
||||
SHAPES,
|
||||
to_gbs,
|
||||
"GB/s",
|
||||
"Strided_sum",
|
||||
)
|
||||
660
pkgs/xformers/benchmarks/utils.py
Normal file
660
pkgs/xformers/benchmarks/utils.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import copy
|
||||
import csv
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict, namedtuple
|
||||
from dataclasses import replace
|
||||
from typing import Any, Dict, Generator, List, Set, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import torch
|
||||
import tqdm
|
||||
from torch.utils import benchmark
|
||||
|
||||
sns.set()
|
||||
|
||||
TestCase = namedtuple("TestCase", ["function", "name"])
|
||||
|
||||
|
||||
_triton_is_available = torch.cuda.is_available()
|
||||
if _triton_is_available:
|
||||
try:
|
||||
import triton
|
||||
except ImportError as e:
|
||||
logging.warning(f"Triton is not available: {e}.\nbench_functions")
|
||||
_triton_is_available = False
|
||||
|
||||
|
||||
def pretty_print(results, title, units):
|
||||
"""Printout the contents of a dict as a human-readable and Markdown compatible array"""
|
||||
print(title)
|
||||
header = " Units: {:<45}".format(units)
|
||||
print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys()))
|
||||
|
||||
offset = len(header)
|
||||
print(
|
||||
"|-{}|".format("-" * offset)
|
||||
+ "".join("{}|".format("-" * 20) for _ in results.keys())
|
||||
)
|
||||
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(v[k])
|
||||
|
||||
for k, w in workloads.items():
|
||||
print(
|
||||
"| {0:<{offset}}|".format(k, offset=offset)
|
||||
+ "".join("{:<20}|".format(v) for v in w)
|
||||
)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
def pretty_plot(
|
||||
results, title, units: str, filename=None, dash_key="", legend_loc="lower right"
|
||||
):
|
||||
"""Graph out the contents of a dict.
|
||||
Dash key means that if the result label has this key, then it will be displayed with a dash
|
||||
"""
|
||||
|
||||
if not filename:
|
||||
filename = title + ".png"
|
||||
|
||||
# Sanitize the filename
|
||||
filename = (
|
||||
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
||||
)
|
||||
|
||||
# Gather all the results in "collumns"
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(float(v[k]))
|
||||
|
||||
# Make sure that the plot is big enough
|
||||
f = plt.figure()
|
||||
f.set_figwidth(6)
|
||||
f.set_figheight(6)
|
||||
|
||||
# Display the collections
|
||||
for k, v in workloads.items():
|
||||
if dash_key and dash_key in k:
|
||||
plt.plot(list(results.keys()), v, "--")
|
||||
else:
|
||||
plt.plot(list(results.keys()), v)
|
||||
|
||||
plt.title(title)
|
||||
plt.legend(list(workloads.keys()), loc=legend_loc)
|
||||
plt.ylabel(units)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
plt.savefig(filename, bbox_inches="tight")
|
||||
plt.close(f)
|
||||
|
||||
|
||||
if _triton_is_available:
|
||||
|
||||
def bench_functions(
|
||||
test_cases: List[TestCase], shapes, metric_transform, unit, title=""
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
for dtype in [torch.bfloat16, torch.float16, torch.float32]:
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for B, M, K in shapes:
|
||||
a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True)
|
||||
|
||||
for testcase in test_cases:
|
||||
time = triton.testing.do_bench(lambda: testcase.function(a))[0]
|
||||
|
||||
metric = metric_transform(a, time)
|
||||
|
||||
key = f"B={B}, M={M}, K={K}"
|
||||
if key not in results:
|
||||
results[key] = {}
|
||||
|
||||
results[key][testcase.name] = f"{metric:.1f}"
|
||||
|
||||
pretty_print(
|
||||
results,
|
||||
title=" ------------- Type: {} ------------- ".format(dtype),
|
||||
units=unit,
|
||||
)
|
||||
pretty_plot(results, title + str(dtype), unit, dash_key="pytorch")
|
||||
|
||||
|
||||
def pretty_barplot(results, title, units: str, filename=None, dash_key=""):
|
||||
"""Graph out the contents of a dict.
|
||||
Dash key means that if the result label has this key, then it will be displayed with a dash
|
||||
"""
|
||||
|
||||
if not filename:
|
||||
filename = title + ".png"
|
||||
|
||||
# Sanitize the filename
|
||||
filename = (
|
||||
filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "")
|
||||
)
|
||||
|
||||
xlabels = list(results.keys())
|
||||
# Gather all the results in "collumns"
|
||||
workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()}
|
||||
for v in results.values():
|
||||
for k in v.keys():
|
||||
workloads[k].append(float(v[k]))
|
||||
|
||||
options = list(workloads.keys())
|
||||
group_len = len(options)
|
||||
for key in workloads.keys():
|
||||
num_groups = len(workloads[key])
|
||||
break
|
||||
group_width = group_len + 1
|
||||
|
||||
# Make sure that the plot is big enough
|
||||
f = plt.figure()
|
||||
f.set_figwidth(6)
|
||||
f.set_figheight(6)
|
||||
|
||||
for idx in range(group_len):
|
||||
option = options[idx]
|
||||
values = workloads[option]
|
||||
xloc = np.arange(1 + idx, group_width * num_groups, group_width)
|
||||
plt.bar(xloc, values, width=1, edgecolor="black")
|
||||
|
||||
plt.title(title)
|
||||
plt.legend(list(workloads.keys()), loc="upper right")
|
||||
plt.ylabel(units)
|
||||
|
||||
ax = plt.gca()
|
||||
xticks_loc = np.arange(
|
||||
1 + (group_len - 1) / 2.0, group_width * num_groups, group_width
|
||||
)
|
||||
ax.set_xticks(xticks_loc, xlabels)
|
||||
plt.xticks(rotation=45)
|
||||
|
||||
plt.setp(ax.xaxis.get_majorticklabels(), ha="right")
|
||||
ax.set_axisbelow(True)
|
||||
ax.yaxis.grid(color="gray", linestyle="dashed")
|
||||
ax.xaxis.grid(color="gray", linestyle="dashed")
|
||||
|
||||
plt.savefig(filename, bbox_inches="tight")
|
||||
plt.close(f)
|
||||
|
||||
|
||||
def rmf(filename: str) -> None:
|
||||
"""Remove a file like rm -f."""
|
||||
try:
|
||||
os.remove(filename)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_files_ctx(num: int) -> Generator:
|
||||
"""A context to get tempfiles and ensure they are cleaned up."""
|
||||
files = [tempfile.mkstemp()[1] for _ in range(num)]
|
||||
|
||||
yield tuple(files)
|
||||
|
||||
# temp files could have been removed, so we use rmf.
|
||||
for name in files:
|
||||
rmf(name)
|
||||
|
||||
|
||||
META_ALGORITHM = "algorithm"
|
||||
BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"]
|
||||
|
||||
|
||||
# Serialize/unserialize to CSV
|
||||
# We could use pkl, but resort to CSV for readability
|
||||
def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]:
|
||||
parts = os.path.basename(filename).split(".")
|
||||
env = ""
|
||||
description = ""
|
||||
if len(parts) == 3:
|
||||
env = parts[1]
|
||||
description = parts[0]
|
||||
|
||||
data = []
|
||||
with open(filename, "r") as csvfile:
|
||||
reader = csv.DictReader(csvfile)
|
||||
for row in reader:
|
||||
if description != "" and row["description"] not in BASELINE_DESCRIPTIONS:
|
||||
row["description"] = description
|
||||
task_spec = benchmark.utils.common.TaskSpec(
|
||||
stmt="",
|
||||
setup="",
|
||||
global_setup="",
|
||||
label=row["label"],
|
||||
sub_label=row["sub_label"],
|
||||
description=row["description"],
|
||||
env=env,
|
||||
num_threads=int(row["num_threads"]),
|
||||
)
|
||||
measurement = benchmark.utils.common.Measurement(
|
||||
number_per_run=1,
|
||||
raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)],
|
||||
task_spec=task_spec,
|
||||
)
|
||||
measurement.mem_use = float(row["mem_use_mb"]) # type: ignore
|
||||
data.append(
|
||||
(
|
||||
{
|
||||
META_ALGORITHM: row["algorithm"]
|
||||
if row["algorithm"] != ""
|
||||
else None,
|
||||
},
|
||||
measurement,
|
||||
)
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _benchmark_results_to_csv(
|
||||
filename: str, results: List[Tuple[Dict[str, Any], Any]]
|
||||
) -> None:
|
||||
data = [
|
||||
{
|
||||
"sub_label": r.task_spec.sub_label,
|
||||
"label": r.task_spec.label,
|
||||
"num_threads": r.task_spec.num_threads,
|
||||
"algorithm": metadata.get(META_ALGORITHM, ""),
|
||||
"description": r.task_spec.description
|
||||
if r.task_spec.description in BASELINE_DESCRIPTIONS
|
||||
else "",
|
||||
"runtime_us": int(1000 * 1000 * r.mean),
|
||||
"mem_use_mb": r.mem_use,
|
||||
}
|
||||
for metadata, r in results
|
||||
]
|
||||
with open(filename, "w+", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys()))
|
||||
writer.writeheader()
|
||||
for d in data:
|
||||
writer.writerow(d)
|
||||
|
||||
|
||||
def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]:
|
||||
"""
|
||||
Returns a `benchmark.Compare` object, except that if we have runs
|
||||
with different algorithms, we also add the algorithm name
|
||||
in the column titles
|
||||
"""
|
||||
all_algorithms: Set[str] = set()
|
||||
all_description: Set[str] = set()
|
||||
for metadata, r in results:
|
||||
algo = metadata.get(META_ALGORITHM, None)
|
||||
if algo is not None:
|
||||
all_algorithms.add(algo)
|
||||
all_description.add(r.task_spec.description)
|
||||
display_algo = len(all_algorithms) > 1
|
||||
display_descr = len(all_description) > 1
|
||||
|
||||
display_results = []
|
||||
for metadata, r in results:
|
||||
algo = metadata.get(META_ALGORITHM, None)
|
||||
if algo is None:
|
||||
display_results.append(r)
|
||||
else:
|
||||
r = copy.copy(r)
|
||||
description = ""
|
||||
if display_descr:
|
||||
description = r.task_spec.description
|
||||
if display_algo:
|
||||
if display_descr:
|
||||
description += "["
|
||||
description += algo
|
||||
if display_descr:
|
||||
description += "]"
|
||||
r.task_spec = replace(r.task_spec, description=description)
|
||||
display_results.append(r)
|
||||
return display_results
|
||||
|
||||
|
||||
def _render_bar_plot(results: List[Any], store_results_folder: str) -> None:
|
||||
if not results:
|
||||
return
|
||||
runtime: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict)
|
||||
all_descriptions: List[str] = []
|
||||
for r in results:
|
||||
# Hacky: use a list to preserve order
|
||||
if r.task_spec.description not in all_descriptions:
|
||||
if r.task_spec.description in BASELINE_DESCRIPTIONS:
|
||||
all_descriptions.insert(0, r.task_spec.description)
|
||||
else:
|
||||
all_descriptions.append(r.task_spec.description)
|
||||
runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean
|
||||
memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use
|
||||
all_data_mem: List[Any] = []
|
||||
all_data_run: List[Any] = []
|
||||
for key, runtime_values in runtime.items():
|
||||
memory_values = memory_usage[key]
|
||||
denom = memory_values.get(all_descriptions[0], math.inf)
|
||||
if denom == 0:
|
||||
all_data_mem.append([key] + [0] * len(all_descriptions))
|
||||
else:
|
||||
all_data_mem.append(
|
||||
[key] + [memory_values.get(d, 0) / denom for d in all_descriptions]
|
||||
)
|
||||
all_data_run.append(
|
||||
[key]
|
||||
+ [
|
||||
runtime_values.get(all_descriptions[0], 0)
|
||||
/ runtime_values.get(d, math.inf)
|
||||
for d in all_descriptions
|
||||
]
|
||||
)
|
||||
if all_descriptions[0] == "":
|
||||
all_descriptions[0] = "baseline"
|
||||
else:
|
||||
all_descriptions[0] = f"{all_descriptions[0]} (baseline)"
|
||||
|
||||
for data, filename, title in [
|
||||
(all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"),
|
||||
(
|
||||
all_data_run,
|
||||
"runtime.png",
|
||||
"Runtime speedup (vs baseline, higher is better)",
|
||||
),
|
||||
]:
|
||||
df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions)
|
||||
df.plot(
|
||||
x="Configuration",
|
||||
kind="bar",
|
||||
stacked=False,
|
||||
title=title,
|
||||
)
|
||||
plt.tight_layout()
|
||||
filename_full = os.path.join(store_results_folder, filename)
|
||||
plt.savefig(filename_full)
|
||||
print(f"Saved plot: {filename_full}")
|
||||
|
||||
|
||||
def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None:
|
||||
"""
|
||||
Helper function to run benchmarks.
|
||||
Supports loading previous results for comparison, and saving current results to file.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--fn", default=None, type=str, help="Only benchmark this function"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label", default=None, type=str, help="Store results to a file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fail_if_regression",
|
||||
action="store_true",
|
||||
help="Enabled in CI to check against performance regressions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compare",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Compare to previously stored benchmarks (coma separated)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--omit-baselines",
|
||||
action="store_true",
|
||||
help="Do not run the (potentially slow) baselines",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="Skip intermediate results and progress bar",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.fn is not None and args.fn != benchmark_fn.__name__:
|
||||
print(f'Skipping benchmark "{benchmark_fn.__name__}"')
|
||||
return
|
||||
benchmark_run_and_compare(
|
||||
benchmark_fn=benchmark_fn,
|
||||
cases=cases,
|
||||
optimized_label="optimized" if args.label is None else args.label,
|
||||
fail_if_regression=args.fail_if_regression,
|
||||
compare=args.compare.split(",") if args.compare is not None else [],
|
||||
quiet=args.quiet,
|
||||
omit_baselines=args.omit_baselines,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_run_and_compare(
|
||||
benchmark_fn,
|
||||
cases: List[Dict[str, Any]],
|
||||
compare: List[str],
|
||||
omit_baselines: bool = False,
|
||||
fail_if_regression: bool = False,
|
||||
quiet: bool = False,
|
||||
optimized_label: str = "optimized",
|
||||
*,
|
||||
min_run_time: int = 2,
|
||||
atol_s: float = 30e-6,
|
||||
rtol: float = 0.05,
|
||||
) -> None:
|
||||
SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True
|
||||
results_compare_to = []
|
||||
results = []
|
||||
|
||||
store_results_folder = os.path.expanduser(
|
||||
os.path.join(
|
||||
os.environ.get(
|
||||
"XFORMERS_BENCHMARKS_CACHE",
|
||||
os.path.join("~", ".cache", "xformers", "benchmarks"),
|
||||
),
|
||||
benchmark_fn.__name__,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
env = (
|
||||
torch.cuda.get_device_name(torch.cuda.current_device())
|
||||
.replace(" ", "_")
|
||||
.replace("-", "_")
|
||||
.replace(".", "_")
|
||||
)
|
||||
except (RuntimeError, AssertionError): # No GPU
|
||||
env = "cpu"
|
||||
assert (
|
||||
"." not in optimized_label
|
||||
), f"label=`{optimized_label}` should not contain dots"
|
||||
assert "." not in env, f"env=`{env}` should not contain dots"
|
||||
|
||||
os.makedirs(store_results_folder, exist_ok=True)
|
||||
|
||||
# Load runs that we want to compare to
|
||||
skip_vanilla_tasks = set()
|
||||
for cmp_name in compare:
|
||||
name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*"
|
||||
for filename in glob.glob(
|
||||
os.path.join(store_results_folder, f"{name_with_env}.csv")
|
||||
):
|
||||
loaded = _benchmark_results_from_csv(filename)
|
||||
for m, r in loaded:
|
||||
if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE:
|
||||
skip_vanilla_tasks.add(
|
||||
(r.task_spec.sub_label, r.task_spec.num_threads)
|
||||
)
|
||||
results_compare_to += loaded
|
||||
|
||||
if not quiet:
|
||||
pbar = tqdm.tqdm(cases, leave=False)
|
||||
cases = pbar
|
||||
for case in cases:
|
||||
if quiet:
|
||||
print(str(case))
|
||||
else:
|
||||
pbar.write(f"====== {str(case)} ======")
|
||||
try:
|
||||
benchmarks_generator = benchmark_fn(**case)
|
||||
except NotImplementedError:
|
||||
# pbar.write(f"Skipped (NotImplementedError)")
|
||||
continue
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
continue
|
||||
|
||||
name = None
|
||||
try:
|
||||
for benchmark_object in benchmarks_generator:
|
||||
is_optimized = (
|
||||
benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS
|
||||
)
|
||||
metadata = {}
|
||||
if is_optimized:
|
||||
metadata[META_ALGORITHM] = benchmark_object._task_spec.description
|
||||
benchmark_object._task_spec = replace(
|
||||
benchmark_object._task_spec, description=optimized_label
|
||||
)
|
||||
elif (
|
||||
omit_baselines
|
||||
or (
|
||||
benchmark_object._task_spec.sub_label,
|
||||
benchmark_object._task_spec.num_threads,
|
||||
)
|
||||
in skip_vanilla_tasks
|
||||
):
|
||||
continue
|
||||
|
||||
memory = math.inf
|
||||
try:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_begin = torch.cuda.max_memory_allocated() / 2**20
|
||||
benchmark_object._task_spec = replace(
|
||||
benchmark_object._task_spec, env=env
|
||||
)
|
||||
measurement = benchmark_object.blocked_autorange(
|
||||
min_run_time=min_run_time
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
results.append((metadata, measurement))
|
||||
name = measurement.task_spec.description
|
||||
memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin
|
||||
measurement.mem_use = memory
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
finally:
|
||||
del benchmark_object
|
||||
if not quiet:
|
||||
pbar.write(f"{name}: memory used: {memory} MB")
|
||||
except RuntimeError as e:
|
||||
if "CUDA out of memory" not in str(e):
|
||||
raise
|
||||
if not quiet:
|
||||
pbar.write("Skipped (OOM)")
|
||||
# Display results for benchmarks we just calculated
|
||||
if name is not None and not quiet:
|
||||
|
||||
def matches_current(r):
|
||||
return (
|
||||
r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label
|
||||
and r[1].task_spec.label == results[-1][1].task_spec.label
|
||||
)
|
||||
|
||||
pbar.write(
|
||||
str(
|
||||
benchmark.Compare(
|
||||
_finalize_results(
|
||||
list(filter(matches_current, results))
|
||||
+ list(filter(matches_current, results_compare_to))
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
results_for_print = _finalize_results(results + results_compare_to)
|
||||
benchmark.Compare(results_for_print).print()
|
||||
_render_bar_plot(results_for_print, store_results_folder)
|
||||
|
||||
# Save runs to a file
|
||||
if results and optimized_label is not None:
|
||||
write_to_path = os.path.join(
|
||||
store_results_folder, f"{optimized_label}.{env}.csv"
|
||||
)
|
||||
_benchmark_results_to_csv(write_to_path, results)
|
||||
print(f"Saved results to {write_to_path}")
|
||||
|
||||
if fail_if_regression:
|
||||
_fail_if_regressions(
|
||||
results, reference=results_compare_to, atol_s=atol_s, rtol=rtol
|
||||
)
|
||||
|
||||
|
||||
def _fail_if_regressions(
|
||||
results: List[Any], reference: List[Any], atol_s: float, rtol: float
|
||||
) -> None:
|
||||
def get_measurement_id(r):
|
||||
return (
|
||||
r[0].get(META_ALGORITHM, ""),
|
||||
r[1].task_spec.label,
|
||||
r[1].task_spec.sub_label,
|
||||
r[1].task_spec.env,
|
||||
)
|
||||
|
||||
id_to_result = {}
|
||||
for r in results:
|
||||
id_to_result[get_measurement_id(r)] = r[1]
|
||||
|
||||
num_better = 0
|
||||
num_worse = 0
|
||||
num_nochange = 0
|
||||
num_unk = 0
|
||||
reference_set = set()
|
||||
for ref in reference:
|
||||
if ref[1].task_spec.description in BASELINE_DESCRIPTIONS:
|
||||
continue
|
||||
benchmark_id = get_measurement_id(ref)
|
||||
if benchmark_id in reference_set:
|
||||
raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}")
|
||||
reference_set.add(benchmark_id)
|
||||
if benchmark_id not in id_to_result:
|
||||
num_unk += 1
|
||||
continue
|
||||
res = id_to_result[benchmark_id]
|
||||
# If significative change
|
||||
if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s:
|
||||
is_now_better = res.mean < ref[1].mean
|
||||
if is_now_better:
|
||||
num_better += 1
|
||||
else:
|
||||
num_worse += 1
|
||||
cmp = "IMPROVED" if is_now_better else "REGRESS "
|
||||
print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}")
|
||||
else:
|
||||
num_nochange += 1
|
||||
|
||||
print("Regression test summary:")
|
||||
print(f" Better : {num_better}")
|
||||
print(f" No change: {num_nochange}")
|
||||
print(f" Worse : {num_worse}")
|
||||
if num_unk > 0:
|
||||
print(f" (no ref) : {num_unk}")
|
||||
if num_worse > 1:
|
||||
raise RuntimeError("At least one benchmark regressed!")
|
||||
if num_nochange == 0:
|
||||
raise RuntimeError("No reference found")
|
||||
Reference in New Issue
Block a user