First commit
This commit is contained in:
313
pkgs/xformers/factory/model_factory.py
Normal file
313
pkgs/xformers/factory/model_factory.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers._deprecation_warning import deprecated_function
|
||||
from xformers.components import reversible as rv
|
||||
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
|
||||
from xformers.factory.block_configs import (
|
||||
xFormerBlockConfig,
|
||||
xFormerDecoderConfig,
|
||||
xFormerEncoderConfig,
|
||||
)
|
||||
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
|
||||
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerConfig:
|
||||
"""
|
||||
The configuration structure to define a full Transformer.
|
||||
This can include a stack of encoder layers, and a stack of decoder layers.
|
||||
|
||||
It is optionally possible to share the embedding weights in between
|
||||
the encoder and decoder positional encoding, as proposed for instance by
|
||||
`Using the Output Embedding to Improve Language Models`, Press et al.
|
||||
|
||||
A full config example is for instance as follows:
|
||||
|
||||
::
|
||||
|
||||
xformer_config = [
|
||||
{
|
||||
"reversible": False, # Turn on to test the effect of using reversible layers
|
||||
"block_type": "encoder",
|
||||
"num_layers": LAYERS,
|
||||
"dim_model": EMB,
|
||||
"residual_norm_style": "pre",
|
||||
"position_encoding_config": {
|
||||
"name": "vocab",
|
||||
"seq_len": CONTEXT,
|
||||
"vocab_size": VOCAB_SIZE,
|
||||
},
|
||||
"multi_head_config": {
|
||||
"num_heads": NUM_HEADS,
|
||||
"residual_dropout": RES_DROP,
|
||||
"use_rotary_embeddings": True,
|
||||
"attention": {
|
||||
"name": ATTENTION_MECHANISM_STR,
|
||||
"dropout": ATTN_DROP,
|
||||
"causal": True,
|
||||
"seq_len": CONTEXT,
|
||||
},
|
||||
},
|
||||
"feedforward_config": {
|
||||
"name": "FusedMLP", # Use MLP if Triton is not available
|
||||
"dropout": MLP_DROP,
|
||||
"activation": "gelu",
|
||||
"hidden_layer_multiplier": MLP_MULTIPLIER,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
|
||||
"""
|
||||
|
||||
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
|
||||
tie_embedding_weights: bool = False
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
|
||||
tie_embedding_weights: bool = False,
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||||
):
|
||||
# Type all the configurations. Possible typos are caught here
|
||||
if isinstance(stack_configs, dict):
|
||||
self.stack_configs = {}
|
||||
for k, config in stack_configs.items():
|
||||
if config["block_type"] == "encoder":
|
||||
self.stack_configs[k] = xFormerEncoderConfig(**config)
|
||||
else:
|
||||
self.stack_configs[k] = xFormerDecoderConfig(**config)
|
||||
else:
|
||||
self.stack_configs = []
|
||||
for config in stack_configs:
|
||||
if config["block_type"] == "encoder":
|
||||
self.stack_configs.append(xFormerEncoderConfig(**config))
|
||||
else:
|
||||
self.stack_configs.append(xFormerDecoderConfig(**config))
|
||||
|
||||
self.tie_embedding_weights = tie_embedding_weights
|
||||
self.weight_init = weight_init
|
||||
deprecated_function(self)
|
||||
|
||||
|
||||
class xFormer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stack_configs: Union[
|
||||
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
|
||||
],
|
||||
tie_embedding_weights: bool = False,
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||||
):
|
||||
"""
|
||||
Given a serialized configuration, generate the corresponding model.
|
||||
This is only a helper and can easily be bypassed
|
||||
"""
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
if isinstance(stack_configs, Dict):
|
||||
stack_configs = list(stack_configs.values())
|
||||
|
||||
# Convenience, users can pass either a list of configs or a single one
|
||||
if not isinstance(stack_configs, List):
|
||||
stack_configs = [stack_configs]
|
||||
|
||||
# Sanity checks, some config combinations do not make sense
|
||||
self._verify_reversible(stack_configs)
|
||||
self._verify_deepnorm(stack_configs)
|
||||
|
||||
encoders: List[torch.nn.Module] = []
|
||||
decoders: List[torch.nn.Module] = []
|
||||
|
||||
self.reversible_encoder = False
|
||||
self.rev_enc_pose_encoding = None
|
||||
|
||||
# Unroll the configs and build the model
|
||||
for config in stack_configs:
|
||||
# Handle either Encoder or Decoder stacks
|
||||
builder = (
|
||||
xFormerEncoderBlock.from_config
|
||||
if isinstance(config, xFormerEncoderConfig)
|
||||
else xFormerDecoderBlock.from_config
|
||||
)
|
||||
recipient = (
|
||||
encoders if isinstance(config, xFormerEncoderConfig) else decoders
|
||||
)
|
||||
|
||||
# Build up the stack
|
||||
for i in range(config.num_layers):
|
||||
# Label where this layer is in the stack
|
||||
# (for instance useful for the positional encoding, or late layer norm)
|
||||
if len(recipient) > 0:
|
||||
config.layer_position.mark_not_first()
|
||||
|
||||
if config != stack_configs[-1] or i < config.num_layers - 1:
|
||||
config.layer_position.mark_not_last()
|
||||
|
||||
block = builder(config) # type: ignore
|
||||
|
||||
# If reversible: extract the reversible sub-parts, else append the block as-is
|
||||
if config.reversible:
|
||||
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
|
||||
assert isinstance(config, xFormerEncoderConfig)
|
||||
if block.pose_encoding is not None:
|
||||
self.rev_enc_pose_encoding = block.pose_encoding
|
||||
self.reversible_encoder = True
|
||||
|
||||
f, g = xFormerEncoderBlock.get_reversible_layer(config)
|
||||
recipient.append(torch.nn.ModuleList([f, g]))
|
||||
else:
|
||||
recipient.append(block) # type: ignore
|
||||
|
||||
# Tie embedding weights, if requested and possible
|
||||
assert (
|
||||
not tie_embedding_weights or not self.reversible_encoder
|
||||
), "Reversible layers and tied embeddings is not supported for now"
|
||||
|
||||
if (
|
||||
tie_embedding_weights
|
||||
and encoders
|
||||
and encoders[0].pose_encoding
|
||||
and decoders
|
||||
and decoders[0].pose_encoding
|
||||
and not config.reversible
|
||||
):
|
||||
logger.info("Tying encoder and decoder embeddings, as requested")
|
||||
encoders[0].pose_encoding = decoders[0].pose_encoding
|
||||
|
||||
self.encoders: torch.nn.Module = (
|
||||
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
|
||||
if self.reversible_encoder
|
||||
else torch.nn.ModuleList(encoders)
|
||||
)
|
||||
self.decoders = torch.nn.ModuleList(decoders)
|
||||
|
||||
use_deepnorm = (
|
||||
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
|
||||
)
|
||||
|
||||
assert (
|
||||
not use_deepnorm or not self.reversible_encoder
|
||||
), "Reversible layers and deepnorm is not supported for now"
|
||||
|
||||
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerConfig):
|
||||
return cls(
|
||||
config.stack_configs, config.tie_embedding_weights, config.weight_init
|
||||
)
|
||||
|
||||
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
|
||||
reversible = [
|
||||
c.reversible
|
||||
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
|
||||
]
|
||||
|
||||
assert all(reversible) or not any(reversible), (
|
||||
"All layers need to have the same reversibility setting. "
|
||||
+ f"Currently {reversible}"
|
||||
)
|
||||
|
||||
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
|
||||
deepnorm = [
|
||||
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
|
||||
]
|
||||
|
||||
assert all(deepnorm) or not any(deepnorm), (
|
||||
"All layers need to have the same deepnorm setting. "
|
||||
+ f"Currently {deepnorm}"
|
||||
)
|
||||
|
||||
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
|
||||
# The deepnorm weight initialization method requires different gain factors for the encoder
|
||||
# and decoder, depending on the general model structure (number of respective layers)
|
||||
if use_deep_norm:
|
||||
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
|
||||
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
|
||||
)
|
||||
else:
|
||||
encoder_coefficients, decoder_coefficients = None, None
|
||||
|
||||
encoder_gain = (
|
||||
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
|
||||
)
|
||||
decoder_gain = (
|
||||
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
|
||||
)
|
||||
|
||||
# Pick the desired init function
|
||||
init_fn = get_weight_init_fn(weight_init)
|
||||
|
||||
# Initialize all the encoder weights
|
||||
for name, module in self.encoders.named_children():
|
||||
init_fn(module=module, name=name, gain=encoder_gain)
|
||||
|
||||
for name, module in self.decoders.named_children():
|
||||
init_fn(module=module, name=name, gain=decoder_gain)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
tgt: Optional[torch.Tensor] = None,
|
||||
encoder_input_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_mask: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
# Encode to latent space if encoder is present
|
||||
if len(list(self.encoders.parameters())) > 0:
|
||||
encoders = self.encoders
|
||||
memory = src.clone()
|
||||
if isinstance(encoders, torch.nn.ModuleList):
|
||||
for encoder in encoders:
|
||||
memory = encoder(memory, input_mask=encoder_input_mask)
|
||||
else:
|
||||
if self.rev_enc_pose_encoding:
|
||||
memory = self.rev_enc_pose_encoding(src)
|
||||
|
||||
# Reversible Encoder
|
||||
x = torch.cat([memory, memory], dim=-1)
|
||||
|
||||
# Apply the optional input masking
|
||||
if encoder_input_mask is not None:
|
||||
if x.dim() - encoder_input_mask.dim() > 1:
|
||||
encoder_input_mask.unsqueeze(0)
|
||||
x += encoder_input_mask.unsqueeze(-1)
|
||||
|
||||
x = encoders(x)
|
||||
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
|
||||
|
||||
if not self.decoders:
|
||||
return memory
|
||||
|
||||
# If decoder: either use the encoder ouput, or just decode, both options are possible
|
||||
if len(self.decoders) > 0:
|
||||
tgt = src.clone() if tgt is None else tgt
|
||||
|
||||
for decoder in self.decoders:
|
||||
tgt = decoder(
|
||||
target=tgt,
|
||||
# pyre-fixme[61]: `memory` is not always initialized here.
|
||||
memory=memory,
|
||||
input_mask=decoder_input_mask,
|
||||
)
|
||||
|
||||
return tgt
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user