# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import argparse import json import logging import os from enum import Enum from pathlib import Path from typing import Dict, Tuple, cast import pytorch_lightning as pl import torch import torch.nn as nn from fvcore.nn import FlopCountAnalysis, flop_count_str from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.strategies import DDPStrategy from torch.utils.data import DataLoader from xformers.benchmarks.LRA.code.dataset import LRADataset from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual from xformers.components.attention import ATTENTION_REGISTRY class Task(str, Enum): Retrieval = "retrieval" ListOps = "listops" Image = "image" PathfinderBaseline = "pathfinder32-curv_baseline" PathfinderContour9 = "pathfinder32-curv_contour_length_9" PathfinderContour14 = "pathfinder32-curv_contour_length_14" Text = "text" def load_config(path: str) -> Dict: with open(Path(path).absolute(), "r") as fileio: config = json.load(fileio) # Duplicate the pathfinder configs config["pathfinder32-curv_baseline"] = config["pathfinder32"] config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] return config def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: task = args.task attention_name = args.attention model = cast( pl.LightningModule, ModelForSCDual(config[f"{task}"], attention_name) if task == Task.Retrieval else ModelForSC(config[f"{task}"], attention_name), ) logging.info(model) summary = pl.utilities.model_summary.LayerSummary(model) logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M") with torch.no_grad(): # Check the flops seq_len = config[f"{task}"]["model"]["common"]["seq_len"] x = torch.rand(1, seq_len).long() mask = torch.rand(1, seq_len).long() indices = torch.rand(1, seq_len).long() flops = FlopCountAnalysis(model.model, (x, mask, indices)) logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") logging.info(flop_count_str(flops)) return model def get_arg_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--attention", type=str, help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \ A list can be passed to test several mechanisms in sequence", dest="attention", required=True, ) parser.add_argument( "--task", type=Task, help=f"Task to chose, among {[t.value for t in Task]}.", dest="task", required=True, ) parser.add_argument( "--skip_train", type=bool, help="Whether to skip training, and test an existing model", dest="skip_train", default=False, ) parser.add_argument( "--config", type=str, help="Path to the config being used", dest="config", default="./config.json", ) parser.add_argument( "--checkpoint_dir", type=str, help="Path to the checkpoint directory", dest="checkpoint_dir", default=f"/checkpoints/{os.getenv('USER')}/xformers", ) parser.add_argument( "--checkpoint_path", type=str, help="Path to checkpoint", ) parser.add_argument( "--debug", help="Make it easier to debug a possible issue", dest="debug", default=False, action="store_true", ) parser.add_argument( "--world_size", help="Number of GPUs used", dest="world_size", type=int, default=1, ) parser.add_argument( "--sweep_parameters", help="Rewrite some hyperparameters in the config", dest="sweep_parameters", type=dict, default=None, ) return parser def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]: experiment_name = f"{task}__{attention_name}" logger = TensorBoardLogger( save_dir=args.checkpoint_dir, name="", # remove lightning_logs subdirectory version=experiment_name, ) log_dir = os.path.join(logger._save_dir, experiment_name) return log_dir, logger def rewrite_hyper(config, rewrites): def replace(config_dict, k, v): if len(k.split(":")) == 1: config_dict[k] = v return first_key = k.split(":")[0] assert first_key in config_dict, first_key k = k[len(first_key) + 1 :] replace(config_dict[first_key], k, v) for k, v in rewrites.items(): replace(config, k, v) return config def build_dataloaders( args: argparse.Namespace, config_training: Dict, num_workers: int = 4, ) -> Dict[str, DataLoader]: datasets = {} for component in ("train", "dev", "test"): datasets[component] = LRADataset( file_path=f"datasets/{args.task}.{component}.pickle", seq_len=config_training["seq_len"], ) # Gradient accumulation accumu_steps = config_training["gradient_accumulation"] logging.info(f"accumu_steps={accumu_steps}") # Batch size per_gpu_batch_size = ( config_training["batch_size"] // args.world_size // accumu_steps ) logging.warning( f"Requested batch size: {config_training['batch_size']}. Given world\ size and grad accumulation, per-gpu batch is\ {per_gpu_batch_size}" ) dataloaders = { k: DataLoader( v, batch_size=per_gpu_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers, ) for k, v in datasets.items() } return dataloaders def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]: eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step} for k, v in trainer.callback_metrics.items(): eval_summary[k] = v.item() return eval_summary class BasicProgressBar(TQDMProgressBar): def get_metrics(self, trainer, model): items = super().get_metrics(trainer, model) items.pop("v_num", None) return items def benchmark(args): log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") args.logger = logger config = load_config(args.config) config_task = config[f"{args.task}"] if args.sweep_parameters is not None: logging.info("Replacing hyperparameters") rewrite_hyper(config_task, args.sweep_parameters) config_training = config_task["training"] config_training["seq_len"] = config_task["model"]["common"]["seq_len"] logging.info(f"Learning rate: {config_training['learning_rate']}") pl.seed_everything(config_training.get("seed", 0)) dataloaders = build_dataloaders(args, config_training) model = build_model(args, config) progress_bar = BasicProgressBar() checkpoint_callback = ModelCheckpoint( monitor="val_accu", mode="max", dirpath=args.checkpoint_dir, filename="{epoch}-{val_accu:.2f}", every_n_train_steps=config_training["eval_frequency"], ) trainer = pl.Trainer( accelerator="gpu", strategy=DDPStrategy(find_unused_parameters=args.debug) if not args.skip_train else None, accumulate_grad_batches=config_training["gradient_accumulation"], callbacks=[progress_bar, checkpoint_callback], detect_anomaly=args.debug, deterministic=True, gpus=args.world_size, limit_val_batches=config_training["num_eval_steps"], logger=logger, max_steps=config_training["num_train_steps"], num_sanity_val_steps=int(not args.skip_train), precision=16 if config_training["mixed_precision"] else 32, val_check_interval=config_training["eval_frequency"] / float(len(dataloaders["train"])), ) if not args.skip_train: trainer.fit( model, train_dataloaders=dataloaders["train"], val_dataloaders=dataloaders["dev"], ) ckpt_path = checkpoint_callback.best_model_path else: ckpt_path = args.checkpoint_path trainer.test( model, dataloaders=dataloaders["test"], ckpt_path=ckpt_path, ) eval_summary = get_eval_summary(trainer) with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f: logging.info(f"Saving test results at {f.name}") json.dump(eval_summary, f) if __name__ == "__main__": parser = get_arg_parser() args = parser.parse_args() if args.skip_train and args.checkpoint_path is None: raise parser.error("Must provide --checkpoint_path if --skip_train=True") benchmark(args)