First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict
if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
"Collect results from a given batch of distributed results"
)
parser.add_argument("-ck", "--checkpoint_path", required=True)
args = parser.parse_args()
logging.getLogger().setLevel(logging.INFO)
# Go through all the data in the given repo, try to find the end results
root = Path(args.checkpoint_path)
# - list all the mechanisms being benchmarked
results: Dict[str, Any] = {}
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
logging.info(f"\nFound results for {attention.stem}")
task_jsons = attention.glob("*/test_eval_summary.json")
results[attention.stem] = {}
for task in task_jsons:
task_name = task.stem.split("__")[0]
logging.info(f"Logs found for task: {task_name}")
results[attention.stem][task_name] = -1
found_result = False
# - collect the individual results
with open(task, "r") as result_file:
dct = json.load(result_file)
if "test_accu_mean" in dct:
found_result = True
results[attention.stem][task_name] = dct["test_accu_mean"]
logging.info(
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
f"{results[attention.stem][task_name]}"
)
else:
break
# - report an error if no result was found
if not found_result:
ERR_TAIL = 30
logging.warning(
f"No result found for {task_name}, showing the error log in {task.parent}"
)
err_log = Path(task.parent).glob("*.err")
print("*****************************************************")
with open(next(err_log), "r") as err_file:
for i, line in enumerate(reversed(err_file.readlines())):
print(line, end="")
if i > ERR_TAIL:
break
print("*****************************************************")
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
# - reduction: compute the average
tasks = set(t for v in results.values() for t in v.keys())
# -- fill in the possible gaps
for att in results.keys():
for t in tasks:
if t not in results[att].keys():
results[att][t] = 0.0
# -- add the average value
for att in results.keys():
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
# - Format as an array, markdown style
tasks_sort = sorted(
set(t for v in results.values() for t in v.keys()), reverse=True
)
print(
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
)
for att in results.keys():
print(
"{0:<20}".format(att)
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
)

View File

@@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from pathlib import Path
from xformers.benchmarks.LRA.run_tasks import Task
from xformers.components.attention import ATTENTION_REGISTRY
def get_default_shared_folder() -> str:
checkpoint_paths = ["/checkpoint", "/checkpoints"]
for checkpoint_path in checkpoint_paths:
if Path(checkpoint_path).is_dir():
return checkpoint_path
return "."
if __name__ == "__main__":
default_checkpoint_path = get_default_shared_folder()
# Get the user requests
parser = argparse.ArgumentParser(
"Benchmark different attention mechanisms on various sequence lengths"
)
parser.add_argument("-c", "--config_path", required=True)
parser.add_argument("-ck", "--checkpoint_path", required=True)
parser.add_argument(
"-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys())
)
parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task])
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
args = parser.parse_args()
for attention in args.attentions:
for task in args.tasks:
os.system(
"python3 run_with_submitit.py"
+ f" --attention {attention} --task {task} --config {args.config_path}"
+ f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}"
+ f" --partition {args.partition}"
)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Almost as-is from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
import logging
import pickle
import torch
from torch.utils.data.dataset import Dataset
logging.getLogger().setLevel(logging.INFO)
class LRADataset(Dataset):
def __init__(self, file_path, seq_len):
with open(file_path, "rb") as f:
self.examples = pickle.load(f)
self.seq_len = seq_len
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.create_inst(self.examples[i], self.seq_len)
@staticmethod
def create_inst(inst, seq_len):
output = {
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
}
output["mask_0"] = (output["input_ids_0"] != 0).float()
if "input_ids_1" in inst:
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
:seq_len
]
output["mask_1"] = (output["input_ids_1"] != 0).float()
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
return output

View File

