First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
from xformers.components import MultiHeadDispatchConfig # noqa
from xformers.components.attention import AttentionConfig # noqa
from xformers.components.feedforward import FeedforwardConfig # noqa
from xformers.components.positional_embedding import PositionEmbeddingConfig # noqa
from .block_factory import xFormerDecoderBlock # noqa
from .block_factory import xFormerDecoderConfig # noqa
from .block_factory import xFormerEncoderBlock # noqa
from .block_factory import xFormerEncoderConfig # noqa
from .model_factory import xFormer, xFormerConfig # noqa
from .weight_init import xFormerWeightInit # noqa

View File

@@ -0,0 +1,237 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional
from xformers.components import NormalizationType, ResidualNormStyle
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
from xformers.components.positional_embedding import (
POSITION_EMBEDDING_REGISTRY,
PositionEmbeddingConfig,
)
from xformers.utils import generate_matching_config
class LayerPositionBitmask(int, Enum):
First = 0b01
Last = 0b10
Default = 0b11
class LayerPosition:
"""Bitmask to mark this layer as first, last, nothing or both"""
def __init__(self):
self.bitmask = LayerPositionBitmask.Default
def is_first(self):
return bool(self.bitmask & LayerPositionBitmask.First)
def is_last(self):
return bool(self.bitmask & LayerPositionBitmask.Last)
def mark_not_first(self):
self.bitmask &= ~LayerPositionBitmask.First
def mark_not_last(self):
self.bitmask &= ~LayerPositionBitmask.Last
class BlockType(str, Enum):
Encoder = "encoder"
Decoder = "decoder"
@dataclass(init=False) # handle constructors explicitly to force type changes
class xFormerBlockConfig:
"""
The configuration structure to define a Transformer block.
This base class is applicable to both encoder and decoder definitions.
This completely defines each of the blocks, for instance in terms of dimensions,
position encoding, pre or post layer norms or reversibility.
"""
dim_model: int
feedforward_config: FeedforwardConfig
position_encoding_config: Optional[PositionEmbeddingConfig]
block_type: BlockType
residual_norm_style: ResidualNormStyle
normalization: NormalizationType
layer_position: LayerPosition
use_triton: bool
reversible: bool
num_layers: int
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]],
block_type: BlockType,
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
normalization: NormalizationType = NormalizationType.LayerNorm,
reversible: bool = False,
num_layers: int = 1,
layer_position: Optional[LayerPosition] = None,
):
self.dim_model = dim_model
self.block_type = block_type
self.residual_norm_style = residual_norm_style
self.reversible = reversible
self.num_layers = num_layers
self.normalization = normalization
# Fill in possible gaps in the config for subparts of the block
self.feedforward_config = generate_matching_config(
feedforward_config,
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
)
self.position_encoding_config = (
generate_matching_config(
position_encoding_config,
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
)
if position_encoding_config is not None
else None
)
# Default is that this layer is the only one, so both first and last
if layer_position:
self.layer_position = layer_position
else:
self.layer_position = LayerPosition()
@dataclass(init=False)
class xFormerEncoderConfig(xFormerBlockConfig):
"""
The configuration structure for an encoder block
"""
multi_head_config: Dict[str, Any]
use_triton: bool
simplicial_embeddings: Optional[Dict[str, Any]]
patch_embedding_config: Optional[Dict[str, Any]]
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
multi_head_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]] = None,
residual_norm_style: str = "post",
normalization: NormalizationType = NormalizationType.LayerNorm,
use_triton: bool = True,
simplicial_embeddings: Optional[Dict[str, Any]] = None,
patch_embedding_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# Convenience, fill in duplicated fields
try:
if "dim_model" not in multi_head_config.keys():
multi_head_config["dim_model"] = dim_model
if "dim_model" not in feedforward_config.keys():
feedforward_config["dim_model"] = dim_model
if (
position_encoding_config is not None
and "dim_model" not in position_encoding_config.keys()
):
position_encoding_config["dim_model"] = dim_model
if (
patch_embedding_config is not None
and "out_channels" not in patch_embedding_config.keys()
):
patch_embedding_config["out_channels"] = dim_model
except AttributeError:
# A config instance was passed in, this is fine
pass
if "block_type" in kwargs:
assert kwargs["block_type"] == "encoder"
kwargs["block_type"] = BlockType("encoder")
super().__init__(
dim_model=dim_model,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
residual_norm_style=ResidualNormStyle(residual_norm_style),
normalization=NormalizationType(normalization),
**kwargs,
)
self.multi_head_config = multi_head_config
self.use_triton = use_triton
self.simplicial_embeddings = simplicial_embeddings
self.patch_embedding_config = patch_embedding_config
@dataclass(init=False)
class xFormerDecoderConfig(xFormerBlockConfig):
"""
The configuration structure for a decoder block.
This specifically defines the masked and cross attention mechanisms,
on top of the settings defining all blocks.
"""
multi_head_config_masked: Dict[str, Any] # prior to encoder output
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
multi_head_config_masked: Dict[str, Any],
multi_head_config_cross: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]] = None,
residual_norm_style: str = "post",
normalization: NormalizationType = NormalizationType.LayerNorm,
use_triton: bool = True,
**kwargs,
):
# Convenience, fill in duplicated field
try:
if "dim_model" not in multi_head_config_masked.keys():
multi_head_config_masked["dim_model"] = dim_model
if "dim_model" not in multi_head_config_cross.keys():
multi_head_config_cross["dim_model"] = dim_model
if "dim_model" not in feedforward_config.keys():
feedforward_config["dim_model"] = dim_model
if (
position_encoding_config is not None
and "dim_model" not in position_encoding_config.keys()
):
position_encoding_config["dim_model"] = dim_model
except AttributeError:
# A config instance was passed in, this is fine
pass
if "block_type" in kwargs.keys():
assert kwargs["block_type"] == "decoder"
kwargs["block_type"] = BlockType("decoder")
super().__init__(
dim_model=dim_model,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
residual_norm_style=ResidualNormStyle(residual_norm_style),
normalization=NormalizationType(normalization),
**kwargs,
)
self.multi_head_config_masked = multi_head_config_masked
self.multi_head_config_cross = multi_head_config_cross
self.use_triton = use_triton

