Files
enginex-bi_series-vllm/pkgs/xformers/factory/block_configs.py

238 lines
8.0 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional
from xformers.components import NormalizationType, ResidualNormStyle
from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig
from xformers.components.positional_embedding import (
POSITION_EMBEDDING_REGISTRY,
PositionEmbeddingConfig,
)
from xformers.utils import generate_matching_config
class LayerPositionBitmask(int, Enum):
First = 0b01
Last = 0b10
Default = 0b11
class LayerPosition:
"""Bitmask to mark this layer as first, last, nothing or both"""
def __init__(self):
self.bitmask = LayerPositionBitmask.Default
def is_first(self):
return bool(self.bitmask & LayerPositionBitmask.First)
def is_last(self):
return bool(self.bitmask & LayerPositionBitmask.Last)
def mark_not_first(self):
self.bitmask &= ~LayerPositionBitmask.First
def mark_not_last(self):
self.bitmask &= ~LayerPositionBitmask.Last
class BlockType(str, Enum):
Encoder = "encoder"
Decoder = "decoder"
@dataclass(init=False) # handle constructors explicitly to force type changes
class xFormerBlockConfig:
"""
The configuration structure to define a Transformer block.
This base class is applicable to both encoder and decoder definitions.
This completely defines each of the blocks, for instance in terms of dimensions,
position encoding, pre or post layer norms or reversibility.
"""
dim_model: int
feedforward_config: FeedforwardConfig
position_encoding_config: Optional[PositionEmbeddingConfig]
block_type: BlockType
residual_norm_style: ResidualNormStyle
normalization: NormalizationType
layer_position: LayerPosition
use_triton: bool
reversible: bool
num_layers: int
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]],
block_type: BlockType,
residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"),
normalization: NormalizationType = NormalizationType.LayerNorm,
reversible: bool = False,
num_layers: int = 1,
layer_position: Optional[LayerPosition] = None,
):
self.dim_model = dim_model
self.block_type = block_type
self.residual_norm_style = residual_norm_style
self.reversible = reversible
self.num_layers = num_layers
self.normalization = normalization
# Fill in possible gaps in the config for subparts of the block
self.feedforward_config = generate_matching_config(
feedforward_config,
FEEDFORWARD_REGISTRY[feedforward_config["name"]].config,
)
self.position_encoding_config = (
generate_matching_config(
position_encoding_config,
POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config,
)
if position_encoding_config is not None
else None
)
# Default is that this layer is the only one, so both first and last
if layer_position:
self.layer_position = layer_position
else:
self.layer_position = LayerPosition()
@dataclass(init=False)
class xFormerEncoderConfig(xFormerBlockConfig):
"""
The configuration structure for an encoder block
"""
multi_head_config: Dict[str, Any]
use_triton: bool
simplicial_embeddings: Optional[Dict[str, Any]]
patch_embedding_config: Optional[Dict[str, Any]]
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
multi_head_config: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]] = None,
residual_norm_style: str = "post",
normalization: NormalizationType = NormalizationType.LayerNorm,
use_triton: bool = True,
simplicial_embeddings: Optional[Dict[str, Any]] = None,
patch_embedding_config: Optional[Dict[str, Any]] = None,
**kwargs,
):
# Convenience, fill in duplicated fields
try:
if "dim_model" not in multi_head_config.keys():
multi_head_config["dim_model"] = dim_model
if "dim_model" not in feedforward_config.keys():
feedforward_config["dim_model"] = dim_model
if (
position_encoding_config is not None
and "dim_model" not in position_encoding_config.keys()
):
position_encoding_config["dim_model"] = dim_model
if (
patch_embedding_config is not None
and "out_channels" not in patch_embedding_config.keys()
):
patch_embedding_config["out_channels"] = dim_model
except AttributeError:
# A config instance was passed in, this is fine
pass
if "block_type" in kwargs:
assert kwargs["block_type"] == "encoder"
kwargs["block_type"] = BlockType("encoder")
super().__init__(
dim_model=dim_model,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
residual_norm_style=ResidualNormStyle(residual_norm_style),
normalization=NormalizationType(normalization),
**kwargs,
)
self.multi_head_config = multi_head_config
self.use_triton = use_triton
self.simplicial_embeddings = simplicial_embeddings
self.patch_embedding_config = patch_embedding_config
@dataclass(init=False)
class xFormerDecoderConfig(xFormerBlockConfig):
"""
The configuration structure for a decoder block.
This specifically defines the masked and cross attention mechanisms,
on top of the settings defining all blocks.
"""
multi_head_config_masked: Dict[str, Any] # prior to encoder output
multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output
def __init__(
self,
dim_model: int,
feedforward_config: Dict[str, Any],
multi_head_config_masked: Dict[str, Any],
multi_head_config_cross: Dict[str, Any],
position_encoding_config: Optional[Dict[str, Any]] = None,
residual_norm_style: str = "post",
normalization: NormalizationType = NormalizationType.LayerNorm,
use_triton: bool = True,
**kwargs,
):
# Convenience, fill in duplicated field
try:
if "dim_model" not in multi_head_config_masked.keys():
multi_head_config_masked["dim_model"] = dim_model
if "dim_model" not in multi_head_config_cross.keys():
multi_head_config_cross["dim_model"] = dim_model
if "dim_model" not in feedforward_config.keys():
feedforward_config["dim_model"] = dim_model
if (
position_encoding_config is not None
and "dim_model" not in position_encoding_config.keys()
):
position_encoding_config["dim_model"] = dim_model
except AttributeError:
# A config instance was passed in, this is fine
pass
if "block_type" in kwargs.keys():
assert kwargs["block_type"] == "decoder"
kwargs["block_type"] = BlockType("decoder")
super().__init__(
dim_model=dim_model,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
residual_norm_style=ResidualNormStyle(residual_norm_style),
normalization=NormalizationType(normalization),
**kwargs,
)
self.multi_head_config_masked = multi_head_config_masked
self.multi_head_config_cross = multi_head_config_cross
self.use_triton = use_triton