@@ -0,0 +1,288 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: adapted from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
from enum import Enum
from typing import Dict, Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
from xformers.components import build_attention
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
from xformers.utils import generate_matching_config
PLOutput = Dict[str, Union[float, torch.Tensor]]
class Pooling(str, Enum):
MEAN = "mean"
CLS = "cls"
def pooling(mode: Pooling):
def pool_cls(inp):
return inp[:, 0, :]
def pool_mean(inp):
return inp.mean(dim=1)
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
def append_cls(inp, mask, vocab_size):
batch_size = inp.size(0)
cls_id = (
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
).long()
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
return inp, mask
def patch_model_config(config, attention_name):
# Rebuild a specific config out of generic + extra params
commons = config["common"]
try:
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
except KeyError:
extra_attention_settings = None
for bc in config["xformer"]:
bc["dim_model"] = commons["dim_model"]
bc["position_encoding_config"].update(commons)
bc["feedforward_config"].update(commons)
bc["multi_head_config"].update(commons)
bc["multi_head_config"]["attention"].update(commons)
bc["multi_head_config"]["attention"]["name"] = attention_name
bc["multi_head_config"]["attention"]["dim_head"] = (
commons["dim_model"] / commons["num_heads"]
)
if extra_attention_settings is not None:
bc["multi_head_config"]["attention"].update(extra_attention_settings)
bc["multi_head_config"] = generate_matching_config(
bc["multi_head_config"], MultiHeadDispatchConfig
)
bc["multi_head_config"].attention = build_attention(
bc["multi_head_config"].attention
)
bc = generate_matching_config(bc, xFormerEncoderConfig)
return config
class SCHead(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(dim_embedding, dim_mlp),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp: torch.Tensor):
seq_score = self.mlpblock(self.pooling(inp))
return seq_score
class SCHeadDual(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(
dim_embedding * 4,
dim_mlp,
),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
X_0 = self.pooling(inp_0)
X_1 = self.pooling(inp_1)
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
return seq_score
class ModelTrunk(pl.LightningModule):
def __init__(self, config, model_name):
super().__init__()
config_model = config["model"]
self.config_training = config["training"]
self.enable_amp = config["training"]["mixed_precision"]
self.pooling_mode = Pooling(config_model["pooling_mode"])
self.vocab_size = config_model["common"]["vocab_size"]
# Rebuild a specific config out of generic + extra params
self.config_model = patch_model_config(config_model, model_name)
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
ff_config = self.config_model["xformer"][0]["feedforward_config"]
self.dim_mlp = (
self.config_model["common"]["dim_model"]
* ff_config["hidden_layer_multiplier"]
)
def training_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self(**batch)
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("train_accu", outputs["accu"], sync_dist=True)
return outputs
def training_epoch_end(self, outputs):
logs = self.eval_epoch_end(outputs)
self.log("train_accu_mean", logs["accu"], sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.config_training["learning_rate"],
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=self.config_training["weight_decay"],
)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.config_training["learning_rate"],
pct_start=self.config_training["warmup"]
/ self.config_training["num_train_steps"],
anneal_strategy=self.config_training["lr_decay"],
total_steps=self.config_training["num_train_steps"],
)
return [optimizer], [lr_scheduler]
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
outputs = self(**batch)
return outputs
def eval_epoch_end(self, outputs, prefix: str = "train"):
logs = {}
counts = torch.tensor([x["count"] for x in outputs]).float()
logs["count"] = counts.sum()
for k in ("accu", "loss"):
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
"count"
]
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
return logs
def validation_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self.eval_step(batch, batch_idx)
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
return outputs
def validation_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="val")
def test_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
return self.eval_step(batch, batch_idx)
def test_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="test")
class ModelForSC(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHead(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
):
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)
seq_scores = self.seq_classifer(token_out)
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs
class ModelForSCDual(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHeadDual(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self,
input_ids_0: torch.Tensor,
input_ids_1: torch.Tensor,
mask_0: torch.Tensor,
mask_1: torch.Tensor,
label: torch.Tensor,
):
mask_0, mask_1 = mask_0.long(), mask_1.long()
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)
tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs

View File

