First commit
This commit is contained in:
11
pkgs/xformers/factory/__init__.py
Normal file
11
pkgs/xformers/factory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from xformers.components import MultiHeadDispatchConfig # noqa
|
||||
from xformers.components.attention import AttentionConfig # noqa
|
||||
from xformers.components.feedforward import FeedforwardConfig # noqa
|
||||
from xformers.components.positional_embedding import PositionEmbeddingConfig # noqa
|
||||
|
||||
from .block_factory import xFormerDecoderBlock # noqa
|
||||
from .block_factory import xFormerDecoderConfig # noqa
|
||||
from .block_factory import xFormerEncoderBlock # noqa
|
||||
from .block_factory import xFormerEncoderConfig # noqa
|
||||
from .model_factory import xFormer, xFormerConfig # noqa
|
||||
from .weight_init import xFormerWeightInit # noqa
|
||||
BIN
pkgs/xformers/factory/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/factory/__pycache__/block_configs.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/block_configs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/factory/__pycache__/block_factory.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/block_factory.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/factory/__pycache__/hydra_helper.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/hydra_helper.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/factory/__pycache__/model_factory.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/model_factory.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/factory/__pycache__/weight_init.cpython-310.pyc
Normal file
BIN
pkgs/xformers/factory/__pycache__/weight_init.cpython-310.pyc
Normal file
Binary file not shown.
237
pkgs/xformers/factory/block_configs.py
Normal file
237
pkgs/xformers/factory/block_configs.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from xformers.components import NormalizationType, ResidualNormStyle
|
||||
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
|
||||
from xformers.components.positional_embedding import (
|
||||
POSITION_EMBEDDING_REGISTRY,
|
||||
PositionEmbeddingConfig,
|
||||
)
|
||||
from xformers.utils import generate_matching_config
|
||||
|
||||
|
||||
class LayerPositionBitmask(int, Enum):
|
||||
First = 0b01
|
||||
Last = 0b10
|
||||
Default = 0b11
|
||||
|
||||
|
||||
class LayerPosition:
|
||||
"""Bitmask to mark this layer as first, last, nothing or both"""
|
||||
|
||||
def __init__(self):
|
||||
self.bitmask = LayerPositionBitmask.Default
|
||||
|
||||
def is_first(self):
|
||||
return bool(self.bitmask & LayerPositionBitmask.First)
|
||||
|
||||
def is_last(self):
|
||||
return bool(self.bitmask & LayerPositionBitmask.Last)
|
||||
|
||||
def mark_not_first(self):
|
||||
self.bitmask &= ~LayerPositionBitmask.First
|
||||
|
||||
def mark_not_last(self):
|
||||
self.bitmask &= ~LayerPositionBitmask.Last
|
||||
|
||||
|
||||
class BlockType(str, Enum):
|
||||
Encoder = "encoder"
|
||||
Decoder = "decoder"
|
||||
|
||||
|
||||
@dataclass(init=False) # handle constructors explicitly to force type changes
|
||||
class xFormerBlockConfig:
|
||||
"""
|
||||
The configuration structure to define a Transformer block.
|
||||
This base class is applicable to both encoder and decoder definitions.
|
||||
|
||||
This completely defines each of the blocks, for instance in terms of dimensions,
|
||||
position encoding, pre or post layer norms or reversibility.
|
||||
"""
|
||||
|
||||
dim_model: int
|
||||
feedforward_config: FeedforwardConfig
|
||||
position_encoding_config: Optional[PositionEmbeddingConfig]
|
||||
block_type: BlockType
|
||||
residual_norm_style: ResidualNormStyle
|
||||
normalization: NormalizationType
|
||||
layer_position: LayerPosition
|
||||
use_triton: bool
|
||||
reversible: bool
|
||||
num_layers: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]],
|
||||
block_type: BlockType,
|
||||
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
reversible: bool = False,
|
||||
num_layers: int = 1,
|
||||
layer_position: Optional[LayerPosition] = None,
|
||||
):
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.block_type = block_type
|
||||
self.residual_norm_style = residual_norm_style
|
||||
self.reversible = reversible
|
||||
self.num_layers = num_layers
|
||||
self.normalization = normalization
|
||||
|
||||
# Fill in possible gaps in the config for subparts of the block
|
||||
self.feedforward_config = generate_matching_config(
|
||||
feedforward_config,
|
||||
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
|
||||
)
|
||||
|
||||
self.position_encoding_config = (
|
||||
generate_matching_config(
|
||||
position_encoding_config,
|
||||
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
|
||||
)
|
||||
if position_encoding_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Default is that this layer is the only one, so both first and last
|
||||
if layer_position:
|
||||
self.layer_position = layer_position
|
||||
else:
|
||||
self.layer_position = LayerPosition()
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerEncoderConfig(xFormerBlockConfig):
|
||||
"""
|
||||
The configuration structure for an encoder block
|
||||
"""
|
||||
|
||||
multi_head_config: Dict[str, Any]
|
||||
use_triton: bool
|
||||
simplicial_embeddings: Optional[Dict[str, Any]]
|
||||
patch_embedding_config: Optional[Dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
multi_head_config: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||||
residual_norm_style: str = "post",
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
use_triton: bool = True,
|
||||
simplicial_embeddings: Optional[Dict[str, Any]] = None,
|
||||
patch_embedding_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Convenience, fill in duplicated fields
|
||||
try:
|
||||
if "dim_model" not in multi_head_config.keys():
|
||||
multi_head_config["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in feedforward_config.keys():
|
||||
feedforward_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
position_encoding_config is not None
|
||||
and "dim_model" not in position_encoding_config.keys()
|
||||
):
|
||||
position_encoding_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
patch_embedding_config is not None
|
||||
and "out_channels" not in patch_embedding_config.keys()
|
||||
):
|
||||
patch_embedding_config["out_channels"] = dim_model
|
||||
|
||||
except AttributeError:
|
||||
# A config instance was passed in, this is fine
|
||||
pass
|
||||
if "block_type" in kwargs:
|
||||
assert kwargs["block_type"] == "encoder"
|
||||
kwargs["block_type"] = BlockType("encoder")
|
||||
super().__init__(
|
||||
dim_model=dim_model,
|
||||
feedforward_config=feedforward_config,
|
||||
position_encoding_config=position_encoding_config,
|
||||
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||||
normalization=NormalizationType(normalization),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.multi_head_config = multi_head_config
|
||||
self.use_triton = use_triton
|
||||
self.simplicial_embeddings = simplicial_embeddings
|
||||
self.patch_embedding_config = patch_embedding_config
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerDecoderConfig(xFormerBlockConfig):
|
||||
"""
|
||||
The configuration structure for a decoder block.
|
||||
|
||||
This specifically defines the masked and cross attention mechanisms,
|
||||
on top of the settings defining all blocks.
|
||||
"""
|
||||
|
||||
multi_head_config_masked: Dict[str, Any] # prior to encoder output
|
||||
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
multi_head_config_masked: Dict[str, Any],
|
||||
multi_head_config_cross: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||||
residual_norm_style: str = "post",
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
use_triton: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Convenience, fill in duplicated field
|
||||
try:
|
||||
if "dim_model" not in multi_head_config_masked.keys():
|
||||
multi_head_config_masked["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in multi_head_config_cross.keys():
|
||||
multi_head_config_cross["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in feedforward_config.keys():
|
||||
feedforward_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
position_encoding_config is not None
|
||||
and "dim_model" not in position_encoding_config.keys()
|
||||
):
|
||||
position_encoding_config["dim_model"] = dim_model
|
||||
except AttributeError:
|
||||
# A config instance was passed in, this is fine
|
||||
pass
|
||||
if "block_type" in kwargs.keys():
|
||||
assert kwargs["block_type"] == "decoder"
|
||||
kwargs["block_type"] = BlockType("decoder")
|
||||
|
||||
super().__init__(
|
||||
dim_model=dim_model,
|
||||
feedforward_config=feedforward_config,
|
||||
position_encoding_config=position_encoding_config,
|
||||
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||||
normalization=NormalizationType(normalization),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.multi_head_config_masked = multi_head_config_masked
|
||||
self.multi_head_config_cross = multi_head_config_cross
|
||||
self.use_triton = use_triton
|
||||
358
pkgs/xformers/factory/block_factory.py
Normal file
358
pkgs/xformers/factory/block_factory.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers._deprecation_warning import deprecated_function
|
||||
from xformers.components import (
|
||||
PatchEmbeddingConfig,
|
||||
PostNorm,
|
||||
PreNorm,
|
||||
Residual,
|
||||
ResidualNormStyle,
|
||||
build_multi_head_attention,
|
||||
build_patch_embedding,
|
||||
)
|
||||
from xformers.components.attention import AttentionMask
|
||||
from xformers.components.feedforward import build_feedforward
|
||||
from xformers.components.positional_embedding import build_positional_embedding
|
||||
from xformers.components.residual import get_deepnorm_coefficients
|
||||
from xformers.components.simplicial_embedding import SimplicialEmbedding
|
||||
from xformers.factory.block_configs import (
|
||||
NormalizationType,
|
||||
xFormerDecoderConfig,
|
||||
xFormerEncoderConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
def _get_ln_factory(
|
||||
d_model: int,
|
||||
residual_norm_style: Optional[ResidualNormStyle],
|
||||
use_triton: bool,
|
||||
residual: bool,
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
residual_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Handle all the supported residual path configurations.
|
||||
|
||||
..Note: we return the appropriate constructor, not an actual layer
|
||||
"""
|
||||
|
||||
def get_layer_wrapper(
|
||||
d_model: int,
|
||||
sublayer: nn.Module,
|
||||
residual_norm_style: Optional[ResidualNormStyle],
|
||||
residual: bool,
|
||||
residual_scale: float,
|
||||
):
|
||||
if residual:
|
||||
if residual_norm_style == ResidualNormStyle.Pre:
|
||||
return Residual(
|
||||
layer=PreNorm(d_model, sublayer, normalization, use_triton),
|
||||
scale=None,
|
||||
)
|
||||
elif residual_norm_style == ResidualNormStyle.Post:
|
||||
return PostNorm(
|
||||
d_model,
|
||||
Residual(layer=sublayer, scale=None),
|
||||
normalization,
|
||||
use_triton,
|
||||
)
|
||||
elif residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
return PostNorm(
|
||||
d_model,
|
||||
Residual(layer=sublayer, scale=residual_scale),
|
||||
normalization,
|
||||
use_triton=use_triton,
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return (
|
||||
PreNorm(d_model, sublayer, normalization, use_triton)
|
||||
if residual_norm_style == ResidualNormStyle.Pre
|
||||
else PostNorm(d_model, sublayer, normalization, use_triton)
|
||||
)
|
||||
|
||||
def ln_factory(sublayer: nn.Module):
|
||||
return get_layer_wrapper(
|
||||
d_model, sublayer, residual_norm_style, residual, residual_scale
|
||||
)
|
||||
|
||||
return ln_factory
|
||||
|
||||
|
||||
class xFormerEncoderBlock(torch.nn.Module):
|
||||
r"""A vanilla Transformer Encoder block"""
|
||||
|
||||
def __init__(self, config: xFormerEncoderConfig, **kwargs):
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
self.reversible_f = None
|
||||
self.reversible_g = None
|
||||
self.residual_norm_style = config.residual_norm_style
|
||||
self.dim_model = config.dim_model
|
||||
|
||||
# If this layer is the first one, and a pose encoding has been requested
|
||||
if (
|
||||
config.position_encoding_config is not None
|
||||
and config.layer_position.is_first()
|
||||
):
|
||||
self.pose_encoding = build_positional_embedding(
|
||||
asdict(config.position_encoding_config)
|
||||
)
|
||||
|
||||
pos_encoding_dim = config.position_encoding_config.dim_model
|
||||
mha_dim = config.multi_head_config["dim_model"]
|
||||
|
||||
if pos_encoding_dim != mha_dim:
|
||||
logger.warning(
|
||||
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
||||
)
|
||||
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
||||
else:
|
||||
self.pose_encoding = None
|
||||
|
||||
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
# Just use the layer norm coefficient here,
|
||||
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
||||
deep_norm_coefficients, _ = get_deepnorm_coefficients(
|
||||
encoder_layers=config.num_layers, decoder_layers=0
|
||||
)
|
||||
assert deep_norm_coefficients is not None
|
||||
residual_scale = deep_norm_coefficients.alpha
|
||||
else:
|
||||
residual_scale = 1.0
|
||||
|
||||
# mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
use_triton=config.use_triton,
|
||||
residual=True,
|
||||
residual_scale=residual_scale,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config)
|
||||
feedforward = build_feedforward(asdict(config.feedforward_config))
|
||||
|
||||
# Expose attention specific capabilities
|
||||
self.supports_attention_mask = mha.attention.supports_attention_mask
|
||||
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
||||
self.causal = (
|
||||
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
||||
)
|
||||
|
||||
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
||||
self.wrap_att = ln_factory(mha)
|
||||
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
||||
if (
|
||||
config.residual_norm_style == ResidualNormStyle.Pre
|
||||
and config.layer_position.is_last()
|
||||
):
|
||||
self.wrap_ff = PostNorm(
|
||||
config.dim_model,
|
||||
self.wrap_ff,
|
||||
normalization=config.normalization,
|
||||
use_triton=config.use_triton,
|
||||
)
|
||||
|
||||
# Simplicial embeddings are only used if specified, and on the last layer
|
||||
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
|
||||
if config.simplicial_embeddings is not None and config.layer_position.is_last():
|
||||
self.simplicial_embedding = SimplicialEmbedding(
|
||||
**config.simplicial_embeddings
|
||||
)
|
||||
|
||||
# Optional patch embedding
|
||||
self.patch_emb: Optional[nn.Module] = None
|
||||
|
||||
if config.patch_embedding_config is not None:
|
||||
self.patch_emb = build_patch_embedding(
|
||||
PatchEmbeddingConfig(**config.patch_embedding_config)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerEncoderConfig):
|
||||
return cls(config)
|
||||
|
||||
@staticmethod
|
||||
def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
residual=False,
|
||||
use_triton=config.use_triton,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config)
|
||||
feedforward = build_feedforward(asdict(config.feedforward_config))
|
||||
|
||||
reversible_f = ln_factory(mha)
|
||||
reversible_g = ln_factory(feedforward)
|
||||
return reversible_f, reversible_g
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.patch_emb is not None:
|
||||
x = self.patch_emb(x)
|
||||
|
||||
if self.pose_encoding is not None:
|
||||
x = self.pose_encoding(x)
|
||||
|
||||
if hasattr(self, "embedding_projector"):
|
||||
x = self.embedding_projector(x)
|
||||
|
||||
# Handle the optional input masking, differs on Q, K, V
|
||||
if input_mask is not None:
|
||||
q = x
|
||||
k = x * input_mask.unsqueeze(-1)
|
||||
v = k
|
||||
else:
|
||||
q, k, v = x, x, x
|
||||
|
||||
# Pre/Post norms and residual paths are already handled
|
||||
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
|
||||
x = self.wrap_ff(inputs=[x])
|
||||
|
||||
# Optional simplicial embeddings
|
||||
if self.simplicial_embedding is not None:
|
||||
x = self.simplicial_embedding(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class xFormerDecoderBlock(torch.nn.Module):
|
||||
r"""A vanilla Transformer Decoder block
|
||||
|
||||
... note: this implementation is not (yet ?) reversible"""
|
||||
|
||||
def __init__(self, config: xFormerDecoderConfig, **kwargs):
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
# If this layer is the first one, and a pose encoding as been requested
|
||||
if (
|
||||
config.position_encoding_config is not None
|
||||
and config.layer_position.is_first()
|
||||
):
|
||||
self.pose_encoding = build_positional_embedding(
|
||||
config.position_encoding_config
|
||||
)
|
||||
|
||||
pos_encoding_dim = config.position_encoding_config.dim_model
|
||||
mha_dim = config.multi_head_config_masked["dim_model"]
|
||||
|
||||
if pos_encoding_dim != mha_dim:
|
||||
|
||||
logger.warning(
|
||||
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
||||
)
|
||||
|
||||
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
||||
else:
|
||||
self.pose_encoding = None
|
||||
|
||||
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
# Just use the layer norm coefficient here,
|
||||
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
||||
_, deep_norm_coefficients = get_deepnorm_coefficients(
|
||||
encoder_layers=0, decoder_layers=config.num_layers
|
||||
)
|
||||
assert deep_norm_coefficients is not None
|
||||
residual_scale = deep_norm_coefficients.alpha
|
||||
else:
|
||||
residual_scale = 1.0
|
||||
|
||||
# mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
use_triton=config.use_triton,
|
||||
residual=True,
|
||||
residual_scale=residual_scale,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config_masked)
|
||||
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
|
||||
feedforward = build_feedforward(config.feedforward_config)
|
||||
|
||||
# Expose attention or feedforward specific capabilities
|
||||
self.supports_attention_mask = mha.attention.supports_attention_mask
|
||||
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
||||
self.requires_squared_context_length = (
|
||||
feedforward.requires_squared_context
|
||||
or mha.attention.requires_squared_context
|
||||
)
|
||||
|
||||
self.causal_attention = (
|
||||
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
||||
)
|
||||
|
||||
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
||||
self.wrap_att = ln_factory(mha)
|
||||
self.wrap_cross = ln_factory(cross_mha)
|
||||
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
||||
|
||||
if (
|
||||
config.residual_norm_style == ResidualNormStyle.Pre
|
||||
and config.layer_position.is_last()
|
||||
):
|
||||
self.wrap_ff = PostNorm(
|
||||
config.dim_model,
|
||||
self.wrap_ff,
|
||||
normalization=NormalizationType.LayerNorm,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerDecoderConfig):
|
||||
return cls(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.pose_encoding is not None:
|
||||
target = self.pose_encoding(target)
|
||||
|
||||
if hasattr(self, "embedding_projector"):
|
||||
target = self.embedding_projector(target)
|
||||
|
||||
# Handle the optional input masking, differs on Q, K, V
|
||||
if input_mask is not None:
|
||||
target_q = target
|
||||
target_k = target * input_mask.unsqueeze(-1)
|
||||
target_v = target_k
|
||||
else:
|
||||
target_q, target_k, target_v = target, target, target
|
||||
|
||||
x = self.wrap_att(
|
||||
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
|
||||
)
|
||||
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
|
||||
x = self.wrap_ff(inputs=[x])
|
||||
|
||||
return x
|
||||
36
pkgs/xformers/factory/hydra_helper.py
Normal file
36
pkgs/xformers/factory/hydra_helper.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# register components configs into Hydra ConfigStore
|
||||
# component config classes could be used to validate configs
|
||||
import logging
|
||||
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from omegaconf.errors import ValidationError
|
||||
|
||||
from xformers.components.attention import ATTENTION_REGISTRY
|
||||
from xformers.components.feedforward import FEEDFORWARD_REGISTRY
|
||||
from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
def import_xformer_config_schema():
|
||||
"""
|
||||
Best effort - OmegaConf supports limited typing, so we may fail to import
|
||||
certain config classes. For example, pytorch typing are not supported.
|
||||
"""
|
||||
cs = ConfigStore.instance()
|
||||
|
||||
for k, v in {
|
||||
"ff": FEEDFORWARD_REGISTRY,
|
||||
"pe": POSITION_EMBEDDING_REGISTRY,
|
||||
"attention": ATTENTION_REGISTRY,
|
||||
}.items():
|
||||
for kk in v.keys():
|
||||
try:
|
||||
cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}")
|
||||
except ValidationError as e:
|
||||
logger.debug(f"Error registering {kk}_schema, error: {e}")
|
||||
313
pkgs/xformers/factory/model_factory.py
Normal file
313
pkgs/xformers/factory/model_factory.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers._deprecation_warning import deprecated_function
|
||||
from xformers.components import reversible as rv
|
||||
from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients
|
||||
from xformers.factory.block_configs import (
|
||||
xFormerBlockConfig,
|
||||
xFormerDecoderConfig,
|
||||
xFormerEncoderConfig,
|
||||
)
|
||||
from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock
|
||||
from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerConfig:
|
||||
"""
|
||||
The configuration structure to define a full Transformer.
|
||||
This can include a stack of encoder layers, and a stack of decoder layers.
|
||||
|
||||
It is optionally possible to share the embedding weights in between
|
||||
the encoder and decoder positional encoding, as proposed for instance by
|
||||
`Using the Output Embedding to Improve Language Models`, Press et al.
|
||||
|
||||
A full config example is for instance as follows:
|
||||
|
||||
::
|
||||
|
||||
xformer_config = [
|
||||
{
|
||||
"reversible": False, # Turn on to test the effect of using reversible layers
|
||||
"block_type": "encoder",
|
||||
"num_layers": LAYERS,
|
||||
"dim_model": EMB,
|
||||
"residual_norm_style": "pre",
|
||||
"position_encoding_config": {
|
||||
"name": "vocab",
|
||||
"seq_len": CONTEXT,
|
||||
"vocab_size": VOCAB_SIZE,
|
||||
},
|
||||
"multi_head_config": {
|
||||
"num_heads": NUM_HEADS,
|
||||
"residual_dropout": RES_DROP,
|
||||
"use_rotary_embeddings": True,
|
||||
"attention": {
|
||||
"name": ATTENTION_MECHANISM_STR,
|
||||
"dropout": ATTN_DROP,
|
||||
"causal": True,
|
||||
"seq_len": CONTEXT,
|
||||
},
|
||||
},
|
||||
"feedforward_config": {
|
||||
"name": "FusedMLP", # Use MLP if Triton is not available
|
||||
"dropout": MLP_DROP,
|
||||
"activation": "gelu",
|
||||
"hidden_layer_multiplier": MLP_MULTIPLIER,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
.. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf
|
||||
"""
|
||||
|
||||
stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]]
|
||||
tie_embedding_weights: bool = False
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]],
|
||||
tie_embedding_weights: bool = False,
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||||
):
|
||||
# Type all the configurations. Possible typos are caught here
|
||||
if isinstance(stack_configs, dict):
|
||||
self.stack_configs = {}
|
||||
for k, config in stack_configs.items():
|
||||
if config["block_type"] == "encoder":
|
||||
self.stack_configs[k] = xFormerEncoderConfig(**config)
|
||||
else:
|
||||
self.stack_configs[k] = xFormerDecoderConfig(**config)
|
||||
else:
|
||||
self.stack_configs = []
|
||||
for config in stack_configs:
|
||||
if config["block_type"] == "encoder":
|
||||
self.stack_configs.append(xFormerEncoderConfig(**config))
|
||||
else:
|
||||
self.stack_configs.append(xFormerDecoderConfig(**config))
|
||||
|
||||
self.tie_embedding_weights = tie_embedding_weights
|
||||
self.weight_init = weight_init
|
||||
deprecated_function(self)
|
||||
|
||||
|
||||
class xFormer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
stack_configs: Union[
|
||||
xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]
|
||||
],
|
||||
tie_embedding_weights: bool = False,
|
||||
weight_init: xFormerWeightInit = xFormerWeightInit.ViT,
|
||||
):
|
||||
"""
|
||||
Given a serialized configuration, generate the corresponding model.
|
||||
This is only a helper and can easily be bypassed
|
||||
"""
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
if isinstance(stack_configs, Dict):
|
||||
stack_configs = list(stack_configs.values())
|
||||
|
||||
# Convenience, users can pass either a list of configs or a single one
|
||||
if not isinstance(stack_configs, List):
|
||||
stack_configs = [stack_configs]
|
||||
|
||||
# Sanity checks, some config combinations do not make sense
|
||||
self._verify_reversible(stack_configs)
|
||||
self._verify_deepnorm(stack_configs)
|
||||
|
||||
encoders: List[torch.nn.Module] = []
|
||||
decoders: List[torch.nn.Module] = []
|
||||
|
||||
self.reversible_encoder = False
|
||||
self.rev_enc_pose_encoding = None
|
||||
|
||||
# Unroll the configs and build the model
|
||||
for config in stack_configs:
|
||||
# Handle either Encoder or Decoder stacks
|
||||
builder = (
|
||||
xFormerEncoderBlock.from_config
|
||||
if isinstance(config, xFormerEncoderConfig)
|
||||
else xFormerDecoderBlock.from_config
|
||||
)
|
||||
recipient = (
|
||||
encoders if isinstance(config, xFormerEncoderConfig) else decoders
|
||||
)
|
||||
|
||||
# Build up the stack
|
||||
for i in range(config.num_layers):
|
||||
# Label where this layer is in the stack
|
||||
# (for instance useful for the positional encoding, or late layer norm)
|
||||
if len(recipient) > 0:
|
||||
config.layer_position.mark_not_first()
|
||||
|
||||
if config != stack_configs[-1] or i < config.num_layers - 1:
|
||||
config.layer_position.mark_not_last()
|
||||
|
||||
block = builder(config) # type: ignore
|
||||
|
||||
# If reversible: extract the reversible sub-parts, else append the block as-is
|
||||
if config.reversible:
|
||||
# WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance)
|
||||
assert isinstance(config, xFormerEncoderConfig)
|
||||
if block.pose_encoding is not None:
|
||||
self.rev_enc_pose_encoding = block.pose_encoding
|
||||
self.reversible_encoder = True
|
||||
|
||||
f, g = xFormerEncoderBlock.get_reversible_layer(config)
|
||||
recipient.append(torch.nn.ModuleList([f, g]))
|
||||
else:
|
||||
recipient.append(block) # type: ignore
|
||||
|
||||
# Tie embedding weights, if requested and possible
|
||||
assert (
|
||||
not tie_embedding_weights or not self.reversible_encoder
|
||||
), "Reversible layers and tied embeddings is not supported for now"
|
||||
|
||||
if (
|
||||
tie_embedding_weights
|
||||
and encoders
|
||||
and encoders[0].pose_encoding
|
||||
and decoders
|
||||
and decoders[0].pose_encoding
|
||||
and not config.reversible
|
||||
):
|
||||
logger.info("Tying encoder and decoder embeddings, as requested")
|
||||
encoders[0].pose_encoding = decoders[0].pose_encoding
|
||||
|
||||
self.encoders: torch.nn.Module = (
|
||||
rv.ReversibleSequence(torch.nn.ModuleList(encoders))
|
||||
if self.reversible_encoder
|
||||
else torch.nn.ModuleList(encoders)
|
||||
)
|
||||
self.decoders = torch.nn.ModuleList(decoders)
|
||||
|
||||
use_deepnorm = (
|
||||
stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm
|
||||
)
|
||||
|
||||
assert (
|
||||
not use_deepnorm or not self.reversible_encoder
|
||||
), "Reversible layers and deepnorm is not supported for now"
|
||||
|
||||
self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerConfig):
|
||||
return cls(
|
||||
config.stack_configs, config.tie_embedding_weights, config.weight_init
|
||||
)
|
||||
|
||||
def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]):
|
||||
reversible = [
|
||||
c.reversible
|
||||
for c in filter(lambda x: x.block_type == "encoder", stack_configs)
|
||||
]
|
||||
|
||||
assert all(reversible) or not any(reversible), (
|
||||
"All layers need to have the same reversibility setting. "
|
||||
+ f"Currently {reversible}"
|
||||
)
|
||||
|
||||
def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]):
|
||||
deepnorm = [
|
||||
c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs
|
||||
]
|
||||
|
||||
assert all(deepnorm) or not any(deepnorm), (
|
||||
"All layers need to have the same deepnorm setting. "
|
||||
+ f"Currently {deepnorm}"
|
||||
)
|
||||
|
||||
def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool):
|
||||
# The deepnorm weight initialization method requires different gain factors for the encoder
|
||||
# and decoder, depending on the general model structure (number of respective layers)
|
||||
if use_deep_norm:
|
||||
encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients(
|
||||
encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore
|
||||
)
|
||||
else:
|
||||
encoder_coefficients, decoder_coefficients = None, None
|
||||
|
||||
encoder_gain = (
|
||||
encoder_coefficients.beta if encoder_coefficients is not None else 1.0
|
||||
)
|
||||
decoder_gain = (
|
||||
decoder_coefficients.beta if decoder_coefficients is not None else 1.0
|
||||
)
|
||||
|
||||
# Pick the desired init function
|
||||
init_fn = get_weight_init_fn(weight_init)
|
||||
|
||||
# Initialize all the encoder weights
|
||||
for name, module in self.encoders.named_children():
|
||||
init_fn(module=module, name=name, gain=encoder_gain)
|
||||
|
||||
for name, module in self.decoders.named_children():
|
||||
init_fn(module=module, name=name, gain=decoder_gain)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
tgt: Optional[torch.Tensor] = None,
|
||||
encoder_input_mask: Optional[torch.Tensor] = None,
|
||||
decoder_input_mask: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
# Encode to latent space if encoder is present
|
||||
if len(list(self.encoders.parameters())) > 0:
|
||||
encoders = self.encoders
|
||||
memory = src.clone()
|
||||
if isinstance(encoders, torch.nn.ModuleList):
|
||||
for encoder in encoders:
|
||||
memory = encoder(memory, input_mask=encoder_input_mask)
|
||||
else:
|
||||
if self.rev_enc_pose_encoding:
|
||||
memory = self.rev_enc_pose_encoding(src)
|
||||
|
||||
# Reversible Encoder
|
||||
x = torch.cat([memory, memory], dim=-1)
|
||||
|
||||
# Apply the optional input masking
|
||||
if encoder_input_mask is not None:
|
||||
if x.dim() - encoder_input_mask.dim() > 1:
|
||||
encoder_input_mask.unsqueeze(0)
|
||||
x += encoder_input_mask.unsqueeze(-1)
|
||||
|
||||
x = encoders(x)
|
||||
memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
|
||||
|
||||
if not self.decoders:
|
||||
return memory
|
||||
|
||||
# If decoder: either use the encoder ouput, or just decode, both options are possible
|
||||
if len(self.decoders) > 0:
|
||||
tgt = src.clone() if tgt is None else tgt
|
||||
|
||||
for decoder in self.decoders:
|
||||
tgt = decoder(
|
||||
target=tgt,
|
||||
# pyre-fixme[61]: `memory` is not always initialized here.
|
||||
memory=memory,
|
||||
input_mask=decoder_input_mask,
|
||||
)
|
||||
|
||||
return tgt
|
||||
|
||||
return None
|
||||
293
pkgs/xformers/factory/weight_init.py
Normal file
293
pkgs/xformers/factory/weight_init.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# CREDITS: Reusing a lot of code from the Timm repo
|
||||
# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import (
|
||||
_calculate_fan_in_and_fan_out,
|
||||
_no_grad_trunc_normal_,
|
||||
_no_grad_uniform_,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_assert_if_not_initialized = False
|
||||
|
||||
|
||||
class xFormerWeightInit(str, Enum):
|
||||
Timm = "timm"
|
||||
ViT = "vit"
|
||||
Moco = "moco"
|
||||
Small = "small"
|
||||
|
||||
|
||||
def get_weight_init_fn(init_choice: xFormerWeightInit):
|
||||
"""
|
||||
Provide the xFormers factory with weight init routines.
|
||||
|
||||
Supported initializations are:
|
||||
- Small: follow the method outlined in `Transformer Without Tears`_
|
||||
- ViT: follow the initialization in the reference ViT_ codebase
|
||||
- Timm: follow the initialization in the reference Timm_ codebase
|
||||
- Moco: follow the initialization in the reference MocoV3_ codebase
|
||||
|
||||
.. _ViT: https://github.com/google-research/vision_transformer
|
||||
.. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
.. _MocoV3: https://github.com/facebookresearch/moco-v3
|
||||
"""
|
||||
return {
|
||||
xFormerWeightInit.Timm: _init_weights_vit_timm,
|
||||
xFormerWeightInit.ViT: _init_weights_vit_jax,
|
||||
xFormerWeightInit.Moco: _init_weights_vit_moco,
|
||||
xFormerWeightInit.Small: _init_weights_small,
|
||||
}[init_choice]
|
||||
|
||||
|
||||
# Define pattern matches
|
||||
def is_ffn(n):
|
||||
return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm"))
|
||||
|
||||
|
||||
def is_mha_input_projection(n):
|
||||
return "q_proj" in n or "k_proj" in n or "v_proj" in n
|
||||
|
||||
|
||||
# Define distribution helpers
|
||||
def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
||||
r"""Fills the input `Tensor` with values according to the method
|
||||
described in `Transformer Without Tears`_, using a uniform distribution.
|
||||
|
||||
This is a variation of the Xavier init. The resulting tensor will have values sampled from
|
||||
:math:`\mathcal{U}(-a, a)` where
|
||||
|
||||
.. math::
|
||||
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}}
|
||||
|
||||
Also known as Glorot initialization.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
gain: an optional scaling factor
|
||||
|
||||
.. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895
|
||||
|
||||
"""
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out))
|
||||
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
return _no_grad_uniform_(tensor, -a, a)
|
||||
|
||||
|
||||
def _lecun_normal(tensor, gain=1.0):
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(tensor)
|
||||
denom = fan_in
|
||||
variance = gain / denom
|
||||
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
_no_grad_trunc_normal_(
|
||||
tensor,
|
||||
mean=0.0,
|
||||
std=math.sqrt(variance) / 0.87962566103423978,
|
||||
a=-2.0,
|
||||
b=2.0,
|
||||
)
|
||||
|
||||
|
||||
# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place
|
||||
def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs):
|
||||
# Small helper to catch all the corner cases, while staying type safe
|
||||
if hasattr(module, attr):
|
||||
maybe_tensor = getattr(module, attr)
|
||||
if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor):
|
||||
distribution_(maybe_tensor, **kwargs)
|
||||
|
||||
|
||||
def _maybe_report_no_init(module, name):
|
||||
if len(list(module.named_children())) == 0 and (
|
||||
hasattr(module, "weight") or hasattr(module, "bias")
|
||||
):
|
||||
# Skip layer norm, this is ok
|
||||
if isinstance(module, torch.nn.LayerNorm):
|
||||
return
|
||||
|
||||
# Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default
|
||||
if isinstance(module, torch.nn.Embedding):
|
||||
return
|
||||
|
||||
# This is unexpected, warn about a possible unhandled weight
|
||||
logger.warning(
|
||||
f"Not initializing weights in {name}, this could be a mistake.\nModule {module}"
|
||||
)
|
||||
|
||||
if _assert_if_not_initialized:
|
||||
assert False, (
|
||||
f"Uninitialized weight found in {module}."
|
||||
+ " If you have a custom module, please provide a `init_weights()` method"
|
||||
)
|
||||
|
||||
|
||||
# Define the different initialization schemes
|
||||
def _init_weights_vit_jax(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
head_bias: float = 0.0,
|
||||
gain: float = 1.0,
|
||||
deepnorm_style: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""ViT weight initialization, matching JAX (Flax) impl"""
|
||||
|
||||
if is_ffn(name):
|
||||
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
|
||||
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
||||
|
||||
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
||||
if deepnorm_style and (
|
||||
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
||||
):
|
||||
gain = 1.0
|
||||
|
||||
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
_maybe_init_tensor(module, "weight", _lecun_normal, gain=gain)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif hasattr(module, "init_weights"):
|
||||
module.init_weights() # type: ignore
|
||||
|
||||
else:
|
||||
_maybe_report_no_init(module, name)
|
||||
|
||||
# Recurse over the children, if the weight init is being handled here
|
||||
if not hasattr(module, "init_weights"):
|
||||
for child_name, child_module in module.named_children():
|
||||
_init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain)
|
||||
|
||||
|
||||
def _init_weights_vit_moco(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
gain: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
"""ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed"""
|
||||
|
||||
assert (
|
||||
"deepnorm_style" not in kwargs.keys()
|
||||
), "This initialization method does not support deepnorm"
|
||||
|
||||
if is_ffn(name):
|
||||
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
||||
if isinstance(module.weight, torch.Tensor):
|
||||
val = (
|
||||
math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1]))
|
||||
* gain
|
||||
)
|
||||
_maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val)
|
||||
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif hasattr(module, "init_weights"):
|
||||
module.init_weights(gain=gain) # type: ignore
|
||||
|
||||
else:
|
||||
_maybe_report_no_init(module, name)
|
||||
|
||||
# Recurse over the children, if the weight init is being handled here
|
||||
if not hasattr(module, "init_weights"):
|
||||
for child_name, child_module in module.named_children():
|
||||
_init_weights_vit_moco(child_module, child_name, gain)
|
||||
|
||||
|
||||
def _init_weights_small(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
head_bias: float = 0.0,
|
||||
gain: float = 1.0,
|
||||
deepnorm_style: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Follow the `Transformer Without Tears`_ initialization for self-attention"""
|
||||
|
||||
if is_ffn(name):
|
||||
_maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain)
|
||||
_maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6)
|
||||
|
||||
elif is_mha_input_projection(name) or isinstance(module, nn.Linear):
|
||||
# "small init" only scales the attention layers init, not the FFN
|
||||
if deepnorm_style and (
|
||||
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
||||
):
|
||||
gain = 1.0
|
||||
|
||||
_maybe_init_tensor(module, "weight", _small_init_, gain=gain)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
_maybe_init_tensor(module, "weight", _lecun_normal)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
elif hasattr(module, "init_weights"):
|
||||
module.init_weights() # type: ignore
|
||||
else:
|
||||
_maybe_report_no_init(module, name)
|
||||
|
||||
# Recurse over the children, if the weight init is being handled here
|
||||
if not hasattr(module, "init_weights"):
|
||||
for child_name, child_module in module.named_children():
|
||||
_init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain)
|
||||
|
||||
|
||||
def _init_weights_vit_timm(
|
||||
module: nn.Module,
|
||||
name: str = "",
|
||||
gain: float = 1.0,
|
||||
deepnorm_style: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
ViT weight initialization, original timm impl (for reproducibility).
|
||||
|
||||
See DeepNet_ for all the DeepNorm specific codepaths
|
||||
"""
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
if deepnorm_style and (
|
||||
"q_proj" in name.split(".") or "k_proj" in name.split(".")
|
||||
):
|
||||
gain = 1
|
||||
|
||||
std = 0.02 * gain
|
||||
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
_maybe_init_tensor(
|
||||
module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a
|
||||
)
|
||||
_maybe_init_tensor(module, "bias", nn.init.zeros_)
|
||||
|
||||
elif hasattr(module, "init_weights"):
|
||||
module.init_weights(gain=gain) # type: ignore
|
||||
else:
|
||||
_maybe_report_no_init(module, name)
|
||||
|
||||
# Recurse over the children, if the weight init is being handled here
|
||||
if not hasattr(module, "init_weights"):
|
||||
for child_name, child_module in module.named_children():
|
||||
_init_weights_vit_timm(child_module, child_name, gain)
|
||||
Reference in New Issue
Block a user