238 lines
8.0 KiB
Python
238 lines
8.0 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Any, Dict, Optional
|
||
|
|
|
||
|
|
from xformers.components import NormalizationType, ResidualNormStyle
|
||
|
|
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
|
||
|
|
from xformers.components.positional_embedding import (
|
||
|
|
POSITION_EMBEDDING_REGISTRY,
|
||
|
|
PositionEmbeddingConfig,
|
||
|
|
)
|
||
|
|
from xformers.utils import generate_matching_config
|
||
|
|
|
||
|
|
|
||
|
|
class LayerPositionBitmask(int, Enum):
|
||
|
|
First = 0b01
|
||
|
|
Last = 0b10
|
||
|
|
Default = 0b11
|
||
|
|
|
||
|
|
|
||
|
|
class LayerPosition:
|
||
|
|
"""Bitmask to mark this layer as first, last, nothing or both"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.bitmask = LayerPositionBitmask.Default
|
||
|
|
|
||
|
|
def is_first(self):
|
||
|
|
return bool(self.bitmask & LayerPositionBitmask.First)
|
||
|
|
|
||
|
|
def is_last(self):
|
||
|
|
return bool(self.bitmask & LayerPositionBitmask.Last)
|
||
|
|
|
||
|
|
def mark_not_first(self):
|
||
|
|
self.bitmask &= ~LayerPositionBitmask.First
|
||
|
|
|
||
|
|
def mark_not_last(self):
|
||
|
|
self.bitmask &= ~LayerPositionBitmask.Last
|
||
|
|
|
||
|
|
|
||
|
|
class BlockType(str, Enum):
|
||
|
|
Encoder = "encoder"
|
||
|
|
Decoder = "decoder"
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(init=False) # handle constructors explicitly to force type changes
|
||
|
|
class xFormerBlockConfig:
|
||
|
|
"""
|
||
|
|
The configuration structure to define a Transformer block.
|
||
|
|
This base class is applicable to both encoder and decoder definitions.
|
||
|
|
|
||
|
|
This completely defines each of the blocks, for instance in terms of dimensions,
|
||
|
|
position encoding, pre or post layer norms or reversibility.
|
||
|
|
"""
|
||
|
|
|
||
|
|
dim_model: int
|
||
|
|
feedforward_config: FeedforwardConfig
|
||
|
|
position_encoding_config: Optional[PositionEmbeddingConfig]
|
||
|
|
block_type: BlockType
|
||
|
|
residual_norm_style: ResidualNormStyle
|
||
|
|
normalization: NormalizationType
|
||
|
|
layer_position: LayerPosition
|
||
|
|
use_triton: bool
|
||
|
|
reversible: bool
|
||
|
|
num_layers: int
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
dim_model: int,
|
||
|
|
feedforward_config: Dict[str, Any],
|
||
|
|
position_encoding_config: Optional[Dict[str, Any]],
|
||
|
|
block_type: BlockType,
|
||
|
|
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
|
||
|
|
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||
|
|
reversible: bool = False,
|
||
|
|
num_layers: int = 1,
|
||
|
|
layer_position: Optional[LayerPosition] = None,
|
||
|
|
):
|
||
|
|
|
||
|
|
self.dim_model = dim_model
|
||
|
|
self.block_type = block_type
|
||
|
|
self.residual_norm_style = residual_norm_style
|
||
|
|
self.reversible = reversible
|
||
|
|
self.num_layers = num_layers
|
||
|
|
self.normalization = normalization
|
||
|
|
|
||
|
|
# Fill in possible gaps in the config for subparts of the block
|
||
|
|
self.feedforward_config = generate_matching_config(
|
||
|
|
feedforward_config,
|
||
|
|
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.position_encoding_config = (
|
||
|
|
generate_matching_config(
|
||
|
|
position_encoding_config,
|
||
|
|
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
|
||
|
|
)
|
||
|
|
if position_encoding_config is not None
|
||
|
|
else None
|
||
|
|
)
|
||
|
|
|
||
|
|
# Default is that this layer is the only one, so both first and last
|
||
|
|
if layer_position:
|
||
|
|
self.layer_position = layer_position
|
||
|
|
else:
|
||
|
|
self.layer_position = LayerPosition()
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(init=False)
|
||
|
|
class xFormerEncoderConfig(xFormerBlockConfig):
|
||
|
|
"""
|
||
|
|
The configuration structure for an encoder block
|
||
|
|
"""
|
||
|
|
|
||
|
|
multi_head_config: Dict[str, Any]
|
||
|
|
use_triton: bool
|
||
|
|
simplicial_embeddings: Optional[Dict[str, Any]]
|
||
|
|
patch_embedding_config: Optional[Dict[str, Any]]
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
dim_model: int,
|
||
|
|
feedforward_config: Dict[str, Any],
|
||
|
|
multi_head_config: Dict[str, Any],
|
||
|
|
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||
|
|
residual_norm_style: str = "post",
|
||
|
|
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||
|
|
use_triton: bool = True,
|
||
|
|
simplicial_embeddings: Optional[Dict[str, Any]] = None,
|
||
|
|
patch_embedding_config: Optional[Dict[str, Any]] = None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
# Convenience, fill in duplicated fields
|
||
|
|
try:
|
||
|
|
if "dim_model" not in multi_head_config.keys():
|
||
|
|
multi_head_config["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if "dim_model" not in feedforward_config.keys():
|
||
|
|
feedforward_config["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if (
|
||
|
|
position_encoding_config is not None
|
||
|
|
and "dim_model" not in position_encoding_config.keys()
|
||
|
|
):
|
||
|
|
position_encoding_config["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if (
|
||
|
|
patch_embedding_config is not None
|
||
|
|
and "out_channels" not in patch_embedding_config.keys()
|
||
|
|
):
|
||
|
|
patch_embedding_config["out_channels"] = dim_model
|
||
|
|
|
||
|
|
except AttributeError:
|
||
|
|
# A config instance was passed in, this is fine
|
||
|
|
pass
|
||
|
|
if "block_type" in kwargs:
|
||
|
|
assert kwargs["block_type"] == "encoder"
|
||
|
|
kwargs["block_type"] = BlockType("encoder")
|
||
|
|
super().__init__(
|
||
|
|
dim_model=dim_model,
|
||
|
|
feedforward_config=feedforward_config,
|
||
|
|
position_encoding_config=position_encoding_config,
|
||
|
|
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||
|
|
normalization=NormalizationType(normalization),
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.multi_head_config = multi_head_config
|
||
|
|
self.use_triton = use_triton
|
||
|
|
self.simplicial_embeddings = simplicial_embeddings
|
||
|
|
self.patch_embedding_config = patch_embedding_config
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(init=False)
|
||
|
|
class xFormerDecoderConfig(xFormerBlockConfig):
|
||
|
|
"""
|
||
|
|
The configuration structure for a decoder block.
|
||
|
|
|
||
|
|
This specifically defines the masked and cross attention mechanisms,
|
||
|
|
on top of the settings defining all blocks.
|
||
|
|
"""
|
||
|
|
|
||
|
|
multi_head_config_masked: Dict[str, Any] # prior to encoder output
|
||
|
|
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
dim_model: int,
|
||
|
|
feedforward_config: Dict[str, Any],
|
||
|
|
multi_head_config_masked: Dict[str, Any],
|
||
|
|
multi_head_config_cross: Dict[str, Any],
|
||
|
|
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||
|
|
residual_norm_style: str = "post",
|
||
|
|
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||
|
|
use_triton: bool = True,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
|
||
|
|
# Convenience, fill in duplicated field
|
||
|
|
try:
|
||
|
|
if "dim_model" not in multi_head_config_masked.keys():
|
||
|
|
multi_head_config_masked["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if "dim_model" not in multi_head_config_cross.keys():
|
||
|
|
multi_head_config_cross["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if "dim_model" not in feedforward_config.keys():
|
||
|
|
feedforward_config["dim_model"] = dim_model
|
||
|
|
|
||
|
|
if (
|
||
|
|
position_encoding_config is not None
|
||
|
|
and "dim_model" not in position_encoding_config.keys()
|
||
|
|
):
|
||
|
|
position_encoding_config["dim_model"] = dim_model
|
||
|
|
except AttributeError:
|
||
|
|
# A config instance was passed in, this is fine
|
||
|
|
pass
|
||
|
|
if "block_type" in kwargs.keys():
|
||
|
|
assert kwargs["block_type"] == "decoder"
|
||
|
|
kwargs["block_type"] = BlockType("decoder")
|
||
|
|
|
||
|
|
super().__init__(
|
||
|
|
dim_model=dim_model,
|
||
|
|
feedforward_config=feedforward_config,
|
||
|
|
position_encoding_config=position_encoding_config,
|
||
|
|
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||
|
|
normalization=NormalizationType(normalization),
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
self.multi_head_config_masked = multi_head_config_masked
|
||
|
|
self.multi_head_config_cross = multi_head_config_cross
|
||
|
|
self.use_triton = use_triton
|