First commit
This commit is contained in:
358
pkgs/xformers/factory/block_factory.py
Normal file
358
pkgs/xformers/factory/block_factory.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers._deprecation_warning import deprecated_function
|
||||
from xformers.components import (
|
||||
PatchEmbeddingConfig,
|
||||
PostNorm,
|
||||
PreNorm,
|
||||
Residual,
|
||||
ResidualNormStyle,
|
||||
build_multi_head_attention,
|
||||
build_patch_embedding,
|
||||
)
|
||||
from xformers.components.attention import AttentionMask
|
||||
from xformers.components.feedforward import build_feedforward
|
||||
from xformers.components.positional_embedding import build_positional_embedding
|
||||
from xformers.components.residual import get_deepnorm_coefficients
|
||||
from xformers.components.simplicial_embedding import SimplicialEmbedding
|
||||
from xformers.factory.block_configs import (
|
||||
NormalizationType,
|
||||
xFormerDecoderConfig,
|
||||
xFormerEncoderConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
def _get_ln_factory(
|
||||
d_model: int,
|
||||
residual_norm_style: Optional[ResidualNormStyle],
|
||||
use_triton: bool,
|
||||
residual: bool,
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
residual_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Handle all the supported residual path configurations.
|
||||
|
||||
..Note: we return the appropriate constructor, not an actual layer
|
||||
"""
|
||||
|
||||
def get_layer_wrapper(
|
||||
d_model: int,
|
||||
sublayer: nn.Module,
|
||||
residual_norm_style: Optional[ResidualNormStyle],
|
||||
residual: bool,
|
||||
residual_scale: float,
|
||||
):
|
||||
if residual:
|
||||
if residual_norm_style == ResidualNormStyle.Pre:
|
||||
return Residual(
|
||||
layer=PreNorm(d_model, sublayer, normalization, use_triton),
|
||||
scale=None,
|
||||
)
|
||||
elif residual_norm_style == ResidualNormStyle.Post:
|
||||
return PostNorm(
|
||||
d_model,
|
||||
Residual(layer=sublayer, scale=None),
|
||||
normalization,
|
||||
use_triton,
|
||||
)
|
||||
elif residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
return PostNorm(
|
||||
d_model,
|
||||
Residual(layer=sublayer, scale=residual_scale),
|
||||
normalization,
|
||||
use_triton=use_triton,
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return (
|
||||
PreNorm(d_model, sublayer, normalization, use_triton)
|
||||
if residual_norm_style == ResidualNormStyle.Pre
|
||||
else PostNorm(d_model, sublayer, normalization, use_triton)
|
||||
)
|
||||
|
||||
def ln_factory(sublayer: nn.Module):
|
||||
return get_layer_wrapper(
|
||||
d_model, sublayer, residual_norm_style, residual, residual_scale
|
||||
)
|
||||
|
||||
return ln_factory
|
||||
|
||||
|
||||
class xFormerEncoderBlock(torch.nn.Module):
|
||||
r"""A vanilla Transformer Encoder block"""
|
||||
|
||||
def __init__(self, config: xFormerEncoderConfig, **kwargs):
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
self.reversible_f = None
|
||||
self.reversible_g = None
|
||||
self.residual_norm_style = config.residual_norm_style
|
||||
self.dim_model = config.dim_model
|
||||
|
||||
# If this layer is the first one, and a pose encoding has been requested
|
||||
if (
|
||||
config.position_encoding_config is not None
|
||||
and config.layer_position.is_first()
|
||||
):
|
||||
self.pose_encoding = build_positional_embedding(
|
||||
asdict(config.position_encoding_config)
|
||||
)
|
||||
|
||||
pos_encoding_dim = config.position_encoding_config.dim_model
|
||||
mha_dim = config.multi_head_config["dim_model"]
|
||||
|
||||
if pos_encoding_dim != mha_dim:
|
||||
logger.warning(
|
||||
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
||||
)
|
||||
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
||||
else:
|
||||
self.pose_encoding = None
|
||||
|
||||
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
# Just use the layer norm coefficient here,
|
||||
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
||||
deep_norm_coefficients, _ = get_deepnorm_coefficients(
|
||||
encoder_layers=config.num_layers, decoder_layers=0
|
||||
)
|
||||
assert deep_norm_coefficients is not None
|
||||
residual_scale = deep_norm_coefficients.alpha
|
||||
else:
|
||||
residual_scale = 1.0
|
||||
|
||||
# mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
use_triton=config.use_triton,
|
||||
residual=True,
|
||||
residual_scale=residual_scale,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config)
|
||||
feedforward = build_feedforward(asdict(config.feedforward_config))
|
||||
|
||||
# Expose attention specific capabilities
|
||||
self.supports_attention_mask = mha.attention.supports_attention_mask
|
||||
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
||||
self.causal = (
|
||||
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
||||
)
|
||||
|
||||
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
||||
self.wrap_att = ln_factory(mha)
|
||||
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
||||
if (
|
||||
config.residual_norm_style == ResidualNormStyle.Pre
|
||||
and config.layer_position.is_last()
|
||||
):
|
||||
self.wrap_ff = PostNorm(
|
||||
config.dim_model,
|
||||
self.wrap_ff,
|
||||
normalization=config.normalization,
|
||||
use_triton=config.use_triton,
|
||||
)
|
||||
|
||||
# Simplicial embeddings are only used if specified, and on the last layer
|
||||
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
|
||||
if config.simplicial_embeddings is not None and config.layer_position.is_last():
|
||||
self.simplicial_embedding = SimplicialEmbedding(
|
||||
**config.simplicial_embeddings
|
||||
)
|
||||
|
||||
# Optional patch embedding
|
||||
self.patch_emb: Optional[nn.Module] = None
|
||||
|
||||
if config.patch_embedding_config is not None:
|
||||
self.patch_emb = build_patch_embedding(
|
||||
PatchEmbeddingConfig(**config.patch_embedding_config)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerEncoderConfig):
|
||||
return cls(config)
|
||||
|
||||
@staticmethod
|
||||
def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
residual=False,
|
||||
use_triton=config.use_triton,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config)
|
||||
feedforward = build_feedforward(asdict(config.feedforward_config))
|
||||
|
||||
reversible_f = ln_factory(mha)
|
||||
reversible_g = ln_factory(feedforward)
|
||||
return reversible_f, reversible_g
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.patch_emb is not None:
|
||||
x = self.patch_emb(x)
|
||||
|
||||
if self.pose_encoding is not None:
|
||||
x = self.pose_encoding(x)
|
||||
|
||||
if hasattr(self, "embedding_projector"):
|
||||
x = self.embedding_projector(x)
|
||||
|
||||
# Handle the optional input masking, differs on Q, K, V
|
||||
if input_mask is not None:
|
||||
q = x
|
||||
k = x * input_mask.unsqueeze(-1)
|
||||
v = k
|
||||
else:
|
||||
q, k, v = x, x, x
|
||||
|
||||
# Pre/Post norms and residual paths are already handled
|
||||
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
|
||||
x = self.wrap_ff(inputs=[x])
|
||||
|
||||
# Optional simplicial embeddings
|
||||
if self.simplicial_embedding is not None:
|
||||
x = self.simplicial_embedding(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class xFormerDecoderBlock(torch.nn.Module):
|
||||
r"""A vanilla Transformer Decoder block
|
||||
|
||||
... note: this implementation is not (yet ?) reversible"""
|
||||
|
||||
def __init__(self, config: xFormerDecoderConfig, **kwargs):
|
||||
super().__init__()
|
||||
deprecated_function(self)
|
||||
|
||||
# If this layer is the first one, and a pose encoding as been requested
|
||||
if (
|
||||
config.position_encoding_config is not None
|
||||
and config.layer_position.is_first()
|
||||
):
|
||||
self.pose_encoding = build_positional_embedding(
|
||||
config.position_encoding_config
|
||||
)
|
||||
|
||||
pos_encoding_dim = config.position_encoding_config.dim_model
|
||||
mha_dim = config.multi_head_config_masked["dim_model"]
|
||||
|
||||
if pos_encoding_dim != mha_dim:
|
||||
|
||||
logger.warning(
|
||||
f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa
|
||||
)
|
||||
|
||||
self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
|
||||
else:
|
||||
self.pose_encoding = None
|
||||
|
||||
if config.residual_norm_style == ResidualNormStyle.DeepNorm:
|
||||
# Just use the layer norm coefficient here,
|
||||
# the init will be handled at the xformers level (knows about encoder and decoder blocks)
|
||||
_, deep_norm_coefficients = get_deepnorm_coefficients(
|
||||
encoder_layers=0, decoder_layers=config.num_layers
|
||||
)
|
||||
assert deep_norm_coefficients is not None
|
||||
residual_scale = deep_norm_coefficients.alpha
|
||||
else:
|
||||
residual_scale = 1.0
|
||||
|
||||
# mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
|
||||
ln_factory = _get_ln_factory(
|
||||
config.dim_model,
|
||||
config.residual_norm_style,
|
||||
use_triton=config.use_triton,
|
||||
residual=True,
|
||||
residual_scale=residual_scale,
|
||||
normalization=config.normalization,
|
||||
)
|
||||
|
||||
mha = build_multi_head_attention(config.multi_head_config_masked)
|
||||
cross_mha = build_multi_head_attention(config.multi_head_config_cross)
|
||||
feedforward = build_feedforward(config.feedforward_config)
|
||||
|
||||
# Expose attention or feedforward specific capabilities
|
||||
self.supports_attention_mask = mha.attention.supports_attention_mask
|
||||
self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
|
||||
self.requires_squared_context_length = (
|
||||
feedforward.requires_squared_context
|
||||
or mha.attention.requires_squared_context
|
||||
)
|
||||
|
||||
self.causal_attention = (
|
||||
mha.attention.causal if hasattr(mha.attention, "causal") else False
|
||||
)
|
||||
|
||||
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
|
||||
self.wrap_att = ln_factory(mha)
|
||||
self.wrap_cross = ln_factory(cross_mha)
|
||||
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
|
||||
|
||||
if (
|
||||
config.residual_norm_style == ResidualNormStyle.Pre
|
||||
and config.layer_position.is_last()
|
||||
):
|
||||
self.wrap_ff = PostNorm(
|
||||
config.dim_model,
|
||||
self.wrap_ff,
|
||||
normalization=NormalizationType.LayerNorm,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: xFormerDecoderConfig):
|
||||
return cls(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
target: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
input_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.pose_encoding is not None:
|
||||
target = self.pose_encoding(target)
|
||||
|
||||
if hasattr(self, "embedding_projector"):
|
||||
target = self.embedding_projector(target)
|
||||
|
||||
# Handle the optional input masking, differs on Q, K, V
|
||||
if input_mask is not None:
|
||||
target_q = target
|
||||
target_k = target * input_mask.unsqueeze(-1)
|
||||
target_v = target_k
|
||||
else:
|
||||
target_q, target_k, target_v = target, target, target
|
||||
|
||||
x = self.wrap_att(
|
||||
inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
|
||||
)
|
||||
x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
|
||||
x = self.wrap_ff(inputs=[x])
|
||||
|
||||
return x
|
||||
Reference in New Issue
Block a user