First commit
This commit is contained in:
298
pkgs/xformers/benchmarks/LRA/run_tasks.py
Normal file
298
pkgs/xformers/benchmarks/LRA/run_tasks.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, cast
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fvcore.nn import FlopCountAnalysis, flop_count_str
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from xformers.benchmarks.LRA.code.dataset import LRADataset
|
||||
from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
|
||||
from xformers.components.attention import ATTENTION_REGISTRY
|
||||
|
||||
|
||||
class Task(str, Enum):
|
||||
Retrieval = "retrieval"
|
||||
ListOps = "listops"
|
||||
Image = "image"
|
||||
PathfinderBaseline = "pathfinder32-curv_baseline"
|
||||
PathfinderContour9 = "pathfinder32-curv_contour_length_9"
|
||||
PathfinderContour14 = "pathfinder32-curv_contour_length_14"
|
||||
Text = "text"
|
||||
|
||||
|
||||
def load_config(path: str) -> Dict:
|
||||
with open(Path(path).absolute(), "r") as fileio:
|
||||
config = json.load(fileio)
|
||||
|
||||
# Duplicate the pathfinder configs
|
||||
config["pathfinder32-curv_baseline"] = config["pathfinder32"]
|
||||
config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
|
||||
config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
|
||||
return config
|
||||
|
||||
|
||||
def build_model(args: argparse.Namespace, config: Dict) -> nn.Module:
|
||||
task = args.task
|
||||
attention_name = args.attention
|
||||
|
||||
model = cast(
|
||||
pl.LightningModule,
|
||||
ModelForSCDual(config[f"{task}"], attention_name)
|
||||
if task == Task.Retrieval
|
||||
else ModelForSC(config[f"{task}"], attention_name),
|
||||
)
|
||||
|
||||
logging.info(model)
|
||||
summary = pl.utilities.model_summary.LayerSummary(model)
|
||||
logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M")
|
||||
|
||||
with torch.no_grad():
|
||||
# Check the flops
|
||||
seq_len = config[f"{task}"]["model"]["common"]["seq_len"]
|
||||
x = torch.rand(1, seq_len).long()
|
||||
mask = torch.rand(1, seq_len).long()
|
||||
indices = torch.rand(1, seq_len).long()
|
||||
flops = FlopCountAnalysis(model.model, (x, mask, indices))
|
||||
logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops")
|
||||
logging.info(flop_count_str(flops))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_arg_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--attention",
|
||||
type=str,
|
||||
help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
|
||||
A list can be passed to test several mechanisms in sequence",
|
||||
dest="attention",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=Task,
|
||||
help=f"Task to chose, among {[t.value for t in Task]}.",
|
||||
dest="task",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_train",
|
||||
type=bool,
|
||||
help="Whether to skip training, and test an existing model",
|
||||
dest="skip_train",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to the config being used",
|
||||
dest="config",
|
||||
default="./config.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_dir",
|
||||
type=str,
|
||||
help="Path to the checkpoint directory",
|
||||
dest="checkpoint_dir",
|
||||
default=f"/checkpoints/{os.getenv('USER')}/xformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
help="Path to checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
help="Make it easier to debug a possible issue",
|
||||
dest="debug",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--world_size",
|
||||
help="Number of GPUs used",
|
||||
dest="world_size",
|
||||
type=int,
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep_parameters",
|
||||
help="Rewrite some hyperparameters in the config",
|
||||
dest="sweep_parameters",
|
||||
type=dict,
|
||||
default=None,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]:
|
||||
experiment_name = f"{task}__{attention_name}"
|
||||
logger = TensorBoardLogger(
|
||||
save_dir=args.checkpoint_dir,
|
||||
name="", # remove lightning_logs subdirectory
|
||||
version=experiment_name,
|
||||
)
|
||||
log_dir = os.path.join(logger._save_dir, experiment_name)
|
||||
return log_dir, logger
|
||||
|
||||
|
||||
def rewrite_hyper(config, rewrites):
|
||||
def replace(config_dict, k, v):
|
||||
if len(k.split(":")) == 1:
|
||||
config_dict[k] = v
|
||||
return
|
||||
first_key = k.split(":")[0]
|
||||
assert first_key in config_dict, first_key
|
||||
k = k[len(first_key) + 1 :]
|
||||
replace(config_dict[first_key], k, v)
|
||||
|
||||
for k, v in rewrites.items():
|
||||
replace(config, k, v)
|
||||
return config
|
||||
|
||||
|
||||
def build_dataloaders(
|
||||
args: argparse.Namespace,
|
||||
config_training: Dict,
|
||||
num_workers: int = 4,
|
||||
) -> Dict[str, DataLoader]:
|
||||
datasets = {}
|
||||
for component in ("train", "dev", "test"):
|
||||
datasets[component] = LRADataset(
|
||||
file_path=f"datasets/{args.task}.{component}.pickle",
|
||||
seq_len=config_training["seq_len"],
|
||||
)
|
||||
|
||||
# Gradient accumulation
|
||||
accumu_steps = config_training["gradient_accumulation"]
|
||||
logging.info(f"accumu_steps={accumu_steps}")
|
||||
|
||||
# Batch size
|
||||
per_gpu_batch_size = (
|
||||
config_training["batch_size"] // args.world_size // accumu_steps
|
||||
)
|
||||
logging.warning(
|
||||
f"Requested batch size: {config_training['batch_size']}. Given world\
|
||||
size and grad accumulation, per-gpu batch is\
|
||||
{per_gpu_batch_size}"
|
||||
)
|
||||
|
||||
dataloaders = {
|
||||
k: DataLoader(
|
||||
v,
|
||||
batch_size=per_gpu_batch_size,
|
||||
shuffle=False,
|
||||
pin_memory=True,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
for k, v in datasets.items()
|
||||
}
|
||||
return dataloaders
|
||||
|
||||
|
||||
def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]:
|
||||
eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step}
|
||||
for k, v in trainer.callback_metrics.items():
|
||||
eval_summary[k] = v.item()
|
||||
return eval_summary
|
||||
|
||||
|
||||
class BasicProgressBar(TQDMProgressBar):
|
||||
def get_metrics(self, trainer, model):
|
||||
items = super().get_metrics(trainer, model)
|
||||
items.pop("v_num", None)
|
||||
return items
|
||||
|
||||
|
||||
def benchmark(args):
|
||||
log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}")
|
||||
args.logger = logger
|
||||
|
||||
config = load_config(args.config)
|
||||
|
||||
config_task = config[f"{args.task}"]
|
||||
if args.sweep_parameters is not None:
|
||||
logging.info("Replacing hyperparameters")
|
||||
rewrite_hyper(config_task, args.sweep_parameters)
|
||||
|
||||
config_training = config_task["training"]
|
||||
config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
|
||||
logging.info(f"Learning rate: {config_training['learning_rate']}")
|
||||
|
||||
pl.seed_everything(config_training.get("seed", 0))
|
||||
dataloaders = build_dataloaders(args, config_training)
|
||||
|
||||
model = build_model(args, config)
|
||||
|
||||
progress_bar = BasicProgressBar()
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_accu",
|
||||
mode="max",
|
||||
dirpath=args.checkpoint_dir,
|
||||
filename="{epoch}-{val_accu:.2f}",
|
||||
every_n_train_steps=config_training["eval_frequency"],
|
||||
)
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
strategy=DDPStrategy(find_unused_parameters=args.debug)
|
||||
if not args.skip_train
|
||||
else None,
|
||||
accumulate_grad_batches=config_training["gradient_accumulation"],
|
||||
callbacks=[progress_bar, checkpoint_callback],
|
||||
detect_anomaly=args.debug,
|
||||
deterministic=True,
|
||||
gpus=args.world_size,
|
||||
limit_val_batches=config_training["num_eval_steps"],
|
||||
logger=logger,
|
||||
max_steps=config_training["num_train_steps"],
|
||||
num_sanity_val_steps=int(not args.skip_train),
|
||||
precision=16 if config_training["mixed_precision"] else 32,
|
||||
val_check_interval=config_training["eval_frequency"]
|
||||
/ float(len(dataloaders["train"])),
|
||||
)
|
||||
|
||||
if not args.skip_train:
|
||||
trainer.fit(
|
||||
model,
|
||||
train_dataloaders=dataloaders["train"],
|
||||
val_dataloaders=dataloaders["dev"],
|
||||
)
|
||||
ckpt_path = checkpoint_callback.best_model_path
|
||||
else:
|
||||
ckpt_path = args.checkpoint_path
|
||||
|
||||
trainer.test(
|
||||
model,
|
||||
dataloaders=dataloaders["test"],
|
||||
ckpt_path=ckpt_path,
|
||||
)
|
||||
eval_summary = get_eval_summary(trainer)
|
||||
with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f:
|
||||
logging.info(f"Saving test results at {f.name}")
|
||||
json.dump(eval_summary, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_arg_parser()
|
||||
args = parser.parse_args()
|
||||
if args.skip_train and args.checkpoint_path is None:
|
||||
raise parser.error("Must provide --checkpoint_path if --skip_train=True")
|
||||
benchmark(args)
|
||||
Reference in New Issue
Block a user