@@ -0,0 +1,148 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import os
import uuid
from datetime import date
from pathlib import Path
from typing import Dict, Iterable
import submitit
from xformers.benchmarks.LRA.run_with_submitit import (
Trainer,
get_init_file,
get_shared_folder,
parse_args,
)
def grid_parameters(grid: Dict):
"""
Yield all combinations of parameters in the grid (as a dict)
"""
grid_copy = dict(grid)
# Turn single value in an Iterable
for k in grid_copy:
if not isinstance(grid_copy[k], Iterable):
grid_copy[k] = [grid_copy[k]]
for p in itertools.product(*grid_copy.values()):
yield dict(zip(grid.keys(), p))
def grid_search(args):
if args.checkpoint_dir == "":
args.checkpoint_dir = get_shared_folder() / "%j"
date_curr = date.today().strftime("%m-%d-%Y")
orig_check_dir = os.path.join(args.checkpoint_dir, date_curr)
# Create the executor
# Note that the folder will depend on the job_id, to easily track experiments
executor = submitit.AutoExecutor(
folder=get_shared_folder() / "%j", slurm_max_num_timeout=30
)
num_gpus_per_node = args.ngpus
nodes = args.nodes
args.world_size = args.nodes * args.ngpus
partition = args.partition
executor.update_parameters(
gpus_per_node=num_gpus_per_node,
tasks_per_node=num_gpus_per_node, # one task per GPU
cpus_per_task=10,
nodes=nodes,
timeout_min=60 * 72,
slurm_signal_delay_s=120,
slurm_partition=partition,
)
executor.update_parameters(name="lra")
if args.task == "text":
grid_meta = {
"training:learning_rate": (
[1e-4, 2e-4, 3e-4, 5e-5],
lambda val: f"lr{val}",
),
"training:warmup": ([3000, 8000], lambda val: f"warmup{val}"),
"training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"),
"training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"),
}
elif args.task == "retrieval":
grid_meta = {
"training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"),
"training:warmup": ([2000, 8000], lambda val: f"warmup{val}"),
"training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"),
"training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0], lambda val: f"drop{val}"),
}
elif args.task == "listops":
grid_meta = {
"training:learning_rate": (
[1e-4, 2e-4, 3e-4, 5e-5],
lambda val: f"lr{val}",
),
"training:warmup": ([3000, 2000], lambda val: f"warmup{val}"),
"training:seed": (
[
1234,
],
lambda val: f"seed{val}",
),
"training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0], lambda val: f"drop{val}"),
}
else:
grid_meta = {
"training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"),
"training:warmup": ([8000], lambda val: f"warmup{val}"),
"training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"),
"training:weight_decay": ([0.01], lambda val: f"wd{val}"),
"model:pooling_model": (["cls"], lambda val: f"pool-{val}"),
"model:common:dropout": ([0.1], lambda val: f"drop{val}"),
}
grid = {k: v[0] for k, v in grid_meta.items()}
save_key = {k: v[1] for k, v in grid_meta.items()}
hyper_parameters = list(grid_parameters(grid))
jobs = []
for i, grid_data in enumerate(hyper_parameters):
args.sweep_parameters = grid_data
run_name = f"{args.attention}"
# run_name = "paper_config"
for k, v in grid_data.items():
run_name += "prenorm-" + save_key[k](v)
args.checkpoint_dir = os.path.join(
orig_check_dir, f"{args.task}", "logs", run_name
)
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name)
Path(args.tb_dir).mkdir(parents=True, exist_ok=True)
# Chronos needs a different job name each time
executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}")
args.dist_url = get_init_file().as_uri()
args.temp_file = str(get_init_file())
trainer = Trainer(args)
job = executor.submit(trainer)
jobs.append(job)
print(f"Run {i:02d} submitted with train cfg: {args}")
print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}")
if __name__ == "__main__":
args = parse_args()
grid_search(args)

View File

