First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .timm_sparse_attention import TimmSparseAttention # noqa

View File

@@ -0,0 +1,124 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import copy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from xformers._deprecation_warning import deprecated_function
from xformers.components.residual import ResidualNormStyle
@dataclass
class BasicLayerConfig:
embedding: int
attention_mechanism: str
patch_size: int
stride: int
padding: int
seq_len: int
feedforward: str
normalization: str = "layernorm"
repeat_layer: int = 1
def get_hierarchical_configuration(
layer_base_configs: List[BasicLayerConfig],
residual_norm_style: ResidualNormStyle = ResidualNormStyle.Pre,
use_rotary_embeddings: bool = True,
mlp_multiplier: int = 4,
in_channels: int = 3,
dim_head: Optional[int] = None,
):
"""
A small helper to generate hierarchical xformers configurations,
which correspond for instance to poolformer or swin architectures.
Contrary to more "classical" Transformer architectures, which conserve the sequence/context
length across layers, hierarchical Transformers trade the sequence length for the embedding dimension
"""
deprecated_function(get_hierarchical_configuration)
base_config: Dict[str, Any] = {
"block_type": "encoder",
"dim_model": 0,
"use_triton": False,
"residual_norm_style": str(residual_norm_style),
"multi_head_config": {
"num_heads": 1,
"use_rotary_embeddings": use_rotary_embeddings,
"attention": {
"name": "TBD",
},
},
"feedforward_config": {
"name": "TBD",
"activation": "gelu",
"hidden_layer_multiplier": mlp_multiplier,
"dropout": 0.0,
},
"position_encoding_config": {
"name": "learnable",
"seq_len": 0,
"add_class_token": False,
},
"patch_embedding_config": {
"in_channels": in_channels,
"kernel_size": 0,
"stride": 0,
"padding": 0,
},
}
xformers_config = []
in_channels = in_channels
for layer_base_config in layer_base_configs:
lc = copy.deepcopy(base_config)
lc["normalization"] = layer_base_config.normalization
# Fill in the changing model dimensions
lc["dim_model"] = layer_base_config.embedding
# Update the patches
lc["patch_embedding_config"] = {
"in_channels": in_channels,
"kernel_size": layer_base_config.patch_size,
"stride": layer_base_config.stride,
"padding": layer_base_config.padding,
}
# Update the number of channels for the next layer
in_channels = lc["dim_model"] * 1
lc["position_encoding_config"]["seq_len"] = layer_base_config.seq_len
# Fill in the number of heads (defaults to 1)
if dim_head is not None:
lc["multi_head_config"]["num_heads"] = (
layer_base_config.embedding // dim_head
)
assert layer_base_config.embedding % dim_head == 0
# Fill in the attention mechanism
lc["multi_head_config"]["attention"][
"name"
] = layer_base_config.attention_mechanism
# FIll in the feedforward
lc["feedforward_config"]["name"] = layer_base_config.feedforward
print(lc)
xformers_config.append(lc)
# Handle repeated layers (without the patch embeddings)
if layer_base_config.repeat_layer > 1:
lc_repeat = copy.deepcopy(lc)
lc_repeat.pop("patch_embedding_config")
xformers_config += [lc_repeat] * (layer_base_config.repeat_layer - 1)
return xformers_config

View File

@@ -0,0 +1,32 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import sys
import tempfile
import torch
is_windows = False
if sys.platform == "win32": # pytorch on windows uses gloo not ncll
is_windows = True
def init_torch_distributed_local():
if torch.distributed.is_initialized():
return
init_url = "file://" + tempfile.mkstemp()[1]
backend = (
torch.distributed.Backend.NCCL
if torch.cuda.is_available() and not is_windows
else torch.distributed.Backend.GLOO
)
torch.distributed.init_process_group(
backend=backend,
rank=0,
world_size=1,
init_method=init_url,
)

View File

@@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.components.attention.core import scaled_dot_product_attention
class TimmSparseAttention(torch.nn.Module):
"""
Almost drop-in replacement for timm attention
but using the sparsity-aware scaled_dot_product_attention from xformers
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.0,
proj_drop=0.0,
attn_mask=None,
):
super().__init__()
self.num_heads = num_heads
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
self.proj = torch.nn.Linear(dim, dim)
self.proj_drop = torch.nn.Dropout(proj_drop)
self.attn_mask = attn_mask
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
qkv = qkv.flatten(1, 2)
q, k, v = qkv.unbind()
x = scaled_dot_product_attention(
q, k, v, self.attn_mask, dropout=self.attn_drop
)
x = x.reshape(B, self.num_heads, N, C // self.num_heads)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x