Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/LRA/code/model_wrapper.py
2025-08-05 19:02:46 +08:00

289 lines
9.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: adapted from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
from enum import Enum
from typing import Dict, Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
from xformers.components import build_attention
from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig
from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig
from xformers.utils import generate_matching_config
PLOutput = Dict[str, Union[float, torch.Tensor]]
class Pooling(str, Enum):
MEAN = "mean"
CLS = "cls"
def pooling(mode: Pooling):
def pool_cls(inp):
return inp[:, 0, :]
def pool_mean(inp):
return inp.mean(dim=1)
return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode]
def append_cls(inp, mask, vocab_size):
batch_size = inp.size(0)
cls_id = (
(vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device)
).long()
cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device)
inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1)
mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1)
return inp, mask
def patch_model_config(config, attention_name):
# Rebuild a specific config out of generic + extra params
commons = config["common"]
try:
extra_attention_settings = config["extra_settings"]["attention"][attention_name]
except KeyError:
extra_attention_settings = None
for bc in config["xformer"]:
bc["dim_model"] = commons["dim_model"]
bc["position_encoding_config"].update(commons)
bc["feedforward_config"].update(commons)
bc["multi_head_config"].update(commons)
bc["multi_head_config"]["attention"].update(commons)
bc["multi_head_config"]["attention"]["name"] = attention_name
bc["multi_head_config"]["attention"]["dim_head"] = (
commons["dim_model"] / commons["num_heads"]
)
if extra_attention_settings is not None:
bc["multi_head_config"]["attention"].update(extra_attention_settings)
bc["multi_head_config"] = generate_matching_config(
bc["multi_head_config"], MultiHeadDispatchConfig
)
bc["multi_head_config"].attention = build_attention(
bc["multi_head_config"].attention
)
bc = generate_matching_config(bc, xFormerEncoderConfig)
return config
class SCHead(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(dim_embedding, dim_mlp),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp: torch.Tensor):
seq_score = self.mlpblock(self.pooling(inp))
return seq_score
class SCHeadDual(nn.Module):
def __init__(self, config, dim_embedding, dim_mlp):
super().__init__()
self.pooling = pooling(Pooling(config["pooling_mode"]))
self.mlpblock = nn.Sequential(
nn.Linear(
dim_embedding * 4,
dim_mlp,
),
nn.ReLU(),
nn.Linear(dim_mlp, config["common"]["num_classes"]),
)
def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor):
X_0 = self.pooling(inp_0)
X_1 = self.pooling(inp_1)
seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1))
return seq_score
class ModelTrunk(pl.LightningModule):
def __init__(self, config, model_name):
super().__init__()
config_model = config["model"]
self.config_training = config["training"]
self.enable_amp = config["training"]["mixed_precision"]
self.pooling_mode = Pooling(config_model["pooling_mode"])
self.vocab_size = config_model["common"]["vocab_size"]
# Rebuild a specific config out of generic + extra params
self.config_model = patch_model_config(config_model, model_name)
self.model = xFormer.from_config(xFormerConfig(config_model["xformer"]))
self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"])
ff_config = self.config_model["xformer"][0]["feedforward_config"]
self.dim_mlp = (
self.config_model["common"]["dim_model"]
* ff_config["hidden_layer_multiplier"]
)
def training_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self(**batch)
self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("train_accu", outputs["accu"], sync_dist=True)
return outputs
def training_epoch_end(self, outputs):
logs = self.eval_epoch_end(outputs)
self.log("train_accu_mean", logs["accu"], sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.config_training["learning_rate"],
betas=(0.9, 0.999),
eps=1e-6,
weight_decay=self.config_training["weight_decay"],
)
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=self.config_training["learning_rate"],
pct_start=self.config_training["warmup"]
/ self.config_training["num_train_steps"],
anneal_strategy=self.config_training["lr_decay"],
total_steps=self.config_training["num_train_steps"],
)
return [optimizer], [lr_scheduler]
def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput:
outputs = self(**batch)
return outputs
def eval_epoch_end(self, outputs, prefix: str = "train"):
logs = {}
counts = torch.tensor([x["count"] for x in outputs]).float()
logs["count"] = counts.sum()
for k in ("accu", "loss"):
logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[
"count"
]
self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True)
return logs
def validation_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
outputs = self.eval_step(batch, batch_idx)
self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore
self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True)
return outputs
def validation_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="val")
def test_step( # type: ignore
self, batch: Dict[str, torch.Tensor], batch_idx: int
) -> PLOutput:
return self.eval_step(batch, batch_idx)
def test_epoch_end(self, outputs):
self.eval_epoch_end(outputs, prefix="test")
class ModelForSC(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHead(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor
):
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
token_out = self.norm(
self.model(input_ids_0, encoder_input_mask=mask_0)
) * mask_0.unsqueeze(-1)
seq_scores = self.seq_classifer(token_out)
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs
class ModelForSCDual(ModelTrunk):
def __init__(self, config, model_name):
# Setup trunk
super().__init__(config, model_name)
self.seq_classifer = SCHeadDual(
self.config_model,
dim_embedding=self.config_model["common"]["dim_model"],
dim_mlp=self.dim_mlp,
)
def forward( # type: ignore
self,
input_ids_0: torch.Tensor,
input_ids_1: torch.Tensor,
mask_0: torch.Tensor,
mask_1: torch.Tensor,
label: torch.Tensor,
):
mask_0, mask_1 = mask_0.long(), mask_1.long()
if self.pooling_mode == Pooling.CLS:
input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size)
input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size)
# Concatenate the two inputs into one batch
input_ids = torch.cat([input_ids_0, input_ids_1], dim=0)
masks = torch.cat([mask_0, mask_1], dim=0)
tokens_out = self.norm(
self.model(input_ids, encoder_input_mask=masks)
) * masks.unsqueeze(-1)
seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0))
seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label)
seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32)
outputs = {
"loss": seq_loss.mean(),
"accu": seq_accu.mean(),
"count": label.size(0),
}
return outputs