View File

@@ -0,0 +1,358 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import asdict
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from xformers._deprecation_warning import deprecated_function
from xformers.components import (
PatchEmbeddingConfig,
PostNorm,
PreNorm,
Residual,
ResidualNormStyle,
build_multi_head_attention,
build_patch_embedding,
)
from xformers.components.attention import AttentionMask
from xformers.components.feedforward import build_feedforward
from xformers.components.positional_embedding import build_positional_embedding
from xformers.components.residual import get_deepnorm_coefficients
from xformers.components.simplicial_embedding import SimplicialEmbedding
from xformers.factory.block_configs import (
NormalizationType,
xFormerDecoderConfig,
xFormerEncoderConfig,
)
logger = logging.getLogger("xformers")
def _get_ln_factory(
d_model: int,
residual_norm_style: Optional[ResidualNormStyle],
use_triton: bool,
residual: bool,
normalization: NormalizationType = NormalizationType.LayerNorm,
residual_scale: float = 1.0,
):
"""
Handle all the supported residual path configurations.
..Note: we return the appropriate constructor, not an actual layer
"""
def get_layer_wrapper(
d_model: int,
sublayer: nn.Module,
residual_norm_style: Optional[ResidualNormStyle],
residual: bool,
residual_scale: float,
):
if residual:
if residual_norm_style == ResidualNormStyle.Pre:
return Residual(
layer=PreNorm(d_model, sublayer, normalization, use_triton),
scale=None,
)
elif residual_norm_style == ResidualNormStyle.Post:
return PostNorm(
d_model,
Residual(layer=sublayer, scale=None),
normalization,
use_triton,
)
elif residual_norm_style == ResidualNormStyle.DeepNorm:
return PostNorm(
d_model,
Residual(layer=sublayer, scale=residual_scale),
normalization,
use_triton=use_triton,
)
else:
raise ValueError
return (
PreNorm(d_model, sublayer, normalization, use_triton)
if residual_norm_style == ResidualNormStyle.Pre
else PostNorm(d_model, sublayer, normalization, use_triton)
)
def ln_factory(sublayer: nn.Module):
return get_layer_wrapper(
d_model, sublayer, residual_norm_style, residual, residual_scale
)
return ln_factory
class xFormerEncoderBlock(torch.nn.Module):
r"""A vanilla Transformer Encoder block"""
def __init__(self, config: xFormerEncoderConfig, **kwargs):
super().__init__()
deprecated_function(self)
self.reversible_f = None
self.reversible_g = None
self.residual_norm_style = config.residual_norm_style
self.dim_model = config.dim_model
# If this layer is the first one, and a pose encoding has been requested
if (
config.position_encoding_config is not None
and config.layer_position.is_first()
):
self.pose_encoding = build_positional_embedding(
asdict(config.position_encoding_config)
)
pos_encoding_dim = config.position_encoding_config.dim_model
mha_dim = config.multi_head_config["dim_model"]
if pos_encoding_dim != mha_dim:
logger.warning(
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
)
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
else:
self.pose_encoding = None
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
# Just use the layer norm coefficient here,
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
deep_norm_coefficients, _ = get_deepnorm_coefficients(
encoder_layers=config.num_layers, decoder_layers=0
)
assert deep_norm_coefficients is not None
residual_scale = deep_norm_coefficients.alpha
else:
residual_scale = 1.0
# mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
use_triton=config.use_triton,
residual=True,
residual_scale=residual_scale,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config)
feedforward = build_feedforward(asdict(config.feedforward_config))
# Expose attention specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.causal = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
if (
config.residual_norm_style == ResidualNormStyle.Pre
and config.layer_position.is_last()
):
self.wrap_ff = PostNorm(
config.dim_model,
self.wrap_ff,
normalization=config.normalization,
use_triton=config.use_triton,
)
# Simplicial embeddings are only used if specified, and on the last layer
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
if config.simplicial_embeddings is not None and config.layer_position.is_last():
self.simplicial_embedding = SimplicialEmbedding(
**config.simplicial_embeddings
)
# Optional patch embedding
self.patch_emb: Optional[nn.Module] = None
if config.patch_embedding_config is not None:
self.patch_emb = build_patch_embedding(
PatchEmbeddingConfig(**config.patch_embedding_config)
)
@classmethod
def from_config(cls, config: xFormerEncoderConfig):
return cls(config)
@staticmethod
def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
residual=False,
use_triton=config.use_triton,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config)
feedforward = build_feedforward(asdict(config.feedforward_config))
reversible_f = ln_factory(mha)
reversible_g = ln_factory(feedforward)
return reversible_f, reversible_g
def forward(
self,
x: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.patch_emb is not None:
x = self.patch_emb(x)
if self.pose_encoding is not None:
x = self.pose_encoding(x)
if hasattr(self, "embedding_projector"):
x = self.embedding_projector(x)
# Handle the optional input masking, differs on Q, K, V
if input_mask is not None:
q = x
k = x * input_mask.unsqueeze(-1)
v = k
else:
q, k, v = x, x, x
# Pre/Post norms and residual paths are already handled
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
x = self.wrap_ff(inputs=[x])
# Optional simplicial embeddings
if self.simplicial_embedding is not None:
x = self.simplicial_embedding(x)
return x
class xFormerDecoderBlock(torch.nn.Module):
r"""A vanilla Transformer Decoder block
... note: this implementation is not (yet ?) reversible"""
def __init__(self, config: xFormerDecoderConfig, **kwargs):
super().__init__()
deprecated_function(self)
# If this layer is the first one, and a pose encoding as been requested
if (
config.position_encoding_config is not None
and config.layer_position.is_first()
):
self.pose_encoding = build_positional_embedding(
config.position_encoding_config
)
pos_encoding_dim = config.position_encoding_config.dim_model
mha_dim = config.multi_head_config_masked["dim_model"]
if pos_encoding_dim != mha_dim:
logger.warning(
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
)
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
else:
self.pose_encoding = None
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
# Just use the layer norm coefficient here,
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
_, deep_norm_coefficients = get_deepnorm_coefficients(
encoder_layers=0, decoder_layers=config.num_layers
)
assert deep_norm_coefficients is not None
residual_scale = deep_norm_coefficients.alpha
else:
residual_scale = 1.0
# mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
ln_factory = _get_ln_factory(
config.dim_model,
config.residual_norm_style,
use_triton=config.use_triton,
residual=True,
residual_scale=residual_scale,
normalization=config.normalization,
)
mha = build_multi_head_attention(config.multi_head_config_masked)
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
feedforward = build_feedforward(config.feedforward_config)
# Expose attention or feedforward specific capabilities
self.supports_attention_mask = mha.attention.supports_attention_mask
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
self.requires_squared_context_length = (
feedforward.requires_squared_context
or mha.attention.requires_squared_context
)
self.causal_attention = (
mha.attention.causal if hasattr(mha.attention, "causal") else False
)
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(mha)
self.wrap_cross = ln_factory(cross_mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
if (
config.residual_norm_style == ResidualNormStyle.Pre
and config.layer_position.is_last()
):
self.wrap_ff = PostNorm(
config.dim_model,
self.wrap_ff,
normalization=NormalizationType.LayerNorm,
)
@classmethod
def from_config(cls, config: xFormerDecoderConfig):
return cls(config)
def forward(
self,
target: torch.Tensor,
memory: torch.Tensor,
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
input_mask: Optional[torch.Tensor] = None,
):
if self.pose_encoding is not None:
target = self.pose_encoding(target)
if hasattr(self, "embedding_projector"):
target = self.embedding_projector(target)
# Handle the optional input masking, differs on Q, K, V
if input_mask is not None:
target_q = target
target_k = target * input_mask.unsqueeze(-1)
target_v = target_k
else:
target_q, target_k, target_v = target, target, target
x = self.wrap_att(
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
)
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
x = self.wrap_ff(inputs=[x])
return x

View File

@@ -0,0 +1,36 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# register components configs into Hydra ConfigStore
# component config classes could be used to validate configs
import logging
from hydra.core.config_store import ConfigStore
from omegaconf.errors import ValidationError
from xformers.components.attention import ATTENTION_REGISTRY
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY
logger = logging.getLogger("xformers")
def import_xformer_config_schema():
"""
Best effort - OmegaConf supports limited typing, so we may fail to import
certain config classes. For example, pytorch typing are not supported.
"""
cs = ConfigStore.instance()
for k, v in {
"ff": FEEDFORWARD_REGISTRY,
"pe": POSITION_EMBEDDING_REGISTRY,
"attention": ATTENTION_REGISTRY,
}.items():
for kk in v.keys():
try:
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}")
except ValidationError as e:
logger.debug(f"Error registering {kk}_schema, error: {e}")

View File

@@ -0,0 +1,313 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
import torch
from xformers._deprecation_warning import deprecated_function
from xformers.components import reversible as rv
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
from xformers.factory.block_configs import (
xFormerBlockConfig,
xFormerDecoderConfig,
xFormerEncoderConfig,
)
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
logger = logging.getLogger("xformers")
@dataclass(init=False)
class xFormerConfig:
"""
The configuration structure to define a full Transformer.
This can include a stack of encoder layers, and a stack of decoder layers.
It is optionally possible to share the embedding weights in between
the encoder and decoder positional encoding, as proposed for instance by
`Using the Output Embedding to Improve Language Models`, Press et al.
A full config example is for instance as follows:
::
xformer_config = [
{
"reversible": False, # Turn on to test the effect of using reversible layers
"block_type": "encoder",
"num_layers": LAYERS,
"dim_model": EMB,
"residual_norm_style": "pre",
"position_encoding_config": {
"name": "vocab",
"seq_len": CONTEXT,
"vocab_size": VOCAB_SIZE,
},
"multi_head_config": {
"num_heads": NUM_HEADS,
"residual_dropout": RES_DROP,
"use_rotary_embeddings": True,
"attention": {
"name": ATTENTION_MECHANISM_STR,
"dropout": ATTN_DROP,
"causal": True,
"seq_len": CONTEXT,
},
},
"feedforward_config": {
"name": "FusedMLP", # Use MLP if Triton is not available
"dropout": MLP_DROP,
"activation": "gelu",
"hidden_layer_multiplier": MLP_MULTIPLIER,
},
}
]
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
"""
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
tie_embedding_weights: bool = False
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
def __init__(
self,
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
tie_embedding_weights: bool = False,
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
):
# Type all the configurations. Possible typos are caught here
if isinstance(stack_configs, dict):
self.stack_configs = {}
for k, config in stack_configs.items():
if config["block_type"] == "encoder":
self.stack_configs[k] = xFormerEncoderConfig(**config)
else:
self.stack_configs[k] = xFormerDecoderConfig(**config)
else:
self.stack_configs = []
for config in stack_configs:
if config["block_type"] == "encoder":
self.stack_configs.append(xFormerEncoderConfig(**config))
else:
self.stack_configs.append(xFormerDecoderConfig(**config))
self.tie_embedding_weights = tie_embedding_weights
self.weight_init = weight_init
deprecated_function(self)
class xFormer(torch.nn.Module):
def __init__(
self,
stack_configs: Union[
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
],
tie_embedding_weights: bool = False,
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
):
"""
Given a serialized configuration, generate the corresponding model.
This is only a helper and can easily be bypassed
"""
super().__init__()
deprecated_function(self)
if isinstance(stack_configs, Dict):
stack_configs = list(stack_configs.values())
# Convenience, users can pass either a list of configs or a single one
if not isinstance(stack_configs, List):
stack_configs = [stack_configs]
# Sanity checks, some config combinations do not make sense
self._verify_reversible(stack_configs)
self._verify_deepnorm(stack_configs)
encoders: List[torch.nn.Module] = []
decoders: List[torch.nn.Module] = []
self.reversible_encoder = False
self.rev_enc_pose_encoding = None
# Unroll the configs and build the model
for config in stack_configs:
# Handle either Encoder or Decoder stacks
builder = (
xFormerEncoderBlock.from_config
if isinstance(config, xFormerEncoderConfig)
else xFormerDecoderBlock.from_config
)
recipient = (
encoders if isinstance(config, xFormerEncoderConfig) else decoders
)
# Build up the stack
for i in range(config.num_layers):
# Label where this layer is in the stack
# (for instance useful for the positional encoding, or late layer norm)
if len(recipient) > 0:
config.layer_position.mark_not_first()
if config != stack_configs[-1] or i < config.num_layers - 1:
config.layer_position.mark_not_last()
block = builder(config) # type: ignore
# If reversible: extract the reversible sub-parts, else append the block as-is
if config.reversible:
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
assert isinstance(config, xFormerEncoderConfig)
if block.pose_encoding is not None:
self.rev_enc_pose_encoding = block.pose_encoding
self.reversible_encoder = True
f, g = xFormerEncoderBlock.get_reversible_layer(config)
recipient.append(torch.nn.ModuleList([f, g]))
else:
recipient.append(block) # type: ignore
# Tie embedding weights, if requested and possible
assert (
not tie_embedding_weights or not self.reversible_encoder
), "Reversible layers and tied embeddings is not supported for now"
if (
tie_embedding_weights
and encoders
and encoders[0].pose_encoding
and decoders
and decoders[0].pose_encoding
and not config.reversible
):
logger.info("Tying encoder and decoder embeddings, as requested")
encoders[0].pose_encoding = decoders[0].pose_encoding
self.encoders: torch.nn.Module = (
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
if self.reversible_encoder
else torch.nn.ModuleList(encoders)
)
self.decoders = torch.nn.ModuleList(decoders)
use_deepnorm = (
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
)
assert (
not use_deepnorm or not self.reversible_encoder
), "Reversible layers and deepnorm is not supported for now"
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
@classmethod
def from_config(cls, config: xFormerConfig):
return cls(
config.stack_configs, config.tie_embedding_weights, config.weight_init
)
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
reversible = [
c.reversible
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
]
assert all(reversible) or not any(reversible), (
"All layers need to have the same reversibility setting. "
+ f"Currently {reversible}"
)
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
deepnorm = [
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
]
assert all(deepnorm) or not any(deepnorm), (
"All layers need to have the same deepnorm setting. "
+ f"Currently {deepnorm}"
)
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
# The deepnorm weight initialization method requires different gain factors for the encoder
# and decoder, depending on the general model structure (number of respective layers)
if use_deep_norm:
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
)
else:
encoder_coefficients, decoder_coefficients = None, None
encoder_gain = (
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
)
decoder_gain = (
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
)
# Pick the desired init function
init_fn = get_weight_init_fn(weight_init)
# Initialize all the encoder weights
for name, module in self.encoders.named_children():
init_fn(module=module, name=name, gain=encoder_gain)
for name, module in self.decoders.named_children():
init_fn(module=module, name=name, gain=decoder_gain)
def forward(
self,
src: torch.Tensor,
tgt: Optional[torch.Tensor] = None,
encoder_input_mask: Optional[torch.Tensor] = None,
decoder_input_mask: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
# Encode to latent space if encoder is present
if len(list(self.encoders.parameters())) > 0:
encoders = self.encoders
memory = src.clone()
if isinstance(encoders, torch.nn.ModuleList):
for encoder in encoders:
memory = encoder(memory, input_mask=encoder_input_mask)
else:
if self.rev_enc_pose_encoding:
memory = self.rev_enc_pose_encoding(src)
# Reversible Encoder
x = torch.cat([memory, memory], dim=-1)
# Apply the optional input masking
if encoder_input_mask is not None:
if x.dim() - encoder_input_mask.dim() > 1:
encoder_input_mask.unsqueeze(0)
x += encoder_input_mask.unsqueeze(-1)
x = encoders(x)
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
if not self.decoders:
return memory
# If decoder: either use the encoder ouput, or just decode, both options are possible
if len(self.decoders) > 0:
tgt = src.clone() if tgt is None else tgt
for decoder in self.decoders:
tgt = decoder(
target=tgt,
# pyre-fixme[61]: `memory` is not always initialized here.
memory=memory,
input_mask=decoder_input_mask,
)
return tgt
return None

View File

@@ -0,0 +1,293 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Reusing a lot of code from the Timm repo
# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import logging
import math
from enum import Enum
from typing import Callable
import torch
import torch.nn as nn
from torch.nn.init import (
_calculate_fan_in_and_fan_out,
_no_grad_trunc_normal_,
_no_grad_uniform_,
)
logger = logging.getLogger("xformers")
_assert_if_not_initialized = False
class xFormerWeightInit(str, Enum):
Timm = "timm"
ViT = "vit"
Moco = "moco"
Small = "small"
def get_weight_init_fn(init_choice: xFormerWeightInit):
"""
Provide the xFormers factory with weight init routines.
Supported initializations are:
- Small: follow the method outlined in `Transformer Without Tears`_
- ViT: follow the initialization in the reference ViT_ codebase
- Timm: follow the initialization in the reference Timm_ codebase
- Moco: follow the initialization in the reference MocoV3_ codebase
.. _ViT: https://github.com/google-research/vision_transformer
.. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
.. _MocoV3: https://github.com/facebookresearch/moco-v3
"""
return {
xFormerWeightInit.Timm: _init_weights_vit_timm,
xFormerWeightInit.ViT: _init_weights_vit_jax,
xFormerWeightInit.Moco: _init_weights_vit_moco,
xFormerWeightInit.Small: _init_weights_small,
}[init_choice]
# Define pattern matches
def is_ffn(n):
return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm"))
def is_mha_input_projection(n):
return "q_proj" in n or "k_proj" in n or "v_proj" in n
# Define distribution helpers
def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
r"""Fills the input `Tensor` with values according to the method
described in `Transformer Without Tears`_, using a uniform distribution.
This is a variation of the Xavier init. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}}
Also known as Glorot initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
.. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895
"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a)
def _lecun_normal(tensor, gain=1.0):
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
denom = fan_in
variance = gain / denom
# constant is stddev of standard normal truncated to (-2, 2)
_no_grad_trunc_normal_(
tensor,
mean=0.0,
std=math.sqrt(variance) / 0.87962566103423978,
a=-2.0,
b=2.0,
)
# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place
def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs):
# Small helper to catch all the corner cases, while staying type safe
if hasattr(module, attr):
maybe_tensor = getattr(module, attr)
if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor):
distribution_(maybe_tensor, **kwargs)
def _maybe_report_no_init(module, name):
if len(list(module.named_children())) == 0 and (
hasattr(module, "weight") or hasattr(module, "bias")
):
# Skip layer norm, this is ok
if isinstance(module, torch.nn.LayerNorm):
return
# Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default
if isinstance(module, torch.nn.Embedding):
return
# This is unexpected, warn about a possible unhandled weight
logger.warning(
f"Not initializing weights in {name}, this could be a mistake.\nModule {module}"
)
if _assert_if_not_initialized:
assert False, (
f"Uninitialized weight found in {module}."
+ " If you have a custom module, please provide a `init_weights()` method"
)
# Define the different initialization schemes
def _init_weights_vit_jax(
module: nn.Module,
name: str = "",
head_bias: float = 0.0,
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""ViT weight initialization, matching JAX (Flax) impl"""
if is_ffn(name):
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1.0
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif isinstance(module, nn.Conv2d):
_maybe_init_tensor(module, "weight", _lecun_normal, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights() # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain)
def _init_weights_vit_moco(
module: nn.Module,
name: str = "",
gain: float = 1.0,
**kwargs,
):
"""ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed"""
assert (
"deepnorm_style" not in kwargs.keys()
), "This initialization method does not support deepnorm"
if is_ffn(name):
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
if isinstance(module.weight, torch.Tensor):
val = (
math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1]))
* gain
)
_maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights(gain=gain) # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_moco(child_module, child_name, gain)
def _init_weights_small(
module: nn.Module,
name: str = "",
head_bias: float = 0.0,
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""Follow the `Transformer Without Tears`_ initialization for self-attention"""
if is_ffn(name):
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
# "small init" only scales the attention layers init, not the FFN
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1.0
_maybe_init_tensor(module, "weight", _small_init_, gain=gain)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif isinstance(module, nn.Conv2d):
_maybe_init_tensor(module, "weight", _lecun_normal)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights() # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain)
def _init_weights_vit_timm(
module: nn.Module,
name: str = "",
gain: float = 1.0,
deepnorm_style: bool = False,
**kwargs,
):
"""
ViT weight initialization, original timm impl (for reproducibility).
See DeepNet_ for all the DeepNorm specific codepaths
"""
if isinstance(module, nn.Linear):
if deepnorm_style and (
"q_proj" in name.split(".") or "k_proj" in name.split(".")
):
gain = 1
std = 0.02 * gain
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
_maybe_init_tensor(
module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a
)
_maybe_init_tensor(module, "bias", nn.init.zeros_)
elif hasattr(module, "init_weights"):
module.init_weights(gain=gain) # type: ignore
else:
_maybe_report_no_init(module, name)
# Recurse over the children, if the weight init is being handled here
if not hasattr(module, "init_weights"):
for child_name, child_module in module.named_children():
_init_weights_vit_timm(child_module, child_name, gain)