First commit
This commit is contained in:
7
pkgs/xformers/helpers/__init__.py
Normal file
7
pkgs/xformers/helpers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from .timm_sparse_attention import TimmSparseAttention # noqa
|
||||
BIN
pkgs/xformers/helpers/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/helpers/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/helpers/__pycache__/test_utils.cpython-310.pyc
Normal file
BIN
pkgs/xformers/helpers/__pycache__/test_utils.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
124
pkgs/xformers/helpers/hierarchical_configs.py
Normal file
124
pkgs/xformers/helpers/hierarchical_configs.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from xformers._deprecation_warning import deprecated_function
|
||||
from xformers.components.residual import ResidualNormStyle
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicLayerConfig:
|
||||
embedding: int
|
||||
attention_mechanism: str
|
||||
patch_size: int
|
||||
stride: int
|
||||
padding: int
|
||||
seq_len: int
|
||||
feedforward: str
|
||||
normalization: str = "layernorm"
|
||||
repeat_layer: int = 1
|
||||
|
||||
|
||||
def get_hierarchical_configuration(
|
||||
layer_base_configs: List[BasicLayerConfig],
|
||||
residual_norm_style: ResidualNormStyle = ResidualNormStyle.Pre,
|
||||
use_rotary_embeddings: bool = True,
|
||||
mlp_multiplier: int = 4,
|
||||
in_channels: int = 3,
|
||||
dim_head: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
A small helper to generate hierarchical xformers configurations,
|
||||
which correspond for instance to poolformer or swin architectures.
|
||||
|
||||
Contrary to more "classical" Transformer architectures, which conserve the sequence/context
|
||||
length across layers, hierarchical Transformers trade the sequence length for the embedding dimension
|
||||
"""
|
||||
deprecated_function(get_hierarchical_configuration)
|
||||
|
||||
base_config: Dict[str, Any] = {
|
||||
"block_type": "encoder",
|
||||
"dim_model": 0,
|
||||
"use_triton": False,
|
||||
"residual_norm_style": str(residual_norm_style),
|
||||
"multi_head_config": {
|
||||
"num_heads": 1,
|
||||
"use_rotary_embeddings": use_rotary_embeddings,
|
||||
"attention": {
|
||||
"name": "TBD",
|
||||
},
|
||||
},
|
||||
"feedforward_config": {
|
||||
"name": "TBD",
|
||||
"activation": "gelu",
|
||||
"hidden_layer_multiplier": mlp_multiplier,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
"position_encoding_config": {
|
||||
"name": "learnable",
|
||||
"seq_len": 0,
|
||||
"add_class_token": False,
|
||||
},
|
||||
"patch_embedding_config": {
|
||||
"in_channels": in_channels,
|
||||
"kernel_size": 0,
|
||||
"stride": 0,
|
||||
"padding": 0,
|
||||
},
|
||||
}
|
||||
|
||||
xformers_config = []
|
||||
in_channels = in_channels
|
||||
|
||||
for layer_base_config in layer_base_configs:
|
||||
lc = copy.deepcopy(base_config)
|
||||
|
||||
lc["normalization"] = layer_base_config.normalization
|
||||
|
||||
# Fill in the changing model dimensions
|
||||
lc["dim_model"] = layer_base_config.embedding
|
||||
|
||||
# Update the patches
|
||||
lc["patch_embedding_config"] = {
|
||||
"in_channels": in_channels,
|
||||
"kernel_size": layer_base_config.patch_size,
|
||||
"stride": layer_base_config.stride,
|
||||
"padding": layer_base_config.padding,
|
||||
}
|
||||
|
||||
# Update the number of channels for the next layer
|
||||
in_channels = lc["dim_model"] * 1
|
||||
|
||||
lc["position_encoding_config"]["seq_len"] = layer_base_config.seq_len
|
||||
|
||||
# Fill in the number of heads (defaults to 1)
|
||||
if dim_head is not None:
|
||||
lc["multi_head_config"]["num_heads"] = (
|
||||
layer_base_config.embedding // dim_head
|
||||
)
|
||||
assert layer_base_config.embedding % dim_head == 0
|
||||
|
||||
# Fill in the attention mechanism
|
||||
lc["multi_head_config"]["attention"][
|
||||
"name"
|
||||
] = layer_base_config.attention_mechanism
|
||||
|
||||
# FIll in the feedforward
|
||||
lc["feedforward_config"]["name"] = layer_base_config.feedforward
|
||||
|
||||
print(lc)
|
||||
xformers_config.append(lc)
|
||||
|
||||
# Handle repeated layers (without the patch embeddings)
|
||||
if layer_base_config.repeat_layer > 1:
|
||||
lc_repeat = copy.deepcopy(lc)
|
||||
lc_repeat.pop("patch_embedding_config")
|
||||
xformers_config += [lc_repeat] * (layer_base_config.repeat_layer - 1)
|
||||
|
||||
return xformers_config
|
||||
32
pkgs/xformers/helpers/test_utils.py
Normal file
32
pkgs/xformers/helpers/test_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
is_windows = False
|
||||
if sys.platform == "win32": # pytorch on windows uses gloo not ncll
|
||||
is_windows = True
|
||||
|
||||
|
||||
def init_torch_distributed_local():
|
||||
if torch.distributed.is_initialized():
|
||||
return
|
||||
|
||||
init_url = "file://" + tempfile.mkstemp()[1]
|
||||
backend = (
|
||||
torch.distributed.Backend.NCCL
|
||||
if torch.cuda.is_available() and not is_windows
|
||||
else torch.distributed.Backend.GLOO
|
||||
)
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
rank=0,
|
||||
world_size=1,
|
||||
init_method=init_url,
|
||||
)
|
||||
55
pkgs/xformers/helpers/timm_sparse_attention.py
Normal file
55
pkgs/xformers/helpers/timm_sparse_attention.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
class TimmSparseAttention(torch.nn.Module):
|
||||
"""
|
||||
Almost drop-in replacement for timm attention
|
||||
but using the sparsity-aware scaled_dot_product_attention from xformers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_mask=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = torch.nn.Dropout(attn_drop)
|
||||
self.proj = torch.nn.Linear(dim, dim)
|
||||
self.proj_drop = torch.nn.Dropout(proj_drop)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
qkv = qkv.flatten(1, 2)
|
||||
|
||||
q, k, v = qkv.unbind()
|
||||
|
||||
x = scaled_dot_product_attention(
|
||||
q, k, v, self.attn_mask, dropout=self.attn_drop
|
||||
)
|
||||
x = x.reshape(B, self.num_heads, N, C // self.num_heads)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user