@@ -0,0 +1,298 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
import os
from enum import Enum
from pathlib import Path
from typing import Dict, Tuple, cast
import pytorch_lightning as pl
import torch
import torch.nn as nn
from fvcore.nn import FlopCountAnalysis, flop_count_str
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from torch.utils.data import DataLoader
from xformers.benchmarks.LRA.code.dataset import LRADataset
from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual
from xformers.components.attention import ATTENTION_REGISTRY
class Task(str, Enum):
Retrieval = "retrieval"
ListOps = "listops"
Image = "image"
PathfinderBaseline = "pathfinder32-curv_baseline"
PathfinderContour9 = "pathfinder32-curv_contour_length_9"
PathfinderContour14 = "pathfinder32-curv_contour_length_14"
Text = "text"
def load_config(path: str) -> Dict:
with open(Path(path).absolute(), "r") as fileio:
config = json.load(fileio)
# Duplicate the pathfinder configs
config["pathfinder32-curv_baseline"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"]
config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"]
return config
def build_model(args: argparse.Namespace, config: Dict) -> nn.Module:
task = args.task
attention_name = args.attention
model = cast(
pl.LightningModule,
ModelForSCDual(config[f"{task}"], attention_name)
if task == Task.Retrieval
else ModelForSC(config[f"{task}"], attention_name),
)
logging.info(model)
summary = pl.utilities.model_summary.LayerSummary(model)
logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M")
with torch.no_grad():
# Check the flops
seq_len = config[f"{task}"]["model"]["common"]["seq_len"]
x = torch.rand(1, seq_len).long()
mask = torch.rand(1, seq_len).long()
indices = torch.rand(1, seq_len).long()
flops = FlopCountAnalysis(model.model, (x, mask, indices))
logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops")
logging.info(flop_count_str(flops))
return model
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--attention",
type=str,
help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \
A list can be passed to test several mechanisms in sequence",
dest="attention",
required=True,
)
parser.add_argument(
"--task",
type=Task,
help=f"Task to chose, among {[t.value for t in Task]}.",
dest="task",
required=True,
)
parser.add_argument(
"--skip_train",
type=bool,
help="Whether to skip training, and test an existing model",
dest="skip_train",
default=False,
)
parser.add_argument(
"--config",
type=str,
help="Path to the config being used",
dest="config",
default="./config.json",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
help="Path to the checkpoint directory",
dest="checkpoint_dir",
default=f"/checkpoints/{os.getenv('USER')}/xformers",
)
parser.add_argument(
"--checkpoint_path",
type=str,
help="Path to checkpoint",
)
parser.add_argument(
"--debug",
help="Make it easier to debug a possible issue",
dest="debug",
default=False,
action="store_true",
)
parser.add_argument(
"--world_size",
help="Number of GPUs used",
dest="world_size",
type=int,
default=1,
)
parser.add_argument(
"--sweep_parameters",
help="Rewrite some hyperparameters in the config",
dest="sweep_parameters",
type=dict,
default=None,
)
return parser
def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]:
experiment_name = f"{task}__{attention_name}"
logger = TensorBoardLogger(
save_dir=args.checkpoint_dir,
name="", # remove lightning_logs subdirectory
version=experiment_name,
)
log_dir = os.path.join(logger._save_dir, experiment_name)
return log_dir, logger
def rewrite_hyper(config, rewrites):
def replace(config_dict, k, v):
if len(k.split(":")) == 1:
config_dict[k] = v
return
first_key = k.split(":")[0]
assert first_key in config_dict, first_key
k = k[len(first_key) + 1 :]
replace(config_dict[first_key], k, v)
for k, v in rewrites.items():
replace(config, k, v)
return config
def build_dataloaders(
args: argparse.Namespace,
config_training: Dict,
num_workers: int = 4,
) -> Dict[str, DataLoader]:
datasets = {}
for component in ("train", "dev", "test"):
datasets[component] = LRADataset(
file_path=f"datasets/{args.task}.{component}.pickle",
seq_len=config_training["seq_len"],
)
# Gradient accumulation
accumu_steps = config_training["gradient_accumulation"]
logging.info(f"accumu_steps={accumu_steps}")
# Batch size
per_gpu_batch_size = (
config_training["batch_size"] // args.world_size // accumu_steps
)
logging.warning(
f"Requested batch size: {config_training['batch_size']}. Given world\
size and grad accumulation, per-gpu batch is\
{per_gpu_batch_size}"
)
dataloaders = {
k: DataLoader(
v,
batch_size=per_gpu_batch_size,
shuffle=False,
pin_memory=True,
num_workers=num_workers,
)
for k, v in datasets.items()
}
return dataloaders
def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]:
eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step}
for k, v in trainer.callback_metrics.items():
eval_summary[k] = v.item()
return eval_summary
class BasicProgressBar(TQDMProgressBar):
def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
return items
def benchmark(args):
log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}")
args.logger = logger
config = load_config(args.config)
config_task = config[f"{args.task}"]
if args.sweep_parameters is not None:
logging.info("Replacing hyperparameters")
rewrite_hyper(config_task, args.sweep_parameters)
config_training = config_task["training"]
config_training["seq_len"] = config_task["model"]["common"]["seq_len"]
logging.info(f"Learning rate: {config_training['learning_rate']}")
pl.seed_everything(config_training.get("seed", 0))
dataloaders = build_dataloaders(args, config_training)
model = build_model(args, config)
progress_bar = BasicProgressBar()
checkpoint_callback = ModelCheckpoint(
monitor="val_accu",
mode="max",
dirpath=args.checkpoint_dir,
filename="{epoch}-{val_accu:.2f}",
every_n_train_steps=config_training["eval_frequency"],
)
trainer = pl.Trainer(
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=args.debug)
if not args.skip_train
else None,
accumulate_grad_batches=config_training["gradient_accumulation"],
callbacks=[progress_bar, checkpoint_callback],
detect_anomaly=args.debug,
deterministic=True,
gpus=args.world_size,
limit_val_batches=config_training["num_eval_steps"],
logger=logger,
max_steps=config_training["num_train_steps"],
num_sanity_val_steps=int(not args.skip_train),
precision=16 if config_training["mixed_precision"] else 32,
val_check_interval=config_training["eval_frequency"]
/ float(len(dataloaders["train"])),
)
if not args.skip_train:
trainer.fit(
model,
train_dataloaders=dataloaders["train"],
val_dataloaders=dataloaders["dev"],
)
ckpt_path = checkpoint_callback.best_model_path
else:
ckpt_path = args.checkpoint_path
trainer.test(
model,
dataloaders=dataloaders["test"],
ckpt_path=ckpt_path,
)
eval_summary = get_eval_summary(trainer)
with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f:
logging.info(f"Saving test results at {f.name}")
json.dump(eval_summary, f)
if __name__ == "__main__":
parser = get_arg_parser()
args = parser.parse_args()
if args.skip_train and args.checkpoint_path is None:
raise parser.error("Must provide --checkpoint_path if --skip_train=True")
benchmark(args)

