First commit
This commit is contained in:
237
pkgs/xformers/factory/block_configs.py
Normal file
237
pkgs/xformers/factory/block_configs.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from xformers.components import NormalizationType, ResidualNormStyle
|
||||
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
|
||||
from xformers.components.positional_embedding import (
|
||||
POSITION_EMBEDDING_REGISTRY,
|
||||
PositionEmbeddingConfig,
|
||||
)
|
||||
from xformers.utils import generate_matching_config
|
||||
|
||||
|
||||
class LayerPositionBitmask(int, Enum):
|
||||
First = 0b01
|
||||
Last = 0b10
|
||||
Default = 0b11
|
||||
|
||||
|
||||
class LayerPosition:
|
||||
"""Bitmask to mark this layer as first, last, nothing or both"""
|
||||
|
||||
def __init__(self):
|
||||
self.bitmask = LayerPositionBitmask.Default
|
||||
|
||||
def is_first(self):
|
||||
return bool(self.bitmask & LayerPositionBitmask.First)
|
||||
|
||||
def is_last(self):
|
||||
return bool(self.bitmask & LayerPositionBitmask.Last)
|
||||
|
||||
def mark_not_first(self):
|
||||
self.bitmask &= ~LayerPositionBitmask.First
|
||||
|
||||
def mark_not_last(self):
|
||||
self.bitmask &= ~LayerPositionBitmask.Last
|
||||
|
||||
|
||||
class BlockType(str, Enum):
|
||||
Encoder = "encoder"
|
||||
Decoder = "decoder"
|
||||
|
||||
|
||||
@dataclass(init=False) # handle constructors explicitly to force type changes
|
||||
class xFormerBlockConfig:
|
||||
"""
|
||||
The configuration structure to define a Transformer block.
|
||||
This base class is applicable to both encoder and decoder definitions.
|
||||
|
||||
This completely defines each of the blocks, for instance in terms of dimensions,
|
||||
position encoding, pre or post layer norms or reversibility.
|
||||
"""
|
||||
|
||||
dim_model: int
|
||||
feedforward_config: FeedforwardConfig
|
||||
position_encoding_config: Optional[PositionEmbeddingConfig]
|
||||
block_type: BlockType
|
||||
residual_norm_style: ResidualNormStyle
|
||||
normalization: NormalizationType
|
||||
layer_position: LayerPosition
|
||||
use_triton: bool
|
||||
reversible: bool
|
||||
num_layers: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]],
|
||||
block_type: BlockType,
|
||||
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
reversible: bool = False,
|
||||
num_layers: int = 1,
|
||||
layer_position: Optional[LayerPosition] = None,
|
||||
):
|
||||
|
||||
self.dim_model = dim_model
|
||||
self.block_type = block_type
|
||||
self.residual_norm_style = residual_norm_style
|
||||
self.reversible = reversible
|
||||
self.num_layers = num_layers
|
||||
self.normalization = normalization
|
||||
|
||||
# Fill in possible gaps in the config for subparts of the block
|
||||
self.feedforward_config = generate_matching_config(
|
||||
feedforward_config,
|
||||
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
|
||||
)
|
||||
|
||||
self.position_encoding_config = (
|
||||
generate_matching_config(
|
||||
position_encoding_config,
|
||||
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
|
||||
)
|
||||
if position_encoding_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Default is that this layer is the only one, so both first and last
|
||||
if layer_position:
|
||||
self.layer_position = layer_position
|
||||
else:
|
||||
self.layer_position = LayerPosition()
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerEncoderConfig(xFormerBlockConfig):
|
||||
"""
|
||||
The configuration structure for an encoder block
|
||||
"""
|
||||
|
||||
multi_head_config: Dict[str, Any]
|
||||
use_triton: bool
|
||||
simplicial_embeddings: Optional[Dict[str, Any]]
|
||||
patch_embedding_config: Optional[Dict[str, Any]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
multi_head_config: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||||
residual_norm_style: str = "post",
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
use_triton: bool = True,
|
||||
simplicial_embeddings: Optional[Dict[str, Any]] = None,
|
||||
patch_embedding_config: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Convenience, fill in duplicated fields
|
||||
try:
|
||||
if "dim_model" not in multi_head_config.keys():
|
||||
multi_head_config["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in feedforward_config.keys():
|
||||
feedforward_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
position_encoding_config is not None
|
||||
and "dim_model" not in position_encoding_config.keys()
|
||||
):
|
||||
position_encoding_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
patch_embedding_config is not None
|
||||
and "out_channels" not in patch_embedding_config.keys()
|
||||
):
|
||||
patch_embedding_config["out_channels"] = dim_model
|
||||
|
||||
except AttributeError:
|
||||
# A config instance was passed in, this is fine
|
||||
pass
|
||||
if "block_type" in kwargs:
|
||||
assert kwargs["block_type"] == "encoder"
|
||||
kwargs["block_type"] = BlockType("encoder")
|
||||
super().__init__(
|
||||
dim_model=dim_model,
|
||||
feedforward_config=feedforward_config,
|
||||
position_encoding_config=position_encoding_config,
|
||||
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||||
normalization=NormalizationType(normalization),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.multi_head_config = multi_head_config
|
||||
self.use_triton = use_triton
|
||||
self.simplicial_embeddings = simplicial_embeddings
|
||||
self.patch_embedding_config = patch_embedding_config
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class xFormerDecoderConfig(xFormerBlockConfig):
|
||||
"""
|
||||
The configuration structure for a decoder block.
|
||||
|
||||
This specifically defines the masked and cross attention mechanisms,
|
||||
on top of the settings defining all blocks.
|
||||
"""
|
||||
|
||||
multi_head_config_masked: Dict[str, Any] # prior to encoder output
|
||||
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
feedforward_config: Dict[str, Any],
|
||||
multi_head_config_masked: Dict[str, Any],
|
||||
multi_head_config_cross: Dict[str, Any],
|
||||
position_encoding_config: Optional[Dict[str, Any]] = None,
|
||||
residual_norm_style: str = "post",
|
||||
normalization: NormalizationType = NormalizationType.LayerNorm,
|
||||
use_triton: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Convenience, fill in duplicated field
|
||||
try:
|
||||
if "dim_model" not in multi_head_config_masked.keys():
|
||||
multi_head_config_masked["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in multi_head_config_cross.keys():
|
||||
multi_head_config_cross["dim_model"] = dim_model
|
||||
|
||||
if "dim_model" not in feedforward_config.keys():
|
||||
feedforward_config["dim_model"] = dim_model
|
||||
|
||||
if (
|
||||
position_encoding_config is not None
|
||||
and "dim_model" not in position_encoding_config.keys()
|
||||
):
|
||||
position_encoding_config["dim_model"] = dim_model
|
||||
except AttributeError:
|
||||
# A config instance was passed in, this is fine
|
||||
pass
|
||||
if "block_type" in kwargs.keys():
|
||||
assert kwargs["block_type"] == "decoder"
|
||||
kwargs["block_type"] = BlockType("decoder")
|
||||
|
||||
super().__init__(
|
||||
dim_model=dim_model,
|
||||
feedforward_config=feedforward_config,
|
||||
position_encoding_config=position_encoding_config,
|
||||
residual_norm_style=ResidualNormStyle(residual_norm_style),
|
||||
normalization=NormalizationType(normalization),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.multi_head_config_masked = multi_head_config_masked
|
||||
self.multi_head_config_cross = multi_head_config_cross
|
||||
self.use_triton = use_triton
|
||||
Reference in New Issue
Block a user