314 lines
12 KiB
Python
314 lines
12 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any, Dict, List, Optional, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from xformers._deprecation_warning import deprecated_function
|
||
|
|
from xformers.components import reversible as rv
|
||
|
|
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
|
||
|
|
from xformers.factory.block_configs import (
|
||
|
|
xFormerBlockConfig,
|
||
|
|
xFormerDecoderConfig,
|
||
|
|
xFormerEncoderConfig,
|
||
|
|
)
|
||
|
|
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
|
||
|
|
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
|
||
|
|
|
||
|
|
logger = logging.getLogger("xformers")
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(init=False)
|
||
|
|
class xFormerConfig:
|
||
|
|
"""
|
||
|
|
The configuration structure to define a full Transformer.
|
||
|
|
This can include a stack of encoder layers, and a stack of decoder layers.
|
||
|
|
|
||
|
|
It is optionally possible to share the embedding weights in between
|
||
|
|
the encoder and decoder positional encoding, as proposed for instance by
|
||
|
|
`Using the Output Embedding to Improve Language Models`, Press et al.
|
||
|
|
|
||
|
|
A full config example is for instance as follows:
|
||
|
|
|
||
|
|
::
|
||
|
|
|
||
|
|
xformer_config = [
|
||
|
|
{
|
||
|
|
"reversible": False, # Turn on to test the effect of using reversible layers
|
||
|
|
"block_type": "encoder",
|
||
|
|
"num_layers": LAYERS,
|
||
|
|
"dim_model": EMB,
|
||
|
|
"residual_norm_style": "pre",
|
||
|
|
"position_encoding_config": {
|
||
|
|
"name": "vocab",
|
||
|
|
"seq_len": CONTEXT,
|
||
|
|
"vocab_size": VOCAB_SIZE,
|
||
|
|
},
|
||
|
|
"multi_head_config": {
|
||
|
|
"num_heads": NUM_HEADS,
|
||
|
|
"residual_dropout": RES_DROP,
|
||
|
|
"use_rotary_embeddings": True,
|
||
|
|
"attention": {
|
||
|
|
"name": ATTENTION_MECHANISM_STR,
|
||
|
|
"dropout": ATTN_DROP,
|
||
|
|
"causal": True,
|
||
|
|
"seq_len": CONTEXT,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
"feedforward_config": {
|
||
|
|
"name": "FusedMLP", # Use MLP if Triton is not available
|
||
|
|
"dropout": MLP_DROP,
|
||
|
|
"activation": "gelu",
|
||
|
|
"hidden_layer_multiplier": MLP_MULTIPLIER,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
|
||
|
|
"""
|
||
|
|
|
||
|
|
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
|
||
|
|
tie_embedding_weights: bool = False
|
||
|
|
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
|
||
|
|
tie_embedding_weights: bool = False,
|
||
|
|
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||
|
|
):
|
||
|
|
# Type all the configurations. Possible typos are caught here
|
||
|
|
if isinstance(stack_configs, dict):
|
||
|
|
self.stack_configs = {}
|
||
|
|
for k, config in stack_configs.items():
|
||
|
|
if config["block_type"] == "encoder":
|
||
|
|
self.stack_configs[k] = xFormerEncoderConfig(**config)
|
||
|
|
else:
|
||
|
|
self.stack_configs[k] = xFormerDecoderConfig(**config)
|
||
|
|
else:
|
||
|
|
self.stack_configs = []
|
||
|
|
for config in stack_configs:
|
||
|
|
if config["block_type"] == "encoder":
|
||
|
|
self.stack_configs.append(xFormerEncoderConfig(**config))
|
||
|
|
else:
|
||
|
|
self.stack_configs.append(xFormerDecoderConfig(**config))
|
||
|
|
|
||
|
|
self.tie_embedding_weights = tie_embedding_weights
|
||
|
|
self.weight_init = weight_init
|
||
|
|
deprecated_function(self)
|
||
|
|
|
||
|
|
|
||
|
|
class xFormer(torch.nn.Module):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
stack_configs: Union[
|
||
|
|
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
|
||
|
|
],
|
||
|
|
tie_embedding_weights: bool = False,
|
||
|
|
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Given a serialized configuration, generate the corresponding model.
|
||
|
|
This is only a helper and can easily be bypassed
|
||
|
|
"""
|
||
|
|
super().__init__()
|
||
|
|
deprecated_function(self)
|
||
|
|
|
||
|
|
if isinstance(stack_configs, Dict):
|
||
|
|
stack_configs = list(stack_configs.values())
|
||
|
|
|
||
|
|
# Convenience, users can pass either a list of configs or a single one
|
||
|
|
if not isinstance(stack_configs, List):
|
||
|
|
stack_configs = [stack_configs]
|
||
|
|
|
||
|
|
# Sanity checks, some config combinations do not make sense
|
||
|
|
self._verify_reversible(stack_configs)
|
||
|
|
self._verify_deepnorm(stack_configs)
|
||
|
|
|
||
|
|
encoders: List[torch.nn.Module] = []
|
||
|
|
decoders: List[torch.nn.Module] = []
|
||
|
|
|
||
|
|
self.reversible_encoder = False
|
||
|
|
self.rev_enc_pose_encoding = None
|
||
|
|
|
||
|
|
# Unroll the configs and build the model
|
||
|
|
for config in stack_configs:
|
||
|
|
# Handle either Encoder or Decoder stacks
|
||
|
|
builder = (
|
||
|
|
xFormerEncoderBlock.from_config
|
||
|
|
if isinstance(config, xFormerEncoderConfig)
|
||
|
|
else xFormerDecoderBlock.from_config
|
||
|
|
)
|
||
|
|
recipient = (
|
||
|
|
encoders if isinstance(config, xFormerEncoderConfig) else decoders
|
||
|
|
)
|
||
|
|
|
||
|
|
# Build up the stack
|
||
|
|
for i in range(config.num_layers):
|
||
|
|
# Label where this layer is in the stack
|
||
|
|
# (for instance useful for the positional encoding, or late layer norm)
|
||
|
|
if len(recipient) > 0:
|
||
|
|
config.layer_position.mark_not_first()
|
||
|
|
|
||
|
|
if config != stack_configs[-1] or i < config.num_layers - 1:
|
||
|
|
config.layer_position.mark_not_last()
|
||
|
|
|
||
|
|
block = builder(config) # type: ignore
|
||
|
|
|
||
|
|
# If reversible: extract the reversible sub-parts, else append the block as-is
|
||
|
|
if config.reversible:
|
||
|
|
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
|
||
|
|
assert isinstance(config, xFormerEncoderConfig)
|
||
|
|
if block.pose_encoding is not None:
|
||
|
|
self.rev_enc_pose_encoding = block.pose_encoding
|
||
|
|
self.reversible_encoder = True
|
||
|
|
|
||
|
|
f, g = xFormerEncoderBlock.get_reversible_layer(config)
|
||
|
|
recipient.append(torch.nn.ModuleList([f, g]))
|
||
|
|
else:
|
||
|
|
recipient.append(block) # type: ignore
|
||
|
|
|
||
|
|
# Tie embedding weights, if requested and possible
|
||
|
|
assert (
|
||
|
|
not tie_embedding_weights or not self.reversible_encoder
|
||
|
|
), "Reversible layers and tied embeddings is not supported for now"
|
||
|
|
|
||
|
|
if (
|
||
|
|
tie_embedding_weights
|
||
|
|
and encoders
|
||
|
|
and encoders[0].pose_encoding
|
||
|
|
and decoders
|
||
|
|
and decoders[0].pose_encoding
|
||
|
|
and not config.reversible
|
||
|
|
):
|
||
|
|
logger.info("Tying encoder and decoder embeddings, as requested")
|
||
|
|
encoders[0].pose_encoding = decoders[0].pose_encoding
|
||
|
|
|
||
|
|
self.encoders: torch.nn.Module = (
|
||
|
|
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
|
||
|
|
if self.reversible_encoder
|
||
|
|
else torch.nn.ModuleList(encoders)
|
||
|
|
)
|
||
|
|
self.decoders = torch.nn.ModuleList(decoders)
|
||
|
|
|
||
|
|
use_deepnorm = (
|
||
|
|
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
|
||
|
|
)
|
||
|
|
|
||
|
|
assert (
|
||
|
|
not use_deepnorm or not self.reversible_encoder
|
||
|
|
), "Reversible layers and deepnorm is not supported for now"
|
||
|
|
|
||
|
|
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_config(cls, config: xFormerConfig):
|
||
|
|
return cls(
|
||
|
|
config.stack_configs, config.tie_embedding_weights, config.weight_init
|
||
|
|
)
|
||
|
|
|
||
|
|
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
|
||
|
|
reversible = [
|
||
|
|
c.reversible
|
||
|
|
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
|
||
|
|
]
|
||
|
|
|
||
|
|
assert all(reversible) or not any(reversible), (
|
||
|
|
"All layers need to have the same reversibility setting. "
|
||
|
|
+ f"Currently {reversible}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
|
||
|
|
deepnorm = [
|
||
|
|
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
|
||
|
|
]
|
||
|
|
|
||
|
|
assert all(deepnorm) or not any(deepnorm), (
|
||
|
|
"All layers need to have the same deepnorm setting. "
|
||
|
|
+ f"Currently {deepnorm}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
|
||
|
|
# The deepnorm weight initialization method requires different gain factors for the encoder
|
||
|
|
# and decoder, depending on the general model structure (number of respective layers)
|
||
|
|
if use_deep_norm:
|
||
|
|
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
|
||
|
|
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
encoder_coefficients, decoder_coefficients = None, None
|
||
|
|
|
||
|
|
encoder_gain = (
|
||
|
|
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
|
||
|
|
)
|
||
|
|
decoder_gain = (
|
||
|
|
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
|
||
|
|
)
|
||
|
|
|
||
|
|
# Pick the desired init function
|
||
|
|
init_fn = get_weight_init_fn(weight_init)
|
||
|
|
|
||
|
|
# Initialize all the encoder weights
|
||
|
|
for name, module in self.encoders.named_children():
|
||
|
|
init_fn(module=module, name=name, gain=encoder_gain)
|
||
|
|
|
||
|
|
for name, module in self.decoders.named_children():
|
||
|
|
init_fn(module=module, name=name, gain=decoder_gain)
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self,
|
||
|
|
src: torch.Tensor,
|
||
|
|
tgt: Optional[torch.Tensor] = None,
|
||
|
|
encoder_input_mask: Optional[torch.Tensor] = None,
|
||
|
|
decoder_input_mask: Optional[torch.Tensor] = None,
|
||
|
|
) -> Optional[torch.Tensor]:
|
||
|
|
|
||
|
|
# Encode to latent space if encoder is present
|
||
|
|
if len(list(self.encoders.parameters())) > 0:
|
||
|
|
encoders = self.encoders
|
||
|
|
memory = src.clone()
|
||
|
|
if isinstance(encoders, torch.nn.ModuleList):
|
||
|
|
for encoder in encoders:
|
||
|
|
memory = encoder(memory, input_mask=encoder_input_mask)
|
||
|
|
else:
|
||
|
|
if self.rev_enc_pose_encoding:
|
||
|
|
memory = self.rev_enc_pose_encoding(src)
|
||
|
|
|
||
|
|
# Reversible Encoder
|
||
|
|
x = torch.cat([memory, memory], dim=-1)
|
||
|
|
|
||
|
|
# Apply the optional input masking
|
||
|
|
if encoder_input_mask is not None:
|
||
|
|
if x.dim() - encoder_input_mask.dim() > 1:
|
||
|
|
encoder_input_mask.unsqueeze(0)
|
||
|
|
x += encoder_input_mask.unsqueeze(-1)
|
||
|
|
|
||
|
|
x = encoders(x)
|
||
|
|
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
|
||
|
|
|
||
|
|
if not self.decoders:
|
||
|
|
return memory
|
||
|
|
|
||
|
|
# If decoder: either use the encoder ouput, or just decode, both options are possible
|
||
|
|
if len(self.decoders) > 0:
|
||
|
|
tgt = src.clone() if tgt is None else tgt
|
||
|
|
|
||
|
|
for decoder in self.decoders:
|
||
|
|
tgt = decoder(
|
||
|
|
target=tgt,
|
||
|
|
# pyre-fixme[61]: `memory` is not always initialized here.
|
||
|
|
memory=memory,
|
||
|
|
input_mask=decoder_input_mask,
|
||
|
|
)
|
||
|
|
|
||
|
|
return tgt
|
||
|
|
|
||
|
|
return None
|