View File

@@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
A script to run multinode training with submitit.
Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py
"""
import argparse
import os
import uuid
from pathlib import Path
import submitit
from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser
def parse_args():
parser = argparse.ArgumentParser(
"Submitit for LRA", parents=[get_arg_parser()], add_help=False
)
parser.add_argument(
"--ngpus", default=1, type=int, help="Number of gpus to request on each node"
)
parser.add_argument(
"--nodes", default=1, type=int, help="Number of nodes to request"
)
parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job")
parser.add_argument(
"--partition", default="a100", type=str, help="Partition where to submit"
)
parser.add_argument(
"--use_volta32", action="store_true", help="Big models? Use this"
)
parser.add_argument(
"--enforce_host_memory", action="store_true", help="Use if the host OOMs"
)
parser.add_argument(
"--comment",
default="",
type=str,
help="Comment to pass to scheduler, e.g. priority message",
)
return parser.parse_args()
def get_shared_folder() -> Path:
user = os.getenv("USER")
checkpoint_paths = ["/checkpoint", "/checkpoints"]
for checkpoint_path in checkpoint_paths:
if Path(checkpoint_path).is_dir():
p = Path(f"{checkpoint_path}/{user}/xformers/submitit")
p.mkdir(exist_ok=True, parents=True)
return p
raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}")
def get_init_file():
# Init file must not exist, but it's parent dir must exist.
os.makedirs(str(get_shared_folder()), exist_ok=True)
init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
if init_file.exists():
os.remove(str(init_file))
return init_file
class Trainer:
def __init__(self, args):
self.args = args
def __call__(self):
self._setup_gpu_args()
benchmark(self.args)
def checkpoint(self):
self.args.dist_url = get_init_file().as_uri()
print("Requeuing ", self.args)
empty_trainer = type(self)(self.args)
return submitit.helpers.DelayedSubmission(empty_trainer)
def _setup_gpu_args(self):
job_env = submitit.JobEnvironment()
self.args.checkpoint_dir = Path(
str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id))
)
self.args.gpu = job_env.local_rank
self.args.rank = job_env.global_rank
self.args.world_size = job_env.num_tasks
print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
def main():
args = parse_args()
if args.checkpoint_dir == "":
args.checkpoint_dir = get_shared_folder() / "%j"
Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True)
executor = submitit.AutoExecutor(
folder=args.checkpoint_dir, slurm_max_num_timeout=30
)
num_gpus_per_node = args.ngpus
nodes = args.nodes
timeout_min = args.timeout
args.world_size = args.nodes * args.ngpus
partition = args.partition
kwargs = {
"gpus_per_node": num_gpus_per_node,
"tasks_per_node": num_gpus_per_node, # one task per GPU
"cpus_per_task": 10,
"nodes": nodes,
"timeout_min": timeout_min, # max is 60 * 72
# Below are cluster dependent parameters
"slurm_partition": partition,
"slurm_signal_delay_s": 120,
}
if args.enforce_host_memory:
kwargs["mem_gb"] = (40 * num_gpus_per_node,)
if args.use_volta32:
kwargs["slurm_constraint"] = "volta32gb"
if args.comment:
kwargs["slurm_comment"] = args.comment
executor.update_parameters(
**kwargs,
)
executor.update_parameters(name="lra")
args.dist_url = get_init_file().as_uri()
args.temp_file = str(get_init_file())
trainer = Trainer(args)
job = executor.submit(trainer)
print(f"Submitted job_id: {job.job_id}")
print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}")
with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile:
jobfile.write(f"{job.job_id}\n")
if __name__ == "__main__":
main()