First commit
This commit is contained in:
288
pkgs/xformers/benchmarks/LRA/code/model_wrapper.py
Normal file
288
pkgs/xformers/benchmarks/LRA/code/model_wrapper.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: adapted from the Nystromformer repo
|
||||
# https://github.com/mlpen/Nystromformer
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import build_attention
|
||||
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
|
||||
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
|
||||
from xformers.utils import generate_matching_config
|
||||
|
||||
PLOutput = Dict[str, Union[float, torch.Tensor]]
|
||||
|
||||
|
||||
class Pooling(str, Enum):
|
||||
MEAN = "mean"
|
||||
CLS = "cls"
|
||||
|
||||
|
||||
def pooling(mode: Pooling):
|
||||
def pool_cls(inp):
|
||||
return inp[:, 0, :]
|
||||
|
||||
def pool_mean(inp):
|
||||
return inp.mean(dim=1)
|
||||
|
||||
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
|
||||
|
||||
|
||||
def append_cls(inp, mask, vocab_size):
|
||||
batch_size = inp.size(0)
|
||||
cls_id = (
|
||||
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
|
||||
).long()
|
||||
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
|
||||
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
|
||||
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
|
||||
return inp, mask
|
||||
|
||||
|
||||
def patch_model_config(config, attention_name):
|
||||
# Rebuild a specific config out of generic + extra params
|
||||
commons = config["common"]
|
||||
try:
|
||||
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
|
||||
except KeyError:
|
||||
extra_attention_settings = None
|
||||
|
||||
for bc in config["xformer"]:
|
||||
bc["dim_model"] = commons["dim_model"]
|
||||
bc["position_encoding_config"].update(commons)
|
||||
bc["feedforward_config"].update(commons)
|
||||
bc["multi_head_config"].update(commons)
|
||||
bc["multi_head_config"]["attention"].update(commons)
|
||||
bc["multi_head_config"]["attention"]["name"] = attention_name
|
||||
bc["multi_head_config"]["attention"]["dim_head"] = (
|
||||
commons["dim_model"] / commons["num_heads"]
|
||||
)
|
||||
if extra_attention_settings is not None:
|
||||
bc["multi_head_config"]["attention"].update(extra_attention_settings)
|
||||
|
||||
bc["multi_head_config"] = generate_matching_config(
|
||||
bc["multi_head_config"], MultiHeadDispatchConfig
|
||||
)
|
||||
bc["multi_head_config"].attention = build_attention(
|
||||
bc["multi_head_config"].attention
|
||||
)
|
||||
bc = generate_matching_config(bc, xFormerEncoderConfig)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class SCHead(nn.Module):
|
||||
def __init__(self, config, dim_embedding, dim_mlp):
|
||||
super().__init__()
|
||||
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
||||
|
||||
self.mlpblock = nn.Sequential(
|
||||
nn.Linear(dim_embedding, dim_mlp),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
||||
)
|
||||
|
||||
def forward(self, inp: torch.Tensor):
|
||||
seq_score = self.mlpblock(self.pooling(inp))
|
||||
return seq_score
|
||||
|
||||
|
||||
class SCHeadDual(nn.Module):
|
||||
def __init__(self, config, dim_embedding, dim_mlp):
|
||||
super().__init__()
|
||||
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
||||
|
||||
self.mlpblock = nn.Sequential(
|
||||
nn.Linear(
|
||||
dim_embedding * 4,
|
||||
dim_mlp,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
||||
)
|
||||
|
||||
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
|
||||
X_0 = self.pooling(inp_0)
|
||||
X_1 = self.pooling(inp_1)
|
||||
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
|
||||
return seq_score
|
||||
|
||||
|
||||
class ModelTrunk(pl.LightningModule):
|
||||
def __init__(self, config, model_name):
|
||||
super().__init__()
|
||||
|
||||
config_model = config["model"]
|
||||
self.config_training = config["training"]
|
||||
|
||||
self.enable_amp = config["training"]["mixed_precision"]
|
||||
self.pooling_mode = Pooling(config_model["pooling_mode"])
|
||||
self.vocab_size = config_model["common"]["vocab_size"]
|
||||
|
||||
# Rebuild a specific config out of generic + extra params
|
||||
self.config_model = patch_model_config(config_model, model_name)
|
||||
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
|
||||
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
|
||||
|
||||
ff_config = self.config_model["xformer"][0]["feedforward_config"]
|
||||
self.dim_mlp = (
|
||||
self.config_model["common"]["dim_model"]
|
||||
* ff_config["hidden_layer_multiplier"]
|
||||
)
|
||||
|
||||
def training_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
outputs = self(**batch)
|
||||
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
|
||||
self.log("train_accu", outputs["accu"], sync_dist=True)
|
||||
return outputs
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
logs = self.eval_epoch_end(outputs)
|
||||
self.log("train_accu_mean", logs["accu"], sync_dist=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=self.config_training["learning_rate"],
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=self.config_training["weight_decay"],
|
||||
)
|
||||
|
||||
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
optimizer=optimizer,
|
||||
max_lr=self.config_training["learning_rate"],
|
||||
pct_start=self.config_training["warmup"]
|
||||
/ self.config_training["num_train_steps"],
|
||||
anneal_strategy=self.config_training["lr_decay"],
|
||||
total_steps=self.config_training["num_train_steps"],
|
||||
)
|
||||
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
|
||||
outputs = self(**batch)
|
||||
return outputs
|
||||
|
||||
def eval_epoch_end(self, outputs, prefix: str = "train"):
|
||||
logs = {}
|
||||
counts = torch.tensor([x["count"] for x in outputs]).float()
|
||||
logs["count"] = counts.sum()
|
||||
for k in ("accu", "loss"):
|
||||
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
|
||||
"count"
|
||||
]
|
||||
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
|
||||
return logs
|
||||
|
||||
def validation_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
outputs = self.eval_step(batch, batch_idx)
|
||||
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
|
||||
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
|
||||
return outputs
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.eval_epoch_end(outputs, prefix="val")
|
||||
|
||||
def test_step( # type: ignore
|
||||
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
||||
) -> PLOutput:
|
||||
return self.eval_step(batch, batch_idx)
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
self.eval_epoch_end(outputs, prefix="test")
|
||||
|
||||
|
||||
class ModelForSC(ModelTrunk):
|
||||
def __init__(self, config, model_name):
|
||||
# Setup trunk
|
||||
super().__init__(config, model_name)
|
||||
|
||||
self.seq_classifer = SCHead(
|
||||
self.config_model,
|
||||
dim_embedding=self.config_model["common"]["dim_model"],
|
||||
dim_mlp=self.dim_mlp,
|
||||
)
|
||||
|
||||
def forward( # type: ignore
|
||||
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
|
||||
):
|
||||
|
||||
if self.pooling_mode == Pooling.CLS:
|
||||
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
||||
|
||||
token_out = self.norm(
|
||||
self.model(input_ids_0, encoder_input_mask=mask_0)
|
||||
) * mask_0.unsqueeze(-1)
|
||||
|
||||
seq_scores = self.seq_classifer(token_out)
|
||||
|
||||
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
||||
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
||||
outputs = {
|
||||
"loss": seq_loss.mean(),
|
||||
"accu": seq_accu.mean(),
|
||||
"count": label.size(0),
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ModelForSCDual(ModelTrunk):
|
||||
def __init__(self, config, model_name):
|
||||
# Setup trunk
|
||||
super().__init__(config, model_name)
|
||||
|
||||
self.seq_classifer = SCHeadDual(
|
||||
self.config_model,
|
||||
dim_embedding=self.config_model["common"]["dim_model"],
|
||||
dim_mlp=self.dim_mlp,
|
||||
)
|
||||
|
||||
def forward( # type: ignore
|
||||
self,
|
||||
input_ids_0: torch.Tensor,
|
||||
input_ids_1: torch.Tensor,
|
||||
mask_0: torch.Tensor,
|
||||
mask_1: torch.Tensor,
|
||||
label: torch.Tensor,
|
||||
):
|
||||
|
||||
mask_0, mask_1 = mask_0.long(), mask_1.long()
|
||||
|
||||
if self.pooling_mode == Pooling.CLS:
|
||||
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
||||
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
|
||||
|
||||
# Concatenate the two inputs into one batch
|
||||
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
|
||||
masks = torch.cat([mask_0, mask_1], dim=0)
|
||||
|
||||
tokens_out = self.norm(
|
||||
self.model(input_ids, encoder_input_mask=masks)
|
||||
) * masks.unsqueeze(-1)
|
||||
|
||||
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
|
||||
|
||||
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
||||
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
||||
outputs = {
|
||||
"loss": seq_loss.mean(),
|
||||
"accu": seq_accu.mean(),
|
||||
"count": label.size(0),
|
||||
}
|
||||
|
||||
return outputs
|
||||
Reference in New Issue
Block a user