Files
enginex-bi_series-vllm/pkgs/xformers/helpers/hierarchical_configs.py
2025-08-05 19:02:46 +08:00

125 lines
3.8 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from xformers._deprecation_warning import deprecated_function
from xformers.components.residual import ResidualNormStyle
@dataclass
class BasicLayerConfig:
embedding: int
attention_mechanism: str
patch_size: int
stride: int
padding: int
seq_len: int
feedforward: str
normalization: str = "layernorm"
repeat_layer: int = 1
def get_hierarchical_configuration(
layer_base_configs: List[BasicLayerConfig],
residual_norm_style: ResidualNormStyle = ResidualNormStyle.Pre,
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
in_channels: int = 3,
dim_head: Optional[int] = None,
):
"""
A small helper to generate hierarchical xformers configurations,
which correspond for instance to poolformer or swin architectures.
Contrary to more "classical" Transformer architectures, which conserve the sequence/context
length across layers, hierarchical Transformers trade the sequence length for the embedding dimension
"""
deprecated_function(get_hierarchical_configuration)
base_config: Dict[str, Any] = {
"block_type": "encoder",
"dim_model": 0,
"use_triton": False,
"residual_norm_style": str(residual_norm_style),
"multi_head_config": {
"num_heads": 1,
"use_rotary_embeddings": use_rotary_embeddings,
"attention": {
"name": "TBD",
},
},
"feedforward_config": {
"name": "TBD",
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
},
"position_encoding_config": {
"name": "learnable",
"seq_len": 0,
"add_class_token": False,
},
"patch_embedding_config": {
"in_channels": in_channels,
"kernel_size": 0,
"stride": 0,
"padding": 0,
},
}
xformers_config = []
in_channels = in_channels
for layer_base_config in layer_base_configs:
lc = copy.deepcopy(base_config)
lc["normalization"] = layer_base_config.normalization
# Fill in the changing model dimensions
lc["dim_model"] = layer_base_config.embedding
# Update the patches
lc["patch_embedding_config"] = {
"in_channels": in_channels,
"kernel_size": layer_base_config.patch_size,
"stride": layer_base_config.stride,
"padding": layer_base_config.padding,
}
# Update the number of channels for the next layer
in_channels = lc["dim_model"] * 1
lc["position_encoding_config"]["seq_len"] = layer_base_config.seq_len
# Fill in the number of heads (defaults to 1)
if dim_head is not None:
lc["multi_head_config"]["num_heads"] = (
layer_base_config.embedding // dim_head
)
assert layer_base_config.embedding % dim_head == 0
# Fill in the attention mechanism
lc["multi_head_config"]["attention"][
"name"
] = layer_base_config.attention_mechanism
# FIll in the feedforward
lc["feedforward_config"]["name"] = layer_base_config.feedforward
print(lc)
xformers_config.append(lc)
# Handle repeated layers (without the patch embeddings)
if layer_base_config.repeat_layer > 1:
lc_repeat = copy.deepcopy(lc)
lc_repeat.pop("patch_embedding_config")
xformers_config += [lc_repeat] * (layer_base_config.repeat_layer - 1)
return xformers_config