289 lines
9.5 KiB
Python
289 lines
9.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# CREDITS: adapted from the Nystromformer repo
|
|
# https://github.com/mlpen/Nystromformer
|
|
|
|
from enum import Enum
|
|
from typing import Dict, Union
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from xformers.components import build_attention
|
|
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
|
|
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
|
|
from xformers.utils import generate_matching_config
|
|
|
|
PLOutput = Dict[str, Union[float, torch.Tensor]]
|
|
|
|
|
|
class Pooling(str, Enum):
|
|
MEAN = "mean"
|
|
CLS = "cls"
|
|
|
|
|
|
def pooling(mode: Pooling):
|
|
def pool_cls(inp):
|
|
return inp[:, 0, :]
|
|
|
|
def pool_mean(inp):
|
|
return inp.mean(dim=1)
|
|
|
|
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
|
|
|
|
|
|
def append_cls(inp, mask, vocab_size):
|
|
batch_size = inp.size(0)
|
|
cls_id = (
|
|
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
|
|
).long()
|
|
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
|
|
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
|
|
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
|
|
return inp, mask
|
|
|
|
|
|
def patch_model_config(config, attention_name):
|
|
# Rebuild a specific config out of generic + extra params
|
|
commons = config["common"]
|
|
try:
|
|
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
|
|
except KeyError:
|
|
extra_attention_settings = None
|
|
|
|
for bc in config["xformer"]:
|
|
bc["dim_model"] = commons["dim_model"]
|
|
bc["position_encoding_config"].update(commons)
|
|
bc["feedforward_config"].update(commons)
|
|
bc["multi_head_config"].update(commons)
|
|
bc["multi_head_config"]["attention"].update(commons)
|
|
bc["multi_head_config"]["attention"]["name"] = attention_name
|
|
bc["multi_head_config"]["attention"]["dim_head"] = (
|
|
commons["dim_model"] / commons["num_heads"]
|
|
)
|
|
if extra_attention_settings is not None:
|
|
bc["multi_head_config"]["attention"].update(extra_attention_settings)
|
|
|
|
bc["multi_head_config"] = generate_matching_config(
|
|
bc["multi_head_config"], MultiHeadDispatchConfig
|
|
)
|
|
bc["multi_head_config"].attention = build_attention(
|
|
bc["multi_head_config"].attention
|
|
)
|
|
bc = generate_matching_config(bc, xFormerEncoderConfig)
|
|
|
|
return config
|
|
|
|
|
|
class SCHead(nn.Module):
|
|
def __init__(self, config, dim_embedding, dim_mlp):
|
|
super().__init__()
|
|
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
|
|
|
self.mlpblock = nn.Sequential(
|
|
nn.Linear(dim_embedding, dim_mlp),
|
|
nn.ReLU(),
|
|
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
|
)
|
|
|
|
def forward(self, inp: torch.Tensor):
|
|
seq_score = self.mlpblock(self.pooling(inp))
|
|
return seq_score
|
|
|
|
|
|
class SCHeadDual(nn.Module):
|
|
def __init__(self, config, dim_embedding, dim_mlp):
|
|
super().__init__()
|
|
self.pooling = pooling(Pooling(config["pooling_mode"]))
|
|
|
|
self.mlpblock = nn.Sequential(
|
|
nn.Linear(
|
|
dim_embedding * 4,
|
|
dim_mlp,
|
|
),
|
|
nn.ReLU(),
|
|
nn.Linear(dim_mlp, config["common"]["num_classes"]),
|
|
)
|
|
|
|
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
|
|
X_0 = self.pooling(inp_0)
|
|
X_1 = self.pooling(inp_1)
|
|
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
|
|
return seq_score
|
|
|
|
|
|
class ModelTrunk(pl.LightningModule):
|
|
def __init__(self, config, model_name):
|
|
super().__init__()
|
|
|
|
config_model = config["model"]
|
|
self.config_training = config["training"]
|
|
|
|
self.enable_amp = config["training"]["mixed_precision"]
|
|
self.pooling_mode = Pooling(config_model["pooling_mode"])
|
|
self.vocab_size = config_model["common"]["vocab_size"]
|
|
|
|
# Rebuild a specific config out of generic + extra params
|
|
self.config_model = patch_model_config(config_model, model_name)
|
|
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
|
|
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
|
|
|
|
ff_config = self.config_model["xformer"][0]["feedforward_config"]
|
|
self.dim_mlp = (
|
|
self.config_model["common"]["dim_model"]
|
|
* ff_config["hidden_layer_multiplier"]
|
|
)
|
|
|
|
def training_step( # type: ignore
|
|
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
|
) -> PLOutput:
|
|
outputs = self(**batch)
|
|
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
|
|
self.log("train_accu", outputs["accu"], sync_dist=True)
|
|
return outputs
|
|
|
|
def training_epoch_end(self, outputs):
|
|
logs = self.eval_epoch_end(outputs)
|
|
self.log("train_accu_mean", logs["accu"], sync_dist=True)
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.AdamW(
|
|
self.parameters(),
|
|
lr=self.config_training["learning_rate"],
|
|
betas=(0.9, 0.999),
|
|
eps=1e-6,
|
|
weight_decay=self.config_training["weight_decay"],
|
|
)
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
|
optimizer=optimizer,
|
|
max_lr=self.config_training["learning_rate"],
|
|
pct_start=self.config_training["warmup"]
|
|
/ self.config_training["num_train_steps"],
|
|
anneal_strategy=self.config_training["lr_decay"],
|
|
total_steps=self.config_training["num_train_steps"],
|
|
)
|
|
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
|
|
outputs = self(**batch)
|
|
return outputs
|
|
|
|
def eval_epoch_end(self, outputs, prefix: str = "train"):
|
|
logs = {}
|
|
counts = torch.tensor([x["count"] for x in outputs]).float()
|
|
logs["count"] = counts.sum()
|
|
for k in ("accu", "loss"):
|
|
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
|
|
"count"
|
|
]
|
|
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
|
|
return logs
|
|
|
|
def validation_step( # type: ignore
|
|
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
|
) -> PLOutput:
|
|
outputs = self.eval_step(batch, batch_idx)
|
|
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
|
|
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
|
|
return outputs
|
|
|
|
def validation_epoch_end(self, outputs):
|
|
self.eval_epoch_end(outputs, prefix="val")
|
|
|
|
def test_step( # type: ignore
|
|
self, batch: Dict[str, torch.Tensor], batch_idx: int
|
|
) -> PLOutput:
|
|
return self.eval_step(batch, batch_idx)
|
|
|
|
def test_epoch_end(self, outputs):
|
|
self.eval_epoch_end(outputs, prefix="test")
|
|
|
|
|
|
class ModelForSC(ModelTrunk):
|
|
def __init__(self, config, model_name):
|
|
# Setup trunk
|
|
super().__init__(config, model_name)
|
|
|
|
self.seq_classifer = SCHead(
|
|
self.config_model,
|
|
dim_embedding=self.config_model["common"]["dim_model"],
|
|
dim_mlp=self.dim_mlp,
|
|
)
|
|
|
|
def forward( # type: ignore
|
|
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
|
|
):
|
|
|
|
if self.pooling_mode == Pooling.CLS:
|
|
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
|
|
|
token_out = self.norm(
|
|
self.model(input_ids_0, encoder_input_mask=mask_0)
|
|
) * mask_0.unsqueeze(-1)
|
|
|
|
seq_scores = self.seq_classifer(token_out)
|
|
|
|
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
|
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
|
outputs = {
|
|
"loss": seq_loss.mean(),
|
|
"accu": seq_accu.mean(),
|
|
"count": label.size(0),
|
|
}
|
|
|
|
return outputs
|
|
|
|
|
|
class ModelForSCDual(ModelTrunk):
|
|
def __init__(self, config, model_name):
|
|
# Setup trunk
|
|
super().__init__(config, model_name)
|
|
|
|
self.seq_classifer = SCHeadDual(
|
|
self.config_model,
|
|
dim_embedding=self.config_model["common"]["dim_model"],
|
|
dim_mlp=self.dim_mlp,
|
|
)
|
|
|
|
def forward( # type: ignore
|
|
self,
|
|
input_ids_0: torch.Tensor,
|
|
input_ids_1: torch.Tensor,
|
|
mask_0: torch.Tensor,
|
|
mask_1: torch.Tensor,
|
|
label: torch.Tensor,
|
|
):
|
|
|
|
mask_0, mask_1 = mask_0.long(), mask_1.long()
|
|
|
|
if self.pooling_mode == Pooling.CLS:
|
|
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
|
|
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
|
|
|
|
# Concatenate the two inputs into one batch
|
|
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
|
|
masks = torch.cat([mask_0, mask_1], dim=0)
|
|
|
|
tokens_out = self.norm(
|
|
self.model(input_ids, encoder_input_mask=masks)
|
|
) * masks.unsqueeze(-1)
|
|
|
|
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
|
|
|
|
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
|
|
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
|
|
outputs = {
|
|
"loss": seq_loss.mean(),
|
|
"accu": seq_accu.mean(),
|
|
"count": label.size(0),
|
|
}
|
|
|
|
return outputs
|