First commit
This commit is contained in:
0
pkgs/xformers/_flash_attn/models/__init__.py
Normal file
0
pkgs/xformers/_flash_attn/models/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc
Normal file
BIN
pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc
Normal file
Binary file not shown.
531
pkgs/xformers/_flash_attn/models/bert.py
Normal file
531
pkgs/xformers/_flash_attn/models/bert.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
||||
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
||||
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
||||
|
||||
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
||||
|
||||
import re
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
from collections.abc import Sequence
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import BertConfig
|
||||
from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDense
|
||||
except ImportError:
|
||||
FusedDense = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm, layer_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
except ImportError:
|
||||
CrossEntropyLoss = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
rotary_kwargs = {}
|
||||
if config.position_embedding_type == "rotary":
|
||||
rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size)
|
||||
rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0)
|
||||
rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None)
|
||||
rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False)
|
||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn,
|
||||
dropout=config.attention_probs_dropout_prob, causal=False,
|
||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn,
|
||||
return_residual=return_residual, **rotary_kwargs)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
||||
inner_dim = config.intermediate_size
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
|
||||
'supports approximate gelu')
|
||||
if not fused_mlp:
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate),
|
||||
return_residual=return_residual)
|
||||
else:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None):
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
||||
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
||||
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
||||
# one layer) so we just choose not to return residual in this case.
|
||||
return_residual = not cross_attn
|
||||
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=False, resid_dropout1=config.hidden_dropout_prob,
|
||||
resid_dropout2=config.hidden_dropout_prob,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
return_residual=return_residual)
|
||||
return block
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
||||
def _init_weights(module, initializer_range=0.02):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
nn.init.zeros_(module.weight[module.padding_idx])
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
||||
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
||||
This means that we only compute the last layer output for these tokens.
|
||||
subset_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
if key_padding_mask is None or not self.use_flash_attn:
|
||||
mixer_kwargs = ({'key_padding_mask': key_padding_mask}
|
||||
if key_padding_mask is not None else None)
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if subset_mask is not None:
|
||||
hidden_states = hidden_states[subset_mask]
|
||||
else:
|
||||
batch, seqlen = hidden_states.shape[:2]
|
||||
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
hidden_states, key_padding_mask
|
||||
)
|
||||
mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch}
|
||||
if subset_mask is None:
|
||||
for layer in self.layers:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
||||
else:
|
||||
for layer in self.layers[:-1]:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if key_padding_mask is not None:
|
||||
subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten()
|
||||
subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
else:
|
||||
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
||||
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
||||
subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0,
|
||||
dtype=torch.torch.int32), (1, 0))
|
||||
hidden_states_subset, hidden_states = index_first_axis_residual(
|
||||
hidden_states, subset_idx
|
||||
)
|
||||
# It's ok to set max_seqlen_q to be much larger
|
||||
mixer_kwargs = {'x_kv': hidden_states,
|
||||
'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch,
|
||||
'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch}
|
||||
hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states, pool=True):
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class BertPredictionHeadTransform(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
self.transform_act_fn = nn.GELU(approximate=approximate)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias,
|
||||
self.layer_norm.eps)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLMPredictionHead(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
|
||||
self.transform = BertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertPreTrainingHeads(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.predictions = BertLMPredictionHead(config)
|
||||
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
def forward(self, sequence_output, pooled_output):
|
||||
prediction_scores = self.predictions(sequence_output)
|
||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, BertConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `model.chkpt` a TensorFlow checkpoint
|
||||
*inputs, **kwargs: additional input for the specific Bert class
|
||||
(ex: num_labels for BertForSequenceClassification)
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *inputs, **kwargs)
|
||||
load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name),
|
||||
config), strict=False)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
class BertModel(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
||||
config.vocab_size += (self.pad_vocab_size_multiple
|
||||
- (config.vocab_size % self.pad_vocab_size_multiple))
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast']
|
||||
|
||||
self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size,
|
||||
config.max_position_embeddings, config.type_vocab_size,
|
||||
padding_idx=config.pad_token_id)
|
||||
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
masked_tokens_mask=None):
|
||||
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
||||
we only want the output for the masked tokens. This means that we only compute the last
|
||||
layer output for these tokens.
|
||||
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
||||
"""
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
# TD [2022-12:18]: Don't need to force residual in fp32
|
||||
# BERT puts embedding LayerNorm before embedding dropout.
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.emb_ln(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias,
|
||||
self.emb_ln.eps)
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
|
||||
if masked_tokens_mask is not None:
|
||||
batch_size, seqlen = input_ids.shape[:2]
|
||||
# We also need the first column for the CLS token
|
||||
first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool,
|
||||
device=input_ids.device)
|
||||
first_col_mask[:, 0] = True
|
||||
subset_mask = masked_tokens_mask | first_col_mask
|
||||
else:
|
||||
subset_mask = None
|
||||
|
||||
sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask,
|
||||
subset_mask=subset_mask)
|
||||
|
||||
if masked_tokens_mask is None:
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
else:
|
||||
# TD [2022-03-01]: the indexing here is very tricky.
|
||||
if attention_mask is not None:
|
||||
subset_idx = subset_mask[attention_mask]
|
||||
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
||||
sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
|
||||
else:
|
||||
pool_input = sequence_output[first_col_mask[subset_mask]]
|
||||
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
||||
pooled_output = (self.pooler(pool_input, pool=False)
|
||||
if self.pooler is not None else None)
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
)
|
||||
|
||||
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__(config)
|
||||
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
||||
# (around 15%) to the classifier heads.
|
||||
self.dense_seq_output = getattr(config, 'dense_seq_output', False)
|
||||
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
||||
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
||||
self.last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
if self.last_layer_subset:
|
||||
assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output'
|
||||
use_xentropy = getattr(config, 'use_xentropy', False)
|
||||
if use_xentropy and CrossEntropyLoss is None:
|
||||
raise ImportError('xentropy_cuda is not installed')
|
||||
loss_cls = (nn.CrossEntropyLoss if not use_xentropy
|
||||
else partial(CrossEntropyLoss, inplace_backward=True))
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.cls = BertPreTrainingHeads(config)
|
||||
self.mlm_loss = loss_cls(ignore_index=0)
|
||||
self.nsp_loss = loss_cls(ignore_index=-1)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None,
|
||||
labels=None, next_sentence_label=None):
|
||||
"""
|
||||
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
||||
mask).
|
||||
Outputs:
|
||||
if `labels` and `next_sentence_label` are not `None`:
|
||||
Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
||||
sentence classification loss.
|
||||
if `labels` or `next_sentence_label` is `None`:
|
||||
Outputs a tuple comprising
|
||||
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
||||
- the next sentence classification logits of shape [batch_size, 2].
|
||||
|
||||
"""
|
||||
masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
||||
outputs = self.bert(
|
||||
input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
||||
masked_tokens_mask=masked_tokens_mask
|
||||
)
|
||||
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
||||
if self.dense_seq_output and labels is not None:
|
||||
masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
||||
if not self.last_layer_subset:
|
||||
sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'),
|
||||
masked_token_idx)
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
total_loss = None
|
||||
if labels is not None and next_sentence_label is not None:
|
||||
if self.dense_seq_output and labels is not None: # prediction_scores are already flattened
|
||||
masked_lm_loss = self.mlm_loss(prediction_scores,
|
||||
labels.flatten()[masked_token_idx])
|
||||
else:
|
||||
masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'),
|
||||
rearrange(labels, '... -> (...)'))
|
||||
next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'),
|
||||
rearrange(next_sentence_label, '... -> (...)'))
|
||||
total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
||||
|
||||
return BertForPreTrainingOutput(
|
||||
loss=total_loss,
|
||||
prediction_logits=prediction_scores,
|
||||
seq_relationship_logits=seq_relationship_score,
|
||||
)
|
||||
|
||||
|
||||
def remap_state_dict(state_dict, config):
|
||||
# LayerNorm
|
||||
def key_mapping_ln_gamma_beta(key):
|
||||
key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key)
|
||||
key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Layers
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.norm2.\2', key)
|
||||
key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)',
|
||||
r'cls.predictions.transform.layer_norm.\1', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc1.\2', key)
|
||||
key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mlp.fc2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
last_layer_subset = getattr(config, 'last_layer_subset', False)
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight')
|
||||
Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight')
|
||||
Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight')
|
||||
bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias')
|
||||
bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias')
|
||||
bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias')
|
||||
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat(
|
||||
[Wq, Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
else:
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat(
|
||||
[Wk, Wv], dim=0
|
||||
)
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq
|
||||
state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)',
|
||||
r'bert.encoder.layers.\1.mixer.out_proj.\2', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_decoder_bias(key):
|
||||
return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key)
|
||||
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Word embedding
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
if pad_vocab_size_multiple > 1:
|
||||
word_embeddings = state_dict['bert.embeddings.word_embeddings.weight']
|
||||
state_dict['bert.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
decoder_weight = state_dict['cls.predictions.decoder.weight']
|
||||
state_dict['cls.predictions.decoder.weight'] = F.pad(
|
||||
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
||||
)
|
||||
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
||||
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
||||
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
||||
decoder_bias = state_dict['cls.predictions.decoder.bias']
|
||||
state_dict['cls.predictions.decoder.bias'] = F.pad(
|
||||
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
||||
)
|
||||
|
||||
return state_dict
|
||||
122
pkgs/xformers/_flash_attn/models/falcon.py
Normal file
122
pkgs/xformers/_flash_attn/models/falcon.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, FalconConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_falcon(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.',
|
||||
r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.',
|
||||
r'transformer.layers.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, "n_head_kv", 1)
|
||||
headdim = config.hidden_size // n_head
|
||||
for l in range(config.n_layer):
|
||||
# The weights are stored in a different layout compared to our implementation
|
||||
Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'),
|
||||
"(group ratio headdim) ... -> group ratio headdim ...",
|
||||
ratio=n_head // n_head_kv + 2, headdim=headdim)
|
||||
Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...")
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config:
|
||||
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
|
||||
n_head_kv = getattr(falcon_config, "n_head_kv",
|
||||
1 if getattr(falcon_config, "multi_query", False)
|
||||
else falcon_config.n_head)
|
||||
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
|
||||
# So we have to infer it from the number of heads in the key/value block
|
||||
parallel_block_tied_norm = n_head_kv == 1
|
||||
return GPT2Config(
|
||||
vocab_size=falcon_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=falcon_config.hidden_size,
|
||||
n_layer=falcon_config.n_layer,
|
||||
n_head=falcon_config.n_head,
|
||||
n_inner=falcon_config.hidden_size * 4,
|
||||
activation_function="gelu",
|
||||
resid_pdrop=falcon_config.hidden_dropout,
|
||||
embd_pdrop=0.0, # There doesn't seem to be any embedding dropout
|
||||
attn_pdrop=falcon_config.attention_dropout,
|
||||
layer_norm_epsilon=falcon_config.layer_norm_epsilon,
|
||||
initializer_range=falcon_config.initializer_range,
|
||||
bos_token_id=falcon_config.bos_token_id,
|
||||
eos_token_id=falcon_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
parallel_block=falcon_config.parallel_attn,
|
||||
n_head_kv=n_head_kv,
|
||||
parallel_block_tied_norm=parallel_block_tied_norm,
|
||||
rotary_emb_fraction=1.0,
|
||||
rotary_emb_interleaved=False,
|
||||
tie_word_embeddings=True,
|
||||
qkv_proj_bias=falcon_config.bias,
|
||||
out_proj_bias=falcon_config.bias,
|
||||
mlp_fc1_bias=falcon_config.bias,
|
||||
mlp_fc2_bias=falcon_config.bias,
|
||||
lm_head_bias=False,
|
||||
)
|
||||
740
pkgs/xformers/_flash_attn/models/gpt.py
Normal file
740
pkgs/xformers/_flash_attn/models/gpt.py
Normal file
@@ -0,0 +1,740 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.activations import sqrelu_fwd
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block, ParallelBlock
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.generation import GenerationMixin
|
||||
from flash_attn.models.opt import remap_state_dict_hf_opt
|
||||
from flash_attn.models.gptj import remap_state_dict_hf_gptj
|
||||
from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox
|
||||
from flash_attn.models.falcon import remap_state_dict_hf_falcon
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear
|
||||
except ImportError:
|
||||
ColumnParallelLinear = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
|
||||
except ImportError:
|
||||
RMSNorm, dropout_add_rms_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_rms_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
|
||||
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
|
||||
if config.scale_attn_by_inverse_layer_idx:
|
||||
assert layer_idx is not None
|
||||
softmax_scale /= float(layer_idx + 1)
|
||||
dwconv = getattr(config, 'attn_dwconv', False)
|
||||
if dwconv:
|
||||
assert process_group is None, 'TensorParallel MHA does not support dwconv yet'
|
||||
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
|
||||
out_proj_bias = getattr(config, 'out_proj_bias', True)
|
||||
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
|
||||
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
|
||||
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
|
||||
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
if not fused_bias_fc:
|
||||
assert process_group is None, 'TensorParallel MHA requires fused_bias_fc'
|
||||
mha_cls = MHA if process_group is None else ParallelMHA
|
||||
serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv}
|
||||
if process_group is None else {})
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
num_heads_kv = getattr(config, "n_head_kv", None)
|
||||
mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads,
|
||||
num_heads_kv=num_heads_kv,
|
||||
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
|
||||
dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
|
||||
rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_interleaved=rotary_emb_interleaved,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True)
|
||||
mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True)
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu']
|
||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||
if fused_dense_sqrelu_dense:
|
||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||
'supports approximate activation_function sqrelu')
|
||||
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu',
|
||||
'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
if config.activation_function in ['glu', 'swiglu', 'geglu']:
|
||||
activation = (F.sigmoid if config.activation_function == 'glu'
|
||||
else (F.silu if config.activation_function == 'swiglu'
|
||||
else F.gelu))
|
||||
mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs)
|
||||
else:
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
elif config.activation_function == 'sqrelu':
|
||||
activation = sqrelu_fwd
|
||||
else:
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = Mlp if process_group is None else ParallelMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
if fused_mlp:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
activation = ('gelu_approx' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function)
|
||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
bias1=mlp_fc1_bias, bias2=mlp_fc2_bias,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
assert FusedDenseSqreluDense is not None
|
||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs)
|
||||
else:
|
||||
raise RuntimeError('MLP type not supported')
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs)
|
||||
use_rms_norm = getattr(config, 'rms_norm', False)
|
||||
norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm,
|
||||
eps=config.layer_norm_epsilon, **factory_kwargs)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop
|
||||
prenorm = getattr(config, 'prenorm', True)
|
||||
parallel_block = getattr(config, 'parallel_block', False)
|
||||
if not parallel_block:
|
||||
block = Block(
|
||||
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None
|
||||
)
|
||||
else:
|
||||
assert prenorm
|
||||
block = ParallelBlock(
|
||||
config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop,
|
||||
tied_norm=getattr(config, 'parallel_block_tied_norm', False),
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False),
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
sequence_parallel=sequence_parallel and process_group is not None,
|
||||
mark_shared_params=process_group is not None
|
||||
)
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
class GPTPreTrainedModel(nn.Module):
|
||||
""" An abstract class to handle weights initialization and
|
||||
a simple interface for dowloading and loading pretrained models.
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(config, GPT2Config):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
|
||||
"To create a model from a Google pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
))
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None,
|
||||
world_size=1, rank=0, **kwargs):
|
||||
"""
|
||||
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
# Instantiate model.
|
||||
model = cls(config, *args, device=device, dtype=dtype, **kwargs)
|
||||
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
|
||||
# want extra stuff taking up more GPU memory
|
||||
state_dict = state_dict_from_pretrained(
|
||||
model_name, device='cpu', dtype=dtype
|
||||
)
|
||||
if model_name.startswith('gpt2'):
|
||||
state_dict = remap_state_dict_hf_gpt2(state_dict, config)
|
||||
elif model_name.startswith('facebook/opt'):
|
||||
state_dict = remap_state_dict_hf_opt(state_dict, config)
|
||||
elif model_name.startswith('EleutherAI/gpt-j-'):
|
||||
state_dict = remap_state_dict_hf_gptj(state_dict, config)
|
||||
elif model_name.startswith('EleutherAI/gpt-neox-'):
|
||||
state_dict = remap_state_dict_hf_gpt_neox(state_dict, config)
|
||||
elif model_name.startswith('tiiuae/falcon-'):
|
||||
state_dict = remap_state_dict_hf_falcon(state_dict, config)
|
||||
else:
|
||||
raise NotImplementedError(f'Model {model_name} not supported')
|
||||
if world_size > 1:
|
||||
state_dict = shard_state_dict_tp(state_dict, config, world_size, rank)
|
||||
load_return = model.load_state_dict(state_dict, strict=strict)
|
||||
logger.info(load_return)
|
||||
return model
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
|
||||
|
||||
|
||||
class GPTModel(GPTPreTrainedModel):
|
||||
|
||||
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
||||
super().__init__(config)
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
||||
'relu', 'sqrelu', 'glu', 'swiglu', 'geglu']
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False)
|
||||
# These 2 options are for OPT-350m
|
||||
self.prenorm = getattr(config, 'prenorm', True)
|
||||
use_rms_norm = getattr(config, 'rms_norm', False)
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
# For GPT-J, GPT-NeoX
|
||||
self.parallel_block = getattr(config, 'parallel_block', False)
|
||||
|
||||
if process_group is None:
|
||||
self.embeddings = GPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs
|
||||
)
|
||||
else:
|
||||
self.embeddings = ParallelGPT2Embeddings(
|
||||
config.hidden_size, vocab_size, config.max_position_embeddings,
|
||||
process_group=process_group, sequence_parallel=self.sequence_parallel,
|
||||
**factory_kwargs
|
||||
)
|
||||
|
||||
# We change the order of dropout, residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
||||
# nn.Dropout probabilities are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group,
|
||||
**factory_kwargs)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln:
|
||||
if ((not self.parallel_block and dropout_add_layer_norm is None)
|
||||
or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)):
|
||||
raise ImportError('dropout_layer_norm is not installed')
|
||||
if self.prenorm:
|
||||
self.drop_f = nn.Dropout(config.resid_pdrop)
|
||||
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
|
||||
self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon,
|
||||
**factory_kwargs)
|
||||
if process_group is not None:
|
||||
for p in self.ln_f.parameters():
|
||||
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
||||
p._shared_params = True
|
||||
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
|
||||
if self.sequence_parallel:
|
||||
p._sequence_parallel = True
|
||||
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
if self.process_group is not None:
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
||||
for i, layer in enumerate(self.layers)}
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None):
|
||||
# If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen
|
||||
# dimensions so that we can split on it easily, in case of small batch size.
|
||||
# Only the attention layers need to know the seqlen.
|
||||
embedding_kwargs = ({'combine_batch_seqlen_dim': True}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs)
|
||||
if self.parallel_block:
|
||||
hidden_states2 = None
|
||||
residual = None
|
||||
mixer_kwargs = ({'seqlen': input_ids.shape[1]}
|
||||
if self.process_group is not None and self.sequence_parallel else {})
|
||||
if inference_params is not None:
|
||||
mixer_kwargs['inference_params'] = inference_params
|
||||
for layer in self.layers:
|
||||
if self.prenorm:
|
||||
if not self.parallel_block:
|
||||
hidden_states, residual = layer(hidden_states, residual,
|
||||
mixer_kwargs=mixer_kwargs)
|
||||
else:
|
||||
hidden_states, hidden_states2, residual = layer(
|
||||
hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_f(hidden_states)
|
||||
if not self.parallel_block:
|
||||
residual = (dropped + residual) if residual is not None else dropped
|
||||
else:
|
||||
dropped2 = self.drop_f(hidden_states2)
|
||||
residual = ((residual + dropped + dropped2)
|
||||
if residual is not None else dropped + dropped2)
|
||||
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
if not self.parallel_block:
|
||||
fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
else:
|
||||
fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual)
|
||||
hidden_states, _ = fused_add_norm_fn(
|
||||
hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias,
|
||||
None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps,
|
||||
prenorm=False, residual_in_fp32=self.residual_in_fp32
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
|
||||
|
||||
def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__(config)
|
||||
self.process_group = process_group
|
||||
self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs)
|
||||
self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True)
|
||||
lm_head_bias = getattr(config, 'lm_head_bias', False)
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# This option is for OPT-350m
|
||||
word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None)
|
||||
embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim
|
||||
if word_embed_proj_dim is not None:
|
||||
self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs)
|
||||
else:
|
||||
self.project_out = None
|
||||
if process_group is None:
|
||||
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs)
|
||||
else:
|
||||
if ColumnParallelLinear is None:
|
||||
raise ImportError('fused_dense_lib is not installed')
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
embed_dim, vocab_size, process_group, bias=lm_head_bias,
|
||||
sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs
|
||||
)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
if self.tie_word_embeddings:
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
if self.process_group is not None:
|
||||
sync_shared_params(self, self.process_group)
|
||||
|
||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||
return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False):
|
||||
"""
|
||||
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
||||
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
||||
last_token_only: whether to return the logit for the last token only,
|
||||
of shape (batch_size, vocab_size)
|
||||
"""
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids,
|
||||
inference_params=inference_params)
|
||||
if last_token_only:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
if self.project_out is not None:
|
||||
hidden_states = self.project_out(hidden_states)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
# During inference, we want the full logit for sampling
|
||||
if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
|
||||
lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
|
||||
lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0])
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# Remapping from our checkpoints that used a different ordering of layers in the block
|
||||
# Previous: Attn / MLP -> Dropout -> Add -> LN
|
||||
# Current: Dropout -> Add -> LN -> Attn / MLP
|
||||
if 'transformer.ln_0.weight' in state_dict:
|
||||
n_layers = len(self.transformer.layers)
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias')
|
||||
state_dict['transformer.ln_f.weight'] = ln_weight
|
||||
state_dict['transformer.ln_f.bias'] = ln_bias
|
||||
for l in reversed(range(n_layers)):
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias
|
||||
if l > 0:
|
||||
ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight')
|
||||
ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias')
|
||||
state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias
|
||||
ln_weight = state_dict.pop('transformer.ln_0.weight')
|
||||
ln_bias = state_dict.pop('transformer.ln_0.bias')
|
||||
state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight
|
||||
state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def shard_state_dict_tp(state_dict, config, world_size, rank):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
def shard_first_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[0] // world_size
|
||||
state_dict[key] = x[rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_last_dim(state_dict, key):
|
||||
if key in state_dict:
|
||||
x = state_dict[key]
|
||||
dim = x.shape[-1] // world_size
|
||||
state_dict[key] = x[..., rank * dim:(rank + 1) * dim]
|
||||
|
||||
def shard_qkv_headdim(state_dict, key):
|
||||
if key in state_dict:
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, 'n_head_kv', n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
if n_head_kv == n_head:
|
||||
x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3)
|
||||
dim = x.shape[1] // world_size
|
||||
state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim],
|
||||
'three d ... -> (three d) ...')
|
||||
else:
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
|
||||
nheadqkv=n_head + 2 * n_head_kv)
|
||||
state_dict[key] = rearrange(torch.cat([
|
||||
x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank],
|
||||
x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank],
|
||||
x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank],
|
||||
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
|
||||
|
||||
shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight')
|
||||
if 'lm_head.weight' in state_dict:
|
||||
shard_first_dim(state_dict, 'lm_head.weight')
|
||||
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
|
||||
shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight')
|
||||
for i in range(config.num_hidden_layers):
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
|
||||
shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None)
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
|
||||
shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias')
|
||||
shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight')
|
||||
if rank != 0:
|
||||
state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None)
|
||||
return state_dict
|
||||
|
||||
|
||||
def combine_state_dicts_tp(state_dicts, config):
|
||||
"""Convert the state_dict of a standard GPT model to the state_dict of a GPT model
|
||||
with tensor parallel.
|
||||
"""
|
||||
world_size = len(state_dicts)
|
||||
keys = state_dicts[0].keys()
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
assert vocab_size % world_size == 0
|
||||
assert config.hidden_size % world_size == 0
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
assert inner_dim % world_size == 0
|
||||
|
||||
# Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim.
|
||||
# vocab_size // world_size coordinates are nonzero.
|
||||
def combine_word_embeddings(state_dicts, state_dict, key):
|
||||
dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1
|
||||
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
||||
|
||||
def combine_dim(state_dicts, state_dict, key, dim=-1):
|
||||
if key in state_dict:
|
||||
state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim)
|
||||
|
||||
def combine_qkv_headdim(state_dicts, state_dict, key):
|
||||
n_head = config.n_head
|
||||
n_head_kv = getattr(config, 'n_head_kv', n_head)
|
||||
assert n_head % world_size == 0 and n_head_kv % world_size == 0
|
||||
n_head_per_rank = n_head // world_size
|
||||
n_head_kv_per_rank = n_head_kv // world_size
|
||||
if key in state_dict:
|
||||
if n_head_kv == n_head:
|
||||
xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...')
|
||||
else:
|
||||
xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...',
|
||||
nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat([
|
||||
torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
|
||||
torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0),
|
||||
torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0),
|
||||
], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...")
|
||||
|
||||
def combine_gated_mlp(state_dicts, state_dict, key):
|
||||
if key in state_dict:
|
||||
xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts]
|
||||
state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...')
|
||||
|
||||
state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace
|
||||
combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight')
|
||||
if 'lm_head.weight' in state_dict:
|
||||
combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight')
|
||||
if 'transformer.embeddings.position_embeddings.weight' in state_dict:
|
||||
combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1)
|
||||
mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu']
|
||||
else partial(combine_dim, dim=0))
|
||||
for i in range(config.num_hidden_layers):
|
||||
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight')
|
||||
combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias')
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1)
|
||||
mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight')
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0)
|
||||
combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1)
|
||||
return state_dict
|
||||
|
||||
|
||||
def remap_state_dict_hf_gpt2(state_dict, config):
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('wte.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key)
|
||||
key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for d in range(config.num_hidden_layers):
|
||||
W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight')
|
||||
state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t()
|
||||
W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight')
|
||||
state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t()
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key)
|
||||
key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for d in range(config.num_hidden_layers):
|
||||
state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias
|
||||
Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t()
|
||||
Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t()
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key)
|
||||
key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def remap_state_dict_megatron(state_dict, config):
|
||||
def key_mapping_transformer(key):
|
||||
key = re.sub(r'^language_model.encoder.', 'transformer.', key)
|
||||
key = re.sub(r'^language_model.', 'transformer.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_pos_emb(key):
|
||||
return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)',
|
||||
r'transformer.layers.\1.norm1.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)',
|
||||
r'transformer.layers.\1.norm2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)',
|
||||
r'transformer.layers.\1.mlp.fc1.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)',
|
||||
r'transformer.layers.\1.mlp.fc2.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq',
|
||||
r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)',
|
||||
r'transformer.layers.\1.mixer.Wqkv.\2', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)',
|
||||
r'transformer.layers.\1.mixer.out_proj.\2', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
# Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
for d in range(config.num_hidden_layers):
|
||||
Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange(
|
||||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias')
|
||||
state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange(
|
||||
bqkv, '(nheads three headdim) -> (three nheads headdim)',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
|
||||
return state_dict
|
||||
107
pkgs/xformers/_flash_attn/models/gpt_neox.py
Normal file
107
pkgs/xformers/_flash_attn/models/gpt_neox.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, GPTNeoXConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gpt_neox(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^gpt_neox.', 'transformer.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('embed_out.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.masked_bias')
|
||||
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
|
||||
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
|
||||
headdim = config.hidden_size // config.num_attention_heads
|
||||
Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange(
|
||||
Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange(
|
||||
bqkv, '(nheads three headdim) -> (three nheads headdim)',
|
||||
three=3, headdim=headdim
|
||||
)
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.dense.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.',
|
||||
r'transformer.layers.\1.mixer.rotary_emb.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config:
|
||||
assert gpt_neox_config.rotary_emb_base == 10000
|
||||
return GPT2Config(
|
||||
vocab_size=gpt_neox_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=gpt_neox_config.hidden_size,
|
||||
n_layer=gpt_neox_config.num_hidden_layers,
|
||||
n_head=gpt_neox_config.num_attention_heads,
|
||||
n_inner=gpt_neox_config.intermediate_size,
|
||||
activation_function=gpt_neox_config.hidden_act,
|
||||
resid_pdrop=0.0, # No dropout
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=gpt_neox_config.layer_norm_eps,
|
||||
initializer_range=gpt_neox_config.initializer_range,
|
||||
bos_token_id=gpt_neox_config.bos_token_id,
|
||||
eos_token_id=gpt_neox_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=True,
|
||||
parallel_block=gpt_neox_config.use_parallel_residual,
|
||||
parallel_block_tied_norm=False,
|
||||
rotary_emb_fraction=gpt_neox_config.rotary_pct,
|
||||
tie_word_embeddings=gpt_neox_config.tie_word_embeddings,
|
||||
)
|
||||
98
pkgs/xformers/_flash_attn/models/gptj.py
Normal file
98
pkgs/xformers/_flash_attn/models/gptj.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, GPTJConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_gptj(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return re.sub(r'^transformer.h.', 'transformer.layers.', key)
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('lm_head.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
output_embeddings_bias = state_dict.pop('lm_head.bias')
|
||||
state_dict['lm_head.bias'] = F.pad(
|
||||
output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key)
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these biases
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.bias')
|
||||
state_dict.pop(f'transformer.layers.{l}.attn.masked_bias')
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config:
|
||||
headdim = gptj_config.n_embd // gptj_config.n_head
|
||||
return GPT2Config(
|
||||
vocab_size=gptj_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=gptj_config.n_embd,
|
||||
n_layer=gptj_config.n_layer,
|
||||
n_head=gptj_config.n_head,
|
||||
n_inner=gptj_config.n_inner,
|
||||
activation_function=gptj_config.activation_function,
|
||||
resid_pdrop=gptj_config.resid_pdrop,
|
||||
embd_pdrop=gptj_config.embd_pdrop,
|
||||
attn_pdrop=gptj_config.attn_pdrop,
|
||||
layer_norm_epsilon=gptj_config.layer_norm_epsilon,
|
||||
initializer_range=gptj_config.initializer_range,
|
||||
bos_token_id=gptj_config.bos_token_id,
|
||||
eos_token_id=gptj_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=True,
|
||||
parallel_block=True,
|
||||
parallel_block_tied_norm=True,
|
||||
rotary_emb_fraction=gptj_config.rotary_dim / headdim,
|
||||
rotary_emb_interleaved=True,
|
||||
tie_word_embeddings=False,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
lm_head_bias=True,
|
||||
)
|
||||
124
pkgs/xformers/_flash_attn/models/llama.py
Normal file
124
pkgs/xformers/_flash_attn/models/llama.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, LlamaConfig
|
||||
|
||||
|
||||
def remap_state_dict_meta_llama(state_dict, config):
|
||||
def key_mapping_layers(key):
|
||||
return f'transformer.{key}' if not key.startswith('output.') else key
|
||||
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
||||
# Word embedding
|
||||
def key_mapping_emb(key):
|
||||
return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key)
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
if getattr(config, 'tie_word_embeddings'):
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
else:
|
||||
output_embeddings = state_dict.pop('output.weight')
|
||||
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
|
||||
# differently.
|
||||
vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
state_dict['lm_head.weight'] = F.pad(
|
||||
output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0])
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
for l in range(config.n_layer):
|
||||
w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight')
|
||||
w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight')
|
||||
# Our ordering is different
|
||||
state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0)
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.',
|
||||
r'transformer.layers.\1.mlp.fc2.', key)
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
# We don't store these
|
||||
state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).attention.wo.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig:
|
||||
"""Load a LlamaConfig from a checkpoint path."""
|
||||
with open(Path(checkpoint_path) / model_name / 'params.json') as f:
|
||||
params = json.load(f)
|
||||
config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None,
|
||||
num_attention_heads=params['n_heads'],
|
||||
num_hidden_layers=params['n_layers'],
|
||||
rms_norm_eps=params['norm_eps'])
|
||||
return config
|
||||
|
||||
|
||||
def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict:
|
||||
# Need to sort, otherwise we mess up the ordering and the weights are wrong
|
||||
return [torch.load(path, map_location='cpu')
|
||||
for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))]
|
||||
|
||||
|
||||
def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
|
||||
return GPT2Config(
|
||||
vocab_size=llama_config.vocab_size,
|
||||
n_positions=0, # No absolute position embedding
|
||||
n_embd=llama_config.hidden_size,
|
||||
n_layer=llama_config.num_hidden_layers,
|
||||
n_head=llama_config.num_attention_heads,
|
||||
n_inner=llama_config.intermediate_size,
|
||||
activation_function='swiglu', # Hardcode since HF calls it 'silu'
|
||||
# Llama doesn't have dropout, idk if it's because they only release the inference code
|
||||
resid_pdrop=0.0,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=llama_config.rms_norm_eps,
|
||||
initializer_range=llama_config.initializer_range,
|
||||
bos_token_id=llama_config.bos_token_id,
|
||||
eos_token_id=llama_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
pad_token_id=llama_config.pad_token_id, # Idk if this does anything
|
||||
rms_norm=True,
|
||||
rotary_emb_fraction=1.0,
|
||||
rotary_emb_interleaved=True,
|
||||
tie_word_embeddings=False,
|
||||
qkv_proj_bias=False,
|
||||
out_proj_bias=False,
|
||||
mlp_fc1_bias=False,
|
||||
mlp_fc2_bias=False,
|
||||
)
|
||||
102
pkgs/xformers/_flash_attn/models/opt.py
Normal file
102
pkgs/xformers/_flash_attn/models/opt.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
|
||||
import math
|
||||
import re
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import GPT2Config, OPTConfig
|
||||
|
||||
|
||||
def remap_state_dict_hf_opt(state_dict, config):
|
||||
def key_mapping_model(key):
|
||||
key = re.sub(r'^model.decoder.', 'transformer.', key)
|
||||
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
|
||||
key = re.sub(r'^decoder.', 'transformer.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items())
|
||||
# Word embedding and position embedding
|
||||
def key_mapping_emb(key):
|
||||
key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key)
|
||||
# The OPT-350m model uses has project_in and project_out
|
||||
key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key)
|
||||
key = re.sub(r'^transformer.project_out.', 'project_out.', key)
|
||||
key = re.sub(r'^transformer.embed_positions.',
|
||||
'transformer.embeddings.position_embeddings.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items())
|
||||
# OPT uses the first 2 indices of pos_emb for padding tokens
|
||||
pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight')
|
||||
state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:]
|
||||
word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight')
|
||||
# It's possible that vocab_size is padded to be a multiple of 8, for example.
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad(
|
||||
word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0])
|
||||
)
|
||||
state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight']
|
||||
|
||||
# LayerNorm
|
||||
def key_mapping_ln(key):
|
||||
key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key)
|
||||
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
|
||||
key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.',
|
||||
r'transformer.layers.\1.norm1.', key)
|
||||
key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.',
|
||||
r'transformer.layers.\1.norm2.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
||||
|
||||
# MLP
|
||||
def key_mapping_mlp(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).fc(1|2).',
|
||||
r'transformer.layers.\1.mlp.fc\2.', key)
|
||||
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
||||
|
||||
# Attention
|
||||
for l in range(config.n_layer):
|
||||
Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight')
|
||||
Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight')
|
||||
Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight')
|
||||
bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias')
|
||||
bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias')
|
||||
bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias')
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0)
|
||||
state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0)
|
||||
def key_mapping_attn(key):
|
||||
return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.',
|
||||
r'transformer.layers.\1.mixer.out_proj.', key)
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
|
||||
assert opt_config.layerdrop == 0.0
|
||||
assert opt_config.layer_norm_elementwise_affine
|
||||
word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size
|
||||
else opt_config.word_embed_proj_dim)
|
||||
return GPT2Config(
|
||||
vocab_size=opt_config.vocab_size,
|
||||
n_positions=opt_config.max_position_embeddings,
|
||||
n_embd=opt_config.hidden_size,
|
||||
n_layer=opt_config.num_hidden_layers,
|
||||
n_head=opt_config.num_attention_heads,
|
||||
n_inner=opt_config.ffn_dim,
|
||||
activation_function=opt_config.activation_function,
|
||||
resid_pdrop=opt_config.dropout,
|
||||
# HF's implementation of OPT doesn't seem to have embedding dropout
|
||||
embd_pdrop=opt_config.dropout,
|
||||
attn_pdrop=opt_config.attention_dropout,
|
||||
initializer_range=opt_config.init_std,
|
||||
bos_token_id=opt_config.bos_token_id,
|
||||
eos_token_id=opt_config.eos_token_id,
|
||||
# These are new arguments not in the original GPT2Config
|
||||
prenorm=opt_config.do_layer_norm_before,
|
||||
word_embed_proj_dim=word_embed_proj_dim
|
||||
)
|
||||
304
pkgs/xformers/_flash_attn/models/vit.py
Normal file
304
pkgs/xformers/_flash_attn/models/vit.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
import math
|
||||
import re
|
||||
from functools import partial
|
||||
from copy import deepcopy
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from torchvision.ops import StochasticDepth
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from timm.models.helpers import named_apply
|
||||
from flash_attn.layers.patch_embed import PatchEmbed
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
|
||||
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=False):
|
||||
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
|
||||
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
|
||||
use_flash_attn=use_flash_attn)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
|
||||
inner_dim = int(embed_dim * mlp_ratio)
|
||||
if not fused_mlp:
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
|
||||
else:
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
|
||||
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
|
||||
last_layer_subset=False):
|
||||
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
|
||||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
|
||||
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
||||
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
|
||||
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
|
||||
drop_path1=drop_path1, drop_path2=drop_path2,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True)
|
||||
return block
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=False,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
use_flash_attn=False,
|
||||
fused_bias_fc=False,
|
||||
fused_mlp=False,
|
||||
fused_dropout_add_ln=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'token')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool == 'token', 'Only support pooling with CLS token'
|
||||
assert class_token
|
||||
assert init_values is None, 'LayerScale is not supported yet'
|
||||
assert weight_init == ''
|
||||
assert fc_norm is None
|
||||
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
|
||||
assert not pre_norm
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
|
||||
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
|
||||
else {})
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
**patch_embed_extra_kwargs
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
# We change the order of dropout, residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of MLP). The model definition is unchanged, but the mapping of the
|
||||
# nn.Dropout probabilities are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.blocks = nn.ModuleList([create_block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
|
||||
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
|
||||
last_layer_subset=(global_pool == 'token')
|
||||
) for i in range(depth)])
|
||||
|
||||
self.dropout = nn.Dropout(p=drop_rate)
|
||||
self.drop_path = StochasticDepth(p=dpr[-1], mode='row')
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
|
||||
# Classifier Head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode == ''
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# this fn left here for compat with downstream users
|
||||
init_weights_vit_timm(m)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return x
|
||||
|
||||
def forward_features(self, x, all_tokens=True):
|
||||
"""
|
||||
If all_tokens==False and self.global_pool == 'token', we only return the features for the
|
||||
cls token.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
hidden_states = self._pos_embed(x)
|
||||
residual = None
|
||||
if self.global_pool != 'token' or all_tokens:
|
||||
# if True:
|
||||
for block in self.blocks:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
else:
|
||||
for block in self.blocks[:-1]:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
||||
# where the query is the 1st token and the key/value is the whole sequence.
|
||||
hidden_states, residual = self.blocks[-1](hidden_states, residual,
|
||||
mixer_subset=slice(0, 1))
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.drop_path(self.dropout(hidden_states)) + residual
|
||||
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
||||
else:
|
||||
if self.drop_path.p == 0 or not self.training:
|
||||
rowscale = None
|
||||
else:
|
||||
rowscale = self.drop_path(torch.ones(
|
||||
hidden_states.shape[:-1], device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
)
|
||||
# Set prenorm=False here since we don't need to the residual
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states, residual, self.norm.weight, self.norm.bias,
|
||||
self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale,
|
||||
prenorm=False, residual_in_fp32=True
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x, all_tokens=False)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
patch_embed_weight = state_dict['patch_embed.proj.weight']
|
||||
if patch_embed_weight.dim() == 4:
|
||||
# convert from Conv2d to Linear
|
||||
state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight,
|
||||
'o c h w -> o (c h w)')
|
||||
def key_mapping_attn(key):
|
||||
key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key)
|
||||
key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key)
|
||||
return key
|
||||
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
||||
n_layer = len(self.blocks)
|
||||
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
|
||||
if (self.blocks[-1].mixer.cross_attn
|
||||
and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict):
|
||||
Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight')
|
||||
bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias')
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim]
|
||||
state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:]
|
||||
return super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
assert not pretrained
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = VisionTransformer(**model_kwargs)
|
||||
return model
|
||||
Reference in New Issue
Block a user