First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,531 @@
# Copyright (c) 2022, Tri Dao.
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import re
import logging
from functools import partial
from collections.abc import Sequence
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
from flash_attn.utils.pretrained import state_dict_from_pretrained
try:
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDense = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
except ImportError:
dropout_add_layer_norm, layer_norm = None, None
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
except ImportError:
CrossEntropyLoss = None
logger = logging.getLogger(__name__)
def create_mixer_cls(config, cross_attn=False, return_residual=False):
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
rotary_kwargs = {}
if config.position_embedding_type == "rotary":
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
dropout=config.attention_probs_dropout_prob, causal=False,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
return_residual=return_residual, **rotary_kwargs)
return mixer_cls
def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
'supports approximate gelu')
if not fused_mlp:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls
def create_block(config, layer_idx=None):
last_layer_subset = getattr(config, 'last_layer_subset', False)
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
return_residual = not cross_attn
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
resid_dropout2=config.hidden_dropout_prob,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
return_residual=return_residual)
return block
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
def _init_weights(module, initializer_range=0.02):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if module.padding_idx is not None:
nn.init.zeros_(module.weight[module.padding_idx])
class BertEncoder(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
"""If subset_mask is not None, we only want output for the subset of the sequence.
This means that we only compute the last layer output for these tokens.
subset_mask: (batch, seqlen), dtype=torch.bool
"""
if key_padding_mask is None or not self.use_flash_attn:
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
if key_padding_mask is not None else None)
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if subset_mask is not None:
hidden_states = hidden_states[subset_mask]
else:
batch, seqlen = hidden_states.shape[:2]
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
hidden_states, key_padding_mask
)
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
if subset_mask is None:
for layer in self.layers:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
else:
for layer in self.layers[:-1]:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if key_padding_mask is not None:
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
else:
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
dtype=torch.torch.int32), (1, 0))
hidden_states_subset, hidden_states = index_first_axis_residual(
hidden_states, subset_idx
)
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs = {'x_kv': hidden_states,
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
return hidden_states
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states, pool=True):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states)
else:
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
self.layer_norm.eps)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertPreTrainingHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
config), strict=False)
logger.info(load_return)
return model
class BertModel(BertPreTrainedModel):
def __init__(self, config: BertConfig, add_pooling_layer=True):
super().__init__(config)
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if config.vocab_size % self.pad_vocab_size_multiple != 0:
config.vocab_size += (self.pad_vocab_size_multiple
- (config.vocab_size % self.pad_vocab_size_multiple))
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings, config.type_vocab_size,
padding_idx=config.pad_token_id)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
masked_tokens_mask=None):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
token_type_ids=token_type_ids)
# TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
self.emb_ln.eps)
hidden_states = self.emb_drop(hidden_states)
if masked_tokens_mask is not None:
batch_size, seqlen = input_ids.shape[:2]
# We also need the first column for the CLS token
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
device=input_ids.device)
first_col_mask[:, 0] = True
subset_mask = masked_tokens_mask | first_col_mask
else:
subset_mask = None
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
subset_mask=subset_mask)
if masked_tokens_mask is None:
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
else:
# TD [2022-03-01]: the indexing here is very tricky.
if attention_mask is not None:
subset_idx = subset_mask[attention_mask]
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
else:
pool_input = sequence_output[first_col_mask[subset_mask]]
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
pooled_output = (self.pooler(pool_input, pool=False)
if self.pooler is not None else None)
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
)
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config: BertConfig):
super().__init__(config)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads.
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
if self.last_layer_subset:
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
use_xentropy = getattr(config, 'use_xentropy', False)
if use_xentropy and CrossEntropyLoss is None:
raise ImportError('xentropy_cuda is not installed')
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
else partial(CrossEntropyLoss, inplace_backward=True))
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
self.mlm_loss = loss_cls(ignore_index=0)
self.nsp_loss = loss_cls(ignore_index=-1)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
labels=None, next_sentence_label=None):
"""
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
"""
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
outputs = self.bert(
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask.bool() if attention_mask is not None else None,
masked_tokens_mask=masked_tokens_mask
)
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
if self.dense_seq_output and labels is not None:
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
if not self.last_layer_subset:
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
masked_token_idx)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
masked_lm_loss = self.mlm_loss(prediction_scores,
labels.flatten()[masked_token_idx])
else:
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
rearrange(labels, '... -> (...)'))
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
rearrange(next_sentence_label, '... -> (...)'))
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
)
def remap_state_dict(state_dict, config):
# LayerNorm
def key_mapping_ln_gamma_beta(key):
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
return key
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
# Layers
def key_mapping_layers(key):
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
r'bert.encoder.layers.\1.norm2.\2', key)
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
r'cls.predictions.transform.layer_norm.\1', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
last_layer_subset = getattr(config, 'last_layer_subset', False)
for d in range(config.num_hidden_layers):
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
if not (last_layer_subset and d == config.num_hidden_layers - 1):
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
[Wq, Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
else:
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
[Wk, Wv], dim=0
)
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
def key_mapping_decoder_bias(key):
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
# Word embedding
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
if pad_vocab_size_multiple > 1:
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
)
decoder_weight = state_dict['cls.predictions.decoder.weight']
state_dict['cls.predictions.decoder.weight'] = F.pad(
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias = state_dict['cls.predictions.decoder.bias']
state_dict['cls.predictions.decoder.bias'] = F.pad(
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
)
return state_dict

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, FalconConfig
def remap_state_dict_hf_falcon(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
r'transformer.layers.\1.norm2.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
r'transformer.layers.\1.mixer.Wqkv.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", 1)
headdim = config.hidden_size // n_head
for l in range(config.n_layer):
# The weights are stored in a different layout compared to our implementation
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
"(group ratio headdim) ... -> group ratio headdim ...",
ratio=n_head // n_head_kv + 2, headdim=headdim)
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
return state_dict
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv = getattr(falcon_config, "n_head_kv",
1 if getattr(falcon_config, "multi_query", False)
else falcon_config.n_head)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm = n_head_kv == 1
return GPT2Config(
vocab_size=falcon_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=falcon_config.hidden_size,
n_layer=falcon_config.n_layer,
n_head=falcon_config.n_head,
n_inner=falcon_config.hidden_size * 4,
activation_function="gelu",
resid_pdrop=falcon_config.hidden_dropout,
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
attn_pdrop=falcon_config.attention_dropout,
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
initializer_range=falcon_config.initializer_range,
bos_token_id=falcon_config.bos_token_id,
eos_token_id=falcon_config.eos_token_id,
# These are new arguments not in the original GPT2Config
parallel_block=falcon_config.parallel_attn,
n_head_kv=n_head_kv,
parallel_block_tied_norm=parallel_block_tied_norm,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=False,
tie_word_embeddings=True,
qkv_proj_bias=falcon_config.bias,
out_proj_bias=falcon_config.bias,
mlp_fc1_bias=falcon_config.bias,
mlp_fc2_bias=falcon_config.bias,
lm_head_bias=False,
)

View File

@@ -0,0 +1,740 @@
# Copyright (c) 2023, Tri Dao.
import logging
import math
import re
from functools import partial
from collections import namedtuple, OrderedDict
from collections.abc import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Config
from einops import rearrange
from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block, ParallelBlock
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import GenerationMixin
from flash_attn.models.opt import remap_state_dict_hf_opt
from flash_attn.models.gptj import remap_state_dict_hf_gptj
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
from flash_attn.models.falcon import remap_state_dict_hf_falcon
try:
from flash_attn.ops.fused_dense import ColumnParallelLinear
except ImportError:
ColumnParallelLinear = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None
logger = logging.getLogger(__name__)
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None
softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False)
if dwconv:
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
out_proj_bias = getattr(config, 'out_proj_bias', True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if not fused_bias_fc:
assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
mha_cls = MHA if process_group is None else ParallelMHA
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
if process_group is None else {})
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
num_heads_kv = getattr(config, "n_head_kv", None)
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
num_heads_kv=num_heads_kv,
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
return mixer_cls
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu')
assert not (fused_dense_sqrelu_dense and fused_mlp)
if not fused_mlp and not fused_dense_sqrelu_dense:
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
'sqrelu', 'glu', 'swiglu', 'geglu']
if config.activation_function in ['glu', 'swiglu', 'geglu']:
activation = (F.sigmoid if config.activation_function == 'glu'
else (F.silu if config.activation_function == 'swiglu'
else F.gelu))
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
else:
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
elif config.activation_function == 'sqrelu':
activation = sqrelu_fwd
else:
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = Mlp if process_group is None else ParallelMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_mlp:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
activation = ('gelu_approx' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
else:
raise RuntimeError('MLP type not supported')
return mlp_cls
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
sequence_parallel = getattr(config, 'sequence_parallel', True)
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
use_rms_norm = getattr(config, 'rms_norm', False)
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
eps=config.layer_norm_epsilon, **factory_kwargs)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
prenorm = getattr(config, 'prenorm', True)
parallel_block = getattr(config, 'parallel_block', False)
if not parallel_block:
block = Block(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None
)
else:
assert prenorm
block = ParallelBlock(
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
tied_norm=getattr(config, 'parallel_block_tied_norm', False),
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
residual_in_fp32=residual_in_fp32,
sequence_parallel=sequence_parallel and process_group is not None,
mark_shared_params=process_group is not None
)
block.layer_idx = layer_idx
return block
class GPTPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super().__init__()
if not isinstance(config, GPT2Config):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
@classmethod
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
world_size=1, rank=0, **kwargs):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict = state_dict_from_pretrained(
model_name, device='cpu', dtype=dtype
)
if model_name.startswith('gpt2'):
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
elif model_name.startswith('facebook/opt'):
state_dict = remap_state_dict_hf_opt(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-j-'):
state_dict = remap_state_dict_hf_gptj(state_dict, config)
elif model_name.startswith('EleutherAI/gpt-neox-'):
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
elif model_name.startswith('tiiuae/falcon-'):
state_dict = remap_state_dict_hf_falcon(state_dict, config)
else:
raise NotImplementedError(f'Model {model_name} not supported')
if world_size > 1:
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
load_return = model.load_state_dict(state_dict, strict=strict)
logger.info(load_return)
return model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
class GPTModel(GPTPreTrainedModel):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
super().__init__(config)
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
# These 2 options are for OPT-350m
self.prenorm = getattr(config, 'prenorm', True)
use_rms_norm = getattr(config, 'rms_norm', False)
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
# For GPT-J, GPT-NeoX
self.parallel_block = getattr(config, 'parallel_block', False)
if process_group is None:
self.embeddings = GPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
)
else:
self.embeddings = ParallelGPT2Embeddings(
config.hidden_size, vocab_size, config.max_position_embeddings,
process_group=process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs
)
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
**factory_kwargs)
for i in range(config.num_hidden_layers)])
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln:
if ((not self.parallel_block and dropout_add_layer_norm is None)
or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
raise ImportError('dropout_layer_norm is not installed')
if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
**factory_kwargs)
if process_group is not None:
for p in self.ln_f.parameters():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p._shared_params = True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if self.sequence_parallel:
p._sequence_parallel = True
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
for i, layer in enumerate(self.layers)}
def forward(self, input_ids, position_ids=None, inference_params=None):
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
# dimensions so that we can split on it easily, in case of small batch size.
# Only the attention layers need to know the seqlen.
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
if self.process_group is not None and self.sequence_parallel else {})
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
if self.parallel_block:
hidden_states2 = None
residual = None
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
if self.process_group is not None and self.sequence_parallel else {})
if inference_params is not None:
mixer_kwargs['inference_params'] = inference_params
for layer in self.layers:
if self.prenorm:
if not self.parallel_block:
hidden_states, residual = layer(hidden_states, residual,
mixer_kwargs=mixer_kwargs)
else:
hidden_states, hidden_states2, residual = layer(
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
)
else:
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_f(hidden_states)
if not self.parallel_block:
residual = (dropped + residual) if residual is not None else dropped
else:
dropped2 = self.drop_f(hidden_states2)
residual = ((residual + dropped + dropped2)
if residual is not None else dropped + dropped2)
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
if not self.parallel_block:
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm)
hidden_states = fused_add_norm_fn(
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
residual_in_fp32=self.residual_in_fp32
)
else:
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual)
hidden_states, _ = fused_add_norm_fn(
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
prenorm=False, residual_in_fp32=self.residual_in_fp32
)
return hidden_states
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(config)
self.process_group = process_group
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
lm_head_bias = getattr(config, 'lm_head_bias', False)
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# This option is for OPT-350m
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
if word_embed_proj_dim is not None:
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
else:
self.project_out = None
if process_group is None:
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
else:
if ColumnParallelLinear is None:
raise ImportError('fused_dense_lib is not installed')
self.lm_head = ColumnParallelLinear(
embed_dim, vocab_size, process_group, bias=lm_head_bias,
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
if self.tie_word_embeddings:
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
if self.process_group is not None:
sync_shared_params(self, self.process_group)
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
**kwargs)
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
"""
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
last_token_only: whether to return the logit for the last token only,
of shape (batch_size, vocab_size)
"""
hidden_states = self.transformer(input_ids, position_ids=position_ids,
inference_params=inference_params)
if last_token_only:
hidden_states = hidden_states[:, -1]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
# During inference, we want the full logit for sampling
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)
def load_state_dict(self, state_dict, strict=True):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if 'transformer.ln_0.weight' in state_dict:
n_layers = len(self.transformer.layers)
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
state_dict['transformer.ln_f.weight'] = ln_weight
state_dict['transformer.ln_f.bias'] = ln_bias
for l in reversed(range(n_layers)):
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
if l > 0:
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
ln_weight = state_dict.pop('transformer.ln_0.weight')
ln_bias = state_dict.pop('transformer.ln_0.bias')
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
return super().load_state_dict(state_dict, strict=strict)
def shard_state_dict_tp(state_dict, config, world_size, rank):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
def shard_first_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim:(rank + 1) * dim]
def shard_last_dim(state_dict, key):
if key in state_dict:
x = state_dict[key]
dim = x.shape[-1] // world_size
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
def shard_qkv_headdim(state_dict, key):
if key in state_dict:
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
if n_head_kv == n_head:
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
'three d ... -> (three d) ...')
else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv)
state_dict[key] = rearrange(torch.cat([
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
shard_first_dim(state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
if rank != 0:
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
return state_dict
def combine_state_dicts_tp(state_dicts, config):
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
with tensor parallel.
"""
world_size = len(state_dicts)
keys = state_dicts[0].keys()
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
assert vocab_size % world_size == 0
assert config.hidden_size % world_size == 0
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
# vocab_size // world_size coordinates are nonzero.
def combine_word_embeddings(state_dicts, state_dict, key):
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_dim(state_dicts, state_dict, key, dim=-1):
if key in state_dict:
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
def combine_qkv_headdim(state_dicts, state_dict, key):
n_head = config.n_head
n_head_kv = getattr(config, 'n_head_kv', n_head)
assert n_head % world_size == 0 and n_head_kv % world_size == 0
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
if key in state_dict:
if n_head_kv == n_head:
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
else:
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
state_dict[key] = rearrange(torch.cat([
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
def combine_gated_mlp(state_dicts, state_dict, key):
if key in state_dict:
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
if 'lm_head.weight' in state_dict:
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
else partial(combine_dim, dim=0))
for i in range(config.num_hidden_layers):
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
return state_dict
def remap_state_dict_hf_gpt2(state_dict, config):
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('wte.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for d in range(config.num_hidden_layers):
W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
def key_mapping_mlp(key):
key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for d in range(config.num_hidden_layers):
state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
def key_mapping_attn(key):
key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def remap_state_dict_megatron(state_dict, config):
def key_mapping_transformer(key):
key = re.sub(r'^language_model.encoder.', 'transformer.', key)
key = re.sub(r'^language_model.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_pos_emb(key):
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
r'transformer.layers.\1.norm1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
r'transformer.layers.\1.norm2.\2', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
r'transformer.layers.\1.mlp.fc1.\2', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
r'transformer.layers.\1.mlp.fc2.\2', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
r'transformer.layers.\1.mixer.Wqkv.\2', key)
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
r'transformer.layers.\1.mixer.out_proj.\2', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
for d in range(config.num_hidden_layers):
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
return state_dict

View File

@@ -0,0 +1,107 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import GPT2Config, GPTNeoXConfig
def remap_state_dict_hf_gpt_neox(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^gpt_neox.', 'transformer.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('embed_out.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attention.bias')
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim = config.hidden_size // config.num_attention_heads
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
three=3, headdim=headdim
)
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
bqkv, '(nheads three headdim) -> (three nheads headdim)',
three=3, headdim=headdim
)
def key_mapping_attn(key):
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
r'transformer.layers.\1.mixer.out_proj.', key)
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
r'transformer.layers.\1.mixer.rotary_emb.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
assert gpt_neox_config.rotary_emb_base == 10000
return GPT2Config(
vocab_size=gpt_neox_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gpt_neox_config.hidden_size,
n_layer=gpt_neox_config.num_hidden_layers,
n_head=gpt_neox_config.num_attention_heads,
n_inner=gpt_neox_config.intermediate_size,
activation_function=gpt_neox_config.hidden_act,
resid_pdrop=0.0, # No dropout
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
initializer_range=gpt_neox_config.initializer_range,
bos_token_id=gpt_neox_config.bos_token_id,
eos_token_id=gpt_neox_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=gpt_neox_config.use_parallel_residual,
parallel_block_tied_norm=False,
rotary_emb_fraction=gpt_neox_config.rotary_pct,
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
)

View File

@@ -0,0 +1,98 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, GPTJConfig
def remap_state_dict_hf_gptj(state_dict, config):
def key_mapping_layers(key):
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('lm_head.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
output_embeddings_bias = state_dict.pop('lm_head.bias')
state_dict['lm_head.bias'] = F.pad(
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
return key
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these biases
state_dict.pop(f'transformer.layers.{l}.attn.bias')
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
headdim = gptj_config.n_embd // gptj_config.n_head
return GPT2Config(
vocab_size=gptj_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=gptj_config.n_embd,
n_layer=gptj_config.n_layer,
n_head=gptj_config.n_head,
n_inner=gptj_config.n_inner,
activation_function=gptj_config.activation_function,
resid_pdrop=gptj_config.resid_pdrop,
embd_pdrop=gptj_config.embd_pdrop,
attn_pdrop=gptj_config.attn_pdrop,
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
initializer_range=gptj_config.initializer_range,
bos_token_id=gptj_config.bos_token_id,
eos_token_id=gptj_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=True,
parallel_block=True,
parallel_block_tied_norm=True,
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
lm_head_bias=True,
)

View File

@@ -0,0 +1,124 @@
# Copyright (c) 2023, Tri Dao.
import math
import json
import re
from pathlib import Path
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, LlamaConfig
def remap_state_dict_meta_llama(state_dict, config):
def key_mapping_layers(key):
return f'transformer.{key}' if not key.startswith('output.') else key
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
# Word embedding
def key_mapping_emb(key):
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
if getattr(config, 'tie_word_embeddings'):
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
else:
output_embeddings = state_dict.pop('output.weight')
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
* pad_vocab_size_multiple)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict['lm_head.weight'] = F.pad(
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
)
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
for l in range(config.n_layer):
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
# Our ordering is different
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
r'transformer.layers.\1.mlp.fc2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
# We don't store these
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
"""Load a LlamaConfig from a checkpoint path."""
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
params = json.load(f)
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
num_attention_heads=params['n_heads'],
num_hidden_layers=params['n_layers'],
rms_norm_eps=params['norm_eps'])
return config
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return [torch.load(path, map_location='cpu')
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
return GPT2Config(
vocab_size=llama_config.vocab_size,
n_positions=0, # No absolute position embedding
n_embd=llama_config.hidden_size,
n_layer=llama_config.num_hidden_layers,
n_head=llama_config.num_attention_heads,
n_inner=llama_config.intermediate_size,
activation_function='swiglu', # Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=llama_config.rms_norm_eps,
initializer_range=llama_config.initializer_range,
bos_token_id=llama_config.bos_token_id,
eos_token_id=llama_config.eos_token_id,
# These are new arguments not in the original GPT2Config
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
rms_norm=True,
rotary_emb_fraction=1.0,
rotary_emb_interleaved=True,
tie_word_embeddings=False,
qkv_proj_bias=False,
out_proj_bias=False,
mlp_fc1_bias=False,
mlp_fc2_bias=False,
)

View File

@@ -0,0 +1,102 @@
# Copyright (c) 2023, Tri Dao.
import math
import re
from collections import OrderedDict
import torch
import torch.nn.functional as F
from transformers import GPT2Config, OPTConfig
def remap_state_dict_hf_opt(state_dict, config):
def key_mapping_model(key):
key = re.sub(r'^model.decoder.', 'transformer.', key)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key = re.sub(r'^decoder.', 'transformer.', key)
return key
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
# Word embedding and position embedding
def key_mapping_emb(key):
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
# The OPT-350m model uses has project_in and project_out
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
key = re.sub(r'^transformer.embed_positions.',
'transformer.embeddings.position_embeddings.', key)
return key
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
)
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
# LayerNorm
def key_mapping_ln(key):
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
r'transformer.layers.\1.norm1.', key)
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
r'transformer.layers.\1.norm2.', key)
return key
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
# MLP
def key_mapping_mlp(key):
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
r'transformer.layers.\1.mlp.fc\2.', key)
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
# Attention
for l in range(config.n_layer):
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
def key_mapping_attn(key):
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
r'transformer.layers.\1.mixer.out_proj.', key)
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
return state_dict
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
assert opt_config.layerdrop == 0.0
assert opt_config.layer_norm_elementwise_affine
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
else opt_config.word_embed_proj_dim)
return GPT2Config(
vocab_size=opt_config.vocab_size,
n_positions=opt_config.max_position_embeddings,
n_embd=opt_config.hidden_size,
n_layer=opt_config.num_hidden_layers,
n_head=opt_config.num_attention_heads,
n_inner=opt_config.ffn_dim,
activation_function=opt_config.activation_function,
resid_pdrop=opt_config.dropout,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop=opt_config.dropout,
attn_pdrop=opt_config.attention_dropout,
initializer_range=opt_config.init_std,
bos_token_id=opt_config.bos_token_id,
eos_token_id=opt_config.eos_token_id,
# These are new arguments not in the original GPT2Config
prenorm=opt_config.do_layer_norm_before,
word_embed_proj_dim=word_embed_proj_dim
)

View File

@@ -0,0 +1,304 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
import re
from functools import partial
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from torchvision.ops import StochasticDepth
from einops import rearrange
from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
cross_attn=False):
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn)
return mixer_cls
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
inner_dim = int(embed_dim * mlp_ratio)
if not fused_mlp:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else:
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
drop_path1=drop_path1, drop_path2=drop_path2,
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
return block
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_mlp=False,
fused_dropout_add_ln=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token'
assert class_token
assert init_values is None, 'LayerScale is not supported yet'
assert weight_init == ''
assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
else {})
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**patch_embed_extra_kwargs
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# We change the order of dropout, residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
# nn.Dropout probabilities are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])
self.dropout = nn.Dropout(p=drop_rate)
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode == ''
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return x
def forward_features(self, x, all_tokens=True):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x = self.patch_embed(x)
hidden_states = self._pos_embed(x)
residual = None
if self.global_pool != 'token' or all_tokens:
# if True:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
for block in self.blocks[:-1]:
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states, residual = self.blocks[-1](hidden_states, residual,
mixer_subset=slice(0, 1))
if not self.fused_dropout_add_ln:
residual = self.drop_path(self.dropout(hidden_states)) + residual
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
else:
if self.drop_path.p == 0 or not self.training:
rowscale = None
else:
rowscale = self.drop_path(torch.ones(
hidden_states.shape[:-1], device=hidden_states.device,
dtype=hidden_states.dtype)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states, residual, self.norm.weight, self.norm.bias,
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
prenorm=False, residual_in_fp32=True
)
return hidden_states
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x, all_tokens=False)
x = self.forward_head(x)
return x
def load_state_dict(self, state_dict, strict=True):
patch_embed_weight = state_dict['patch_embed.proj.weight']
if patch_embed_weight.dim() == 4:
# convert from Conv2d to Linear
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
'o c h w -> o (c h w)')
def key_mapping_attn(key):
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
return key
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
n_layer = len(self.blocks)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if (self.blocks[-1].mixer.cross_attn
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
return super().load_state_dict(state_dict, strict=strict)
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert not pretrained
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = VisionTransformer(**model_kwargs)
return model