First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import fields
from pathlib import Path
from typing import Any, Dict, Union
from xformers.utils import import_all_modules
from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .input_projection import InputProjection, InputProjectionConfig # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .patch_embedding import PatchEmbeddingConfig # noqa
from .patch_embedding import build_patch_embedding # noqa
from .residual import NormalizationType # noqa
from .residual import PostNorm # noqa
from .residual import PreNorm # noqa
from .residual import RequiresWrappedInputs # noqa
from .residual import Residual # noqa
from .residual import ResidualNormStyle # noqa
# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components")
def build_multi_head_attention(
multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]],
):
"""Builds a multihead attention from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_attention",
"foo": "bar"}` will find a class that was registered as "my_attention"
(see :func:`register_attention`) and call .from_config on it."""
if not isinstance(multi_head_config, MultiHeadDispatchConfig):
# Extract the required fields
field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig)))
# The missing fields get Noned
for k in field_names:
if k not in multi_head_config.keys():
multi_head_config[k] = None
# Could be that the attention needs to be instantiated
if not isinstance(multi_head_config["attention"], Attention):
# Convenience: fill in possible missing fields
if "num_heads" not in multi_head_config["attention"]:
multi_head_config["attention"]["num_heads"] = multi_head_config[
"num_heads"
]
if "dim_model" not in multi_head_config["attention"]:
multi_head_config["attention"]["dim_model"] = multi_head_config[
"dim_model"
]
if (
"dim_features" not in multi_head_config["attention"]
or multi_head_config["attention"]["dim_features"] is None
):
multi_head_config["attention"]["dim_features"] = (
multi_head_config["dim_model"] // multi_head_config["num_heads"]
)
multi_head_config["attention"] = build_attention(
multi_head_config["attention"]
)
multi_head_config = MultiHeadDispatchConfig(**multi_head_config)
return MultiHeadDispatch.from_config(multi_head_config)

View File

@@ -0,0 +1,71 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Optional
import torch
from torch import nn
class Activation(str, Enum):
SquaredReLU = "squared_relu"
GeLU = "gelu"
LeakyReLU = "leaky_relu"
ReLU = "relu"
SmeLU = "smelu"
StarReLU = "star_relu"
# For unit testing / parity comparisons, probably not the fastest way
class SquaredReLU(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_ = torch.nn.functional.relu(x)
return x_ * x_
class StarReLU(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_ = torch.nn.functional.relu(x)
return 0.8944 * x_ * x_ - 0.4472
class SmeLU(nn.Module):
def __init__(self, beta: float = 2.0) -> None:
super().__init__()
self.beta = beta
def forward(self, x: torch.Tensor) -> torch.Tensor:
relu = torch.where(
x >= self.beta,
x,
torch.tensor([0.0], device=x.device, dtype=x.dtype),
)
return torch.where(
torch.abs(x) <= self.beta,
((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta),
relu,
)
def build_activation(activation: Optional[Activation]):
if not activation:
return nn.Identity()
return {
Activation.ReLU: nn.ReLU,
Activation.GeLU: nn.GELU,
Activation.LeakyReLU: nn.LeakyReLU,
Activation.SquaredReLU: SquaredReLU,
Activation.StarReLU: StarReLU,
Activation.SmeLU: SmeLU,
}[activation]()

View File

@@ -0,0 +1,133 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Set, Union
import torch
from xformers.utils import (
generate_matching_config,
get_registry_decorator,
import_all_modules,
)
from ._sputnik_sparse import SparseCS
from .attention_mask import AttentionMask
from .base import Attention, AttentionConfig # noqa
logger = logging.getLogger("xformers")
# CREDITS: Classy Vision registry mechanism
ATTENTION_REGISTRY: Dict[str, Any] = {}
ATTENTION_CLASS_NAMES: Set[str] = set()
# Arbitrary threshold for now,
# in between dense and sparse matrix algorithms for the attention mechanism
_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs.
_USE_SPUTNIK = True
def build_attention(config: Union[Dict[str, Any], AttentionConfig]):
"""Builds an attention from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_attention",
"foo": "bar"}` will find a class that was registered as "my_attention"
(see :func:`register_attention`) and call .from_config on it."""
if not isinstance(config, AttentionConfig):
try:
config_instance = generate_matching_config(
config, ATTENTION_REGISTRY[config["name"]].config
)
except KeyError as e:
name = config["name"]
logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}")
raise e
else:
config_instance = config
return ATTENTION_REGISTRY[config_instance.name].constructor.from_config(
config_instance
)
"""Registers an Attention subclass.
This decorator allows xFormers to instantiate a subclass of Attention
from a configuration file, even if the class itself is not part of the
xFormers library. To use it, apply this decorator to an Attention
subclass, like this:
.. code-block:: python
@dataclass
class MyConfig:
...
@register_attention('my_attention', MyConfig)
class MyAttention(Attention):
...
To instantiate an attention from a configuration file, see :func:`build_attention`."""
register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator(
ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig
)
def maybe_sparsify(matrix) -> Any:
# Sparsify if that makes sense
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
# If not sparse, then AttentionMask is the reference type
return AttentionMask.from_bool(matrix)
return sparsify(matrix)
def sparsify(matrix):
if _USE_SPUTNIK:
return SparseCS(matrix)
return matrix.to_sparse()
from .favor import FavorAttention # noqa
from .global_tokens import GlobalAttention # noqa
from .linformer import LinformerAttention # noqa
from .local import LocalAttention # noqa
from .nystrom import NystromAttention # noqa
from .ortho import OrthoFormerAttention # noqa
from .random import RandomAttention # noqa
from .scaled_dot_product import ScaledDotProduct # noqa
__all__ = [
"ScaledDotProduct",
"LocalAttention",
"LinformerAttention",
"NystromAttention",
"RandomAttention",
"OrthoFormerAttention",
"GlobalAttention",
"FavorAttention",
"Attention",
"AttentionMask",
"build_attention",
"register_attention",
]
# Optionally expose the BlockSparse attention
try:
from .blocksparse import BlockSparseAttention # noqa
__all__ += ["BlockSparseAttention"]
except ImportError:
pass
# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components.attention")

View File

@@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.ops import masked_matmul
from xformers.sparse import SparseCSRTensor
# TODO: this is here for BC
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
class SparseCS:
def __init__(self, matrix, device=None):
if device is None:
device = torch.device("cpu")
if matrix.ndim == 2:
matrix = matrix[None]
assert matrix.ndim == 3
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
@property
def device(self):
return self._mat.device
@property
def ndim(self):
return self._mat.ndim
@property
def dtype(self):
return self._mat.dtype
@property
def is_sparse(self):
return True
@property
def shape(self):
return self._mat.shape[1:]
@property
def values(self):
return self._mat.values()
@property
def row_indices(self):
return self._mat._csr_row_indices
@property
def column_indices(self):
return self._mat._csr_column_indices
@property
def row_offsets(self):
return self._mat._csr_row_offsets
@property
def _transp_info(self):
return self._mat._csr_transp_info
@classmethod
def wrap(
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
):
matrix = cls.__new__(cls)
_shape = (values.shape[0],) + shape
csr_matrix = SparseCSRTensor._wrap(
_shape, values, row_indices, row_offsets, column_indices, _transp_info
)
matrix._mat = csr_matrix
return matrix
@classmethod
def _wrap(cls, csr_matrix):
assert isinstance(csr_matrix, SparseCSRTensor)
matrix = cls.__new__(cls)
matrix._mat = csr_matrix
return matrix
def __mul__(self, other):
assert isinstance(other, (int, float))
return type(self)._wrap(self._mat * other)
def __add__(self, other):
assert isinstance(other, type(self))
return type(self)._wrap(self._mat + other._mat)
def matmul_with_mask(self, a, b):
return type(self)._wrap(masked_matmul(a, b, self._mat))
def softmax(self):
out = torch.nn.functional.softmax(self._mat, -1)
return type(self)._wrap(out)
def spmm(self, b):
out = torch.bmm(self._mat, b)
return out
def transpose(self):
out = torch.transpose(self._mat, -2, -1)
return type(self)._wrap(out)
def to(self, device):
assert isinstance(device, torch.device)
out = self._mat.to(device)
return type(self)._wrap(out)
def to_dense(self):
return self._mat.to_dense()
def logical_and(self, other: torch.Tensor):
assert not isinstance(other, SparseCS)
out = torch.logical_and(self._mat, other)
return type(self)._wrap(out)
def __and__(self, other):
return self.logical_and(other)

View File

@@ -0,0 +1,143 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Type, TypeVar
import torch
Self = TypeVar("Self", bound="AttentionMask")
class AttentionMask:
"""
Holds an attention mask, along with a couple of helpers and attributes.
.. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value,
and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose
is to bias the attention computation for instance
.. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`,
`[to_sequence, from_sequence]`, or anything broadcastable in between
"""
def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
assert additive_mask.is_floating_point(), additive_mask.dtype
assert not additive_mask.requires_grad
if additive_mask.ndim == 2:
additive_mask = additive_mask.unsqueeze(0)
self.values = additive_mask
self.is_causal = is_causal
self.seq_len = additive_mask.shape[1]
self.to_seq_len = additive_mask.shape[0]
def to_bool(self) -> torch.Tensor:
"""
.. warning: we assume here that True implies that the value should be computed
"""
return self.values != float("-inf")
@classmethod
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
"""
Create an AttentionMask given a boolean pattern.
.. warning: we assume here that True implies that the value should be computed
"""
assert x.dtype == torch.bool
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
additive_mask.masked_fill_(x, 0.0)
additive_mask.masked_fill_(~x, float("-inf"))
return cls(additive_mask)
@classmethod
def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
"""
Create an AttentionMask given a multiplicative attention mask.
"""
assert not x.dtype == torch.bool
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
x = x.bool()
additive_mask.masked_fill_(x, 0.0)
additive_mask.masked_fill_(~x, float("-inf"))
return cls(additive_mask)
@classmethod
def make_causal(
cls: Type[Self],
seq_len: int,
to_seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not to_seq_len:
to_seq_len = seq_len
additive_mask = torch.triu(
torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"),
diagonal=1,
)
return cls(additive_mask=additive_mask, is_causal=True)
def make_crop(
self, seq_len: int, to_seq_len: Optional[int] = None
) -> "AttentionMask":
"""
Return a cropped attention mask, whose underlying tensor is a view of this one
"""
if not to_seq_len:
to_seq_len = seq_len
return AttentionMask(
self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal
)
def __repr__(self):
return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values)
@property
def device(self):
return self.values.device
@property
def is_sparse(self):
return False
@property
def ndim(self):
return len(self.values.shape)
@property
def dtype(self):
return self.values.dtype
@property
def shape(self):
return self.values.shape
def __add__(self, other):
return AttentionMask(self.values + other.values, is_causal=False)
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> "AttentionMask":
assert device is None or isinstance(device, torch.device)
assert dtype is None or isinstance(dtype, torch.dtype)
assert device is not None or dtype is not None
# Noop if we don't need to create another instance
if ((device and device == self.device) or not device) and (
(dtype and dtype == self.dtype) or not dtype
):
return self
return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)

View File

@@ -0,0 +1,295 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
import numpy as np
import torch
from xformers.components.attention.sparsity_config import (
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
)
# generic nd cases
def _generate_nd_grid(*sizes):
coords = [torch.arange(s) for s in sizes]
return torch.meshgrid(*coords)
def local_nd_distance(*sizes, p=2.0, weights=None):
if weights is None:
weights = (1,) * len(sizes)
assert len(sizes) == len(weights)
grid = _generate_nd_grid(*sizes)
grid = [i.flatten() * w for i, w in zip(grid, weights)]
grid = torch.stack(grid, dim=1).float()
d = torch.cdist(grid, grid, p=p)
return d
def local_nd_gaussian_distribution(*sizes, sigma=1):
d = local_nd_distance(*sizes, p=2.0) ** 2
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
return d
def local_nd_pattern(*sizes, distance, p=2.0):
d = local_nd_distance(*sizes, p=p)
return d < distance
def axial_nd_pattern(*sizes):
# axial is a special case with p=0 and distance=2
d = local_nd_distance(*sizes, p=0)
return d < 2
def random_pattern_from_probability_matrix(dist_matrix, nnz):
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
# PyTorch multinomial wrongly doesn't support sampling when number of categories
# is > 2^24, arguing that it's because it's the max representable consecutive element
# in fp32 and that the kernels use float32. This is actually not true, and the kernels
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
if dist_matrix.numel() > 2**24:
dist_matrix = dist_matrix.double()
dist_matrix /= dist_matrix.sum()
idxs = np.random.choice(
dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
)
idxs = torch.as_tensor(idxs)
else:
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
att.view(-1)[idxs] = True
return att
def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor:
assert attention_query_mask.ndim == 1
assert attention_query_mask.dtype == torch.bool
attention_query_mask = attention_query_mask[None, :]
mask = attention_query_mask | attention_query_mask.transpose(1, 0)
return mask
def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
assert 0 < sparsity < 1
mask = torch.rand(attn_size, attn_size) > sparsity
return mask
# 1d-specific cases
def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
h_win_size = window_size // 2 + 1
return local_nd_pattern(attn_size, distance=h_win_size, p=1.0)
def causal_1d_pattern(attn_size: int) -> torch.Tensor:
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
return mask
# 2d-specific cases
def horizontal_axial_2d_distance(H, W, p=2.0):
d = local_nd_distance(H, W, p=p, weights=(1, 0))
return d
def vertical_axial_2d_distance(H, W, p=2.0):
d = local_nd_distance(H, W, p=p, weights=(0, 1))
return d
def local_2d_distance(H, W, p=2.0):
return local_nd_distance(H, W, p=p)
def local_2d_gausian_distribution(H, W, sigma=1):
return local_nd_gaussian_distribution(H, W, sigma=sigma)
def local_2d_pattern(H, W, distance, p=2.0):
return local_nd_pattern(H, W, distance=distance, p=p)
def axial_2d_pattern(H, W):
return axial_nd_pattern(H, W)
def swin_attention_pattern(H, W, window_size, shift_size=0):
assert H % window_size == 0
assert W % window_size == 0
assert 0 <= shift_size < window_size, "shift_size must in 0-window_size"
# input grid
i, j = _generate_nd_grid(H, W)
i, j = i + 0.5, j + 0.5
# anchors grid
# if shift is present, add extra element to the grid
# to account for the uneven partitioning
extra = int(shift_size % window_size != 0)
grid_h = H // window_size + extra
grid_w = W // window_size + extra
ii, jj = _generate_nd_grid(grid_h, grid_w)
# convert shift to be compatible with the paper representation
s = (-shift_size) % window_size
offset = window_size / 2 - s
ii = ii * window_size + offset
jj = jj * window_size + offset
input_coords = torch.stack([i.flatten(), j.flatten()], 1).float()
anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float()
anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
mask = anchor_id[:, None] == anchor_id[None, :]
return mask
def dilated_2d_pattern(H, W, k=2):
"""
Returns a 2d pattern that samples 1 every k elements in the attention mask.
Can be seen as a form of downsampling, where every pixel attends to a downsampled
version of the input.
"""
d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
return d
# Block sparse utils
def block_sparsify_tensor(x, mask, block_size):
"""
Block sparsify a tensor, given a mask and block size
"""
ret = torch.empty(
(x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[
:,
h,
i * block_size : (i + 1) * block_size,
j * block_size : (j + 1) * block_size,
]
return ret
def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
r"""
Given a mask pattern and blocksize, return the corresponding layout
which makes sure that all the positives in the mask are covered
"""
assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]"
_should_squeeze = False
if mask.ndim == 2:
mask = mask.unsqueeze(0)
_should_squeeze = True
assert (
mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0
), "We're only handling masks divisible by block_size"
# Now mark the mask
layout = torch.nn.functional.max_pool2d(
mask.to(torch.float), kernel_size=block_size, stride=block_size
)
layout = layout.to(torch.long)
if _should_squeeze:
layout.squeeze_(0)
return layout
def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor:
r"""
Use the additive bias computation from ALiBi_ to generate a mask.
Note that this mask can in turn be used to generate a blocksparse attention computation layout
.. note: mask_shape is expected to hold the [heads, seq, seq] dimensions
.. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf
"""
# CREDITS: code snippet from Ofir Press, one of the authors
def get_slopes(n: int):
def get_slopes_power_of_2(n: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some a. This function has
# some good properties that only occur when the input is a power of 2. To maintain that even
# when the number of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = mask_shape[1]
attn_heads = mask_shape[0]
slopes = torch.Tensor(get_slopes(attn_heads))
# In the next line, the part after the * is what constructs the diagonal matrix
# (right matrix in Figure 3 in the paper).
# If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3,
# but one where all rows are identical.
# This works because the softmax operation is invariant to translation,
# and our bias functions are always linear.
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(
0
).unsqueeze(0).expand(attn_heads, -1, -1)
alibi = alibi.view(attn_heads, 1, maxpos)
# Now threshold arbitrarily, report the mask
return alibi < threshold
def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def layout_to_pattern(layout: torch.Tensor, block_size: int):
r"""
create a pattern of shape [heads, seq, seq] out of a blocksparse
layout of shape [heads, seq/block_size, seq/block_size]
"""
return torch.kron(layout, torch.ones(block_size, block_size))

View File

@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
import torch.nn as nn
from xformers.components.attention import AttentionMask
@dataclass
class AttentionConfig:
"""Parameters required for all Attentions.
Can accept and store extra parameters.
"""
name: str # the registered name for this attention mechanism
dropout: float # dropout probability
Self = TypeVar("Self", bound="Attention")
# Define the common interface, every attention block needs to derive from it
class Attention(nn.Module, metaclass=ABCMeta):
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
_causal_mask: Optional[AttentionMask] = None
@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
super().__init__()
# Requires the inputs to be projected
self.requires_input_projection = True
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
self.requires_head_dimension = False
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
self.requires_separate_masks = False
# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False
# Whether the attention owns the single head/multihead mechanism
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False
# This attention requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False
# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)
@abstractmethod
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
"""
If the sequence is shorter than the mask, return a padded view
"""
if x.shape[-2] != mask.shape[-1]:
assert x.shape[-2] < mask.shape[-1], (
"Sequence is bigger than the provided mask, cannot infer what to do with it."
" Please update your attention mask"
)
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
return x

View File

@@ -0,0 +1,190 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
import torch
from xformers import _is_triton_available
from xformers.components.attention import Attention, AttentionConfig, register_attention
logger = logging.getLogger("xformers")
_is_blocksparse_available = _is_triton_available()
if _is_blocksparse_available:
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore
from xformers.triton.utils import gpu_capabilities_older_than_70
# Blocksparse requires Tensor cores
if gpu_capabilities_older_than_70():
logger.warning(
"Blocksparse is not available: the current GPU does not expose Tensor cores"
)
_is_blocksparse_available = False
if _is_blocksparse_available:
@dataclass
class BlockSparseAttentionConfig(AttentionConfig):
layout: torch.Tensor # The dimensions of the random features
block_size: int
dropout: float
num_heads: int
@register_attention("blocksparse", BlockSparseAttentionConfig)
class BlockSparseAttention(Attention):
r"""
Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout.
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
If some dimensions are missing, we assume that the same layout is to be used across heads.
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
"""
def __init__(
self,
layout: torch.Tensor,
block_size: int = 16,
dropout: float = 0.0,
num_heads: int = 1, # optional, used to adapt the layout if in need
causal: bool = False,
*args,
**kwargs,
):
if layout.dim() == 2:
logger.warning(
"The layout passed is lacking a head dimension and a batch dimension"
)
logger.warning(
"Now assuming that the same layout is to be used across all heads"
)
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
logger.warning(f"New layout dimensions: {layout.shape}")
assert block_size in (
16,
32,
64,
128,
), "Only block sizes in [16, 32, 64, 128] are supported"
super().__init__()
self.causal = causal
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Pure blocksparse data
self.layout = layout
self.block_size = block_size
# make sure that the head dimension is not folded down with the batch
self.requires_head_dimension = True
# key padding mask and attention mask must be passed in separately
self.requires_same_k_q_dimensions = True
# The underlying triton op does not support per element attention mask
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def create_triton_kernels(self, device):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=device,
)
self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=device,
)
self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=device,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
assert (
"att_mask" not in kwargs.keys() and "att_mask" not in args
), "This attention does not support an attention mask, but you can specify causality."
r"""
A thin wrap around the Triton blockparse attention operation
.. note: Per element attention mask is not supported, but you can specify causality
"""
# Delayed triton init, to make sure that we get the right device
# Infer device from query
if not hasattr(self, "sparse_dot_sdd"):
self.create_triton_kernels(q.device)
assert (
q.shape[-2] == k.shape[-2]
), "Blocksparse requires the same dimensions for K and Q for now"
assert (
q.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
q.shape[-2] % self.block_size
) == 0, "Sequence length {} must be a multiple of block size {}".format(
q.shape[-2], self.block_size
)
# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
# When the computations are block sparse, the matrix types change along the way:
# - (sparse) attention matrix = (dense) Kt * (dense) Q
q = q / math.sqrt(q.size(-1))
sparse_att_mat = self.sparse_dot_sdd(q, k)
# - softmax on the sparse attention matrix
sparse_att_mat = self.sparse_softmax(
sparse_att_mat, scale=scale, is_causal=self.causal
)
sparse_att_mat = self.attn_drop(sparse_att_mat)
# - then (dense) attention is (sparse) attention matrix * dense (value)
a = self.sparse_dot_dsd(sparse_att_mat, v)
return a

View File

@@ -0,0 +1,341 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Credits: this is heavily inspired by the official implementation, present in
# https://github.com/sarthmit/Compositional-Attention
# Original author: Sarthak Mittal
# This is a simplified version, for the sake of clarity, and because some features could be exposed later
# via the library directly.
# In particular, code paths for TPUs, quantization and gumbel softmax have been removed
# We're also following the same dimension ordering as in the rest of the xformers library
# which is to say [Batch, Sequence, Embedding] wherever possible
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.input_projection import InputProjection, InputProjectionConfig
def _either_or(a: Optional[int], b: int) -> int:
return a if a is not None else b
@dataclass
class CompositionalAttentionConfig(AttentionConfig):
dim_model: int
num_heads: int
dim_attn: Optional[int] = None
num_rules: Optional[int] = None
dim_key: Optional[int] = None
dim_value: Optional[int] = None
dim_selection: Optional[int] = None
dropout: float
qk_rule: bool = False
nonlinear: bool = False
q_compose: bool = False
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InputProjection] = None
use_separate_proj_weight: Optional[bool] = False
@register_attention("compositional", CompositionalAttentionConfig)
class CompositionalAttention(Attention):
"""Compositional Attention, as proposed in
"Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
A key insight from this proposal is that the attention mechanism can be conceived as two steps:
a search and a retrieval operation. When queried, the model can search for the most relevant information
(Softmax(QKt)), then retrieve information given the Value.
Contrary to the original attention proposal, which does not consider interactions in between heads,
the compositional attention will consider all possible interactions and softmax over that dimension,
so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
may not fit in memory.
Args:
dim_model: dimension of the incoming latent space
num_heads: number of heads *for the search operation*
dim_attn: dimension (embedding) of the attention
num_rules: number of rules to consider *for the retrieval operation*
dim_selection: dimension of the scoring/selection space for the retrievals
dim_key, dim_value: dimensions of K and V, if different from Q
dropout: attention dropout probability
qk_rule: QK product will drive the retrieval process
nonlinear: use a non linear method to score the retrievals
bias: use bias in the initial projection step
causal: causal computations (attend to the past only)
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
"""
def __init__(
self,
dim_model: int,
num_heads: int,
dim_attn: Optional[int] = None,
num_rules: Optional[int] = None,
dim_selection: Optional[int] = None,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
dropout=0.0,
qk_rule=False,
nonlinear=False,
q_compose=False,
in_proj_container: Optional[InputProjection] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
*_,
**__,
):
super().__init__()
# Define the inherited flags
self.requires_skip_multi_head = (
True # This attention owns the multi-head mechanism
)
# Handle defaults / undefined values
self.dim_model = dim_model
num_rules = _either_or(num_rules, num_heads)
dim_selection = _either_or(dim_selection, dim_model // num_heads)
# All the initial definition plumbing
dim_attn = _either_or(dim_attn, dim_model)
dim_key = _either_or(dim_key, dim_model)
dim_value = _either_or(dim_value, dim_model)
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InputProjection(
query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
else None,
value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
if use_separate_proj_weight
else None,
)
)
self.num_heads = num_heads
self.num_rules = num_rules
self.qk_rule = qk_rule
self.dim_selection = dim_selection
self.nonlinear = nonlinear
self.q_compose = q_compose
self.dropout_module = nn.Dropout(dropout)
self.dim_head = dim_model // num_heads
self.value_dim = dim_attn // num_rules
assert (
self.value_dim * num_rules == dim_attn
), "value_dim must be divisible by num_rules"
self.scaling = self.dim_head**-0.5
self.scaling_values = self.dim_selection**-0.5
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
if self.qk_rule:
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
else:
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
if self.nonlinear:
self.score_network: nn.Module = nn.Sequential(
nn.Linear(
self.dim_selection + self.value_dim,
self.dim_selection,
bias=bias,
),
nn.ReLU(),
nn.Linear(self.dim_selection, 1, bias=bias),
)
else:
self.score_network = nn.Linear(
self.dim_selection + self.value_dim, 1, bias=bias
)
self.causal = causal
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
self._reset_parameters()
def _reset_parameters(self):
# NOTE: in_proj_container is already initialized
if self.qk_rule:
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.value_q.weight)
if self.nonlinear:
nn.init.xavier_uniform_(self.score_network[0].weight)
nn.init.xavier_uniform_(self.score_network[2].weight)
else:
nn.init.xavier_uniform_(self.score_network.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
att_mask: Optional[Tensor] = None,
*args,
**kwargs,
) -> Tensor:
"""
Input shape: Time x Batch x Channel
Args:
att_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
B, Sq, E = q.shape
_, Sk, _ = k.shape
assert E == self.dim_model
# First define projected query/key/values
# We keep the projected and original tensors in flight,
# depending on the options the original values could be reused
q_unprojected = q
q, k, v = self.in_proj_container(query=q, key=k, value=v)
q *= self.scaling
# Init causal mask if needed, now that we know the context length
if self.causal and (
self._causal_mask is None or self._causal_mask.shape[0] != Sk
):
self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
# Convenience, create an attention mask if a tensor was passed
# This sanitizes different mask types being passed, from now on it's additive
if isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask_additive: Optional[AttentionMask] = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
else:
att_mask_additive = None
# Handle the attention and key padding masks
if self._causal_mask is not None:
# Optionally add the causal mask
if att_mask_additive is not None:
att_mask_additive += self._causal_mask
else:
att_mask_additive = self._causal_mask
# Flatten the heads or the rules
q = (
q.view(B, Sq, self.num_heads, self.dim_head)
.movedim(2, 1)
.flatten(0, 1) # [B * num_heads, Sq, dim_head]
)
k = (
k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
) # [B * num_heads, Sk, dim_head]
v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
# Compute the search: Softmax(QKt)
attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
if att_mask_additive is not None:
attn_weights += att_mask_additive.values
attn_weights = _softmax(attn_weights, causal=self.causal)
attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
attn_probs = self.dropout_module(attn_weights)
# Now compute the information retrieval
# keep all the heads in flight, we'll score the different possibilities
# - compute all the possible retrievals
v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
attn_probs = attn_probs.unsqueeze(2)
attn = torch.matmul(attn_probs, v).view(
B, self.num_heads, self.num_rules, Sq, self.value_dim
)
attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
# - search the most appropriate retrieval among all the values
if self.q_compose:
v_q = self.value_q(q.transpose(0, 1)).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
else:
v_q = self.value_q(q_unprojected).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
if self.qk_rule:
v_q *= self.scaling_values
v_k = (
self.value_k(attn)
.view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
.transpose(4, 3)
.contiguous()
)
v_score = torch.matmul(v_q, v_k).view(
B, Sq, self.num_heads, self.num_rules, 1
)
else:
v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
v_in = torch.cat([attn, v_q], dim=-1)
v_score = self.score_network(v_in).view(
B, Sq, self.num_heads, self.num_rules, 1
)
v_score = F.softmax(v_score, dim=3)
# - extracted values are the original attention (inc. all the values) weighted by value score
attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
# Final attention projection, same as other mechanisms
attn = self.out_proj(attn)
return attn

View File

@@ -0,0 +1,342 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from contextlib import nullcontext
from functools import lru_cache
from typing import Optional, Union
import torch
from xformers import _has_cpp_library, _is_triton_available
from xformers.components.attention.attention_mask import AttentionMask
if _has_cpp_library:
from ._sputnik_sparse import SparseCS
if _is_triton_available():
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70
_is_blocksparse_available = (
_is_triton_available() and not gpu_capabilities_older_than_70()
)
if _is_blocksparse_available:
from xformers.components.attention.blocksparse import BlockSparseAttention
logger = logging.getLogger("xformers")
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
assert matrix.ndim == 3
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
nonzero = torch.nonzero(keep)
nnz = nonzero.shape[0]
# NOTE: need to make it a multiple of 4 for sputnik
nonzero = nonzero[: (nnz - nnz % divisible_by)]
i, j = nonzero.unbind(1)
output = torch.zeros_like(matrix)
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
output[bdim, i, j] = matrix[bdim, i, j]
return output
def _broadcast_batch(mask, batch_size):
if mask.ndim == 3:
return mask
assert mask.ndim == 2
mask = mask.coalesce()
values = mask.values()
indices = mask.indices()
nnz = len(values)
# strategy: repeat the indices and append the extra batch dimension to the indices
indices = indices.repeat(1, batch_size)
# now create the batch indices
batch_indices = torch.arange(batch_size, device=indices.device)
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
# put them together
indices = torch.cat([batch_indices[None, :], indices], dim=0)
# now repeat the values
values = values.repeat(batch_size)
size = (batch_size,) + mask.shape
return torch.sparse_coo_tensor(indices, values, size)
def _matmul_with_mask(
a: torch.Tensor,
b: torch.Tensor,
mask: Optional[Union[torch.Tensor, "SparseCS"]],
) -> torch.Tensor:
if mask is None:
return a @ b
if _has_cpp_library and mask.dtype == torch.bool:
if isinstance(mask, SparseCS):
return mask.matmul_with_mask(a, b)
if mask.is_sparse:
# perform broadcasting if needed
mask = _broadcast_batch(mask, a.shape[0])
# coalesced is not implemented for bool tensors, so need to cast
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
return torch.ops.xformers.matmul_with_mask(a, b, mask)
# Non optimized codepath
if _has_cpp_library:
assert not isinstance(mask, SparseCS)
att = a @ b
if mask.dtype == torch.bool:
assert not isinstance(mask, SparseCS)
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
# repeat if batch sizes don't match
if (
not isinstance(mask, SparseCS)
and mask.ndim == 3
and mask.shape[0] != att.shape[0]
and (att.shape[0] % mask.shape[0]) == 0
):
repeat_factor = att.shape[0] // mask.shape[0]
mask = mask.repeat([repeat_factor, 1, 1])
logger.info("Mismatched batch dimensions for mask, repeating mask.")
att += mask
return att
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
if _has_cpp_library and isinstance(a, SparseCS):
return a.softmax()
if a.is_sparse:
return torch.sparse.softmax(a, dim=a.ndim - 1)
if _is_triton_available():
return triton_softmax(a, mask=None, causal=causal)
else:
return torch.softmax(a, dim=a.ndim - 1)
if _has_cpp_library:
class SparseBMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
a = a.coalesce()
r = torch.bmm(a, b)
ctx.save_for_backward(a, b)
return r
@staticmethod
def backward(ctx, grad):
a, b = ctx.saved_tensors
# gradients w.r.t. a
ga = None
if ctx.needs_input_grad[0]:
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
# gradients w.r.t. b
gb = None
if ctx.needs_input_grad[1]:
gb = a.transpose(1, 2).bmm(grad)
return ga, gb
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Batch matrix multiply between a sparse matrix and a dense matrix
"""
assert a.ndim == b.ndim == 3
assert a.shape[0] == b.shape[0]
assert a.shape[2] == b.shape[1]
return SparseBMM.apply(a, b)
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
if _has_cpp_library:
if isinstance(a, SparseCS):
return a.spmm(b)
if a.is_sparse:
return _sparse_bmm(a, b)
return a @ b
def _apply_dropout(att, dropout):
if dropout is None:
return att
# Dropout chokes on sparse tensors
if _has_cpp_library:
if isinstance(att, SparseCS):
values = att.values.clone()
values = dropout(values)
att = SparseCS.wrap(
att.shape,
values,
att.row_indices,
att.row_offsets,
att.column_indices,
att._transp_info,
)
elif att.is_sparse:
att = att.coalesce()
values = att.values().clone() # protect against in-place dropout
values = dropout(values)
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
else:
# Simple dense case
att = dropout(att)
return att
# Non optimized vanilla dropout
att = dropout(att)
return att
def scaled_query_key_softmax(
q: torch.Tensor,
k: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
) -> torch.Tensor:
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
# this is needed due to limitations in sparse_bmm for now
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
q = q / math.sqrt(k.size(-1))
# Matmul with mask
if att_mask is not None and isinstance(att_mask, AttentionMask):
# Additive mask
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
else:
mask = att_mask
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
# Softmax to get the attention probabilities
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
att = _softmax(att, causal=is_causal)
return att
if _is_blocksparse_available:
# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache
blocks = seq_len // block_size
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(
layout=layout_fill, block_size=block_size, causal=True
)
def blocksparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
orig_dim = q.dim()
seq_len = q.shape[-2]
# Layout head dimension: 1 or batch size (q.shape[0])
layout_heads = 1
# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
if orig_dim == 3:
# Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh
# Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs)
if layout_heads == 1:
q = q.unsqueeze(1)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
else:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
# Dropout is a no-op in evaluation mode
if isinstance(dropout, torch.nn.Dropout):
blocksparse_attention.attn_drop = dropout
else:
blocksparse_attention.attn_drop = torch.nn.Dropout(0.0)
att = blocksparse_attention(q, k, v)
# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return att.flatten(0, 1)
return att
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
autocast_disabled = (
_has_cpp_library
and isinstance(att_mask, SparseCS)
or (att_mask is not None and att_mask.is_sparse)
)
seq_len = q.shape[-2]
# switch if:
# causal is required but mask is not sparse
# fp16 or under amp context
# sequence length is divisible by block size
# same seq len for K and Q
switch_to_blocksparse = (
_is_blocksparse_available
and (att_mask is not None and not att_mask.is_sparse)
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
)
if switch_to_blocksparse:
logger.info("Switching causal attention to Triton blocksparse...")
return blocksparse_attention(q, k, v, dropout, block_size)
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore
if autocast_disabled:
q, k, v = q.float(), k.float(), v.float()
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
# Optional dropout, could be part of the masking in the future
att = _apply_dropout(att, dropout)
# Get to the predicted values, for all heads
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
y = bmm(att, v)
return y

View File

@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.feature_maps import (
FeatureMap,
FeatureMapType,
SMHyperbolic,
SMOrf,
SMReg,
)
logger = logging.getLogger("xformers")
@dataclass
class FavorAttentionConfig(AttentionConfig):
causal: Optional[bool]
dim_features: Optional[int] = None # The dimensions of the random features
dim_head: Optional[
int
] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
iter_before_redraw: Optional[
int
] = None # The number of iterations before the random features are re-drawn from scratch
feature_map: Optional[FeatureMapType] = None
@register_attention("favor", FavorAttentionConfig)
class FavorAttention(Attention):
def __init__(
self,
causal: bool = False,
dropout: float = 0.0,
dim_features: Optional[int] = None,
dim_head: Optional[int] = None,
iter_before_redraw: Optional[int] = None,
feature_map_type: FeatureMapType = FeatureMapType.SMReg,
normalize_inputs: bool = False,
*_,
**__,
):
r"""
Kernelized attention, as proposed in Performers_
("Rethinking attention with performers." K. Choromanski et al. (2020).).
FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
Args:
dropout (float): the probability of an output to be randomly dropped at training time
dim_features (int): the dimension of the random features space
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
feature_map_type (FeatureMapType): the type of feature map being used,
for instance orthogonal random features.
.. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
"""
super().__init__()
self.causal = causal
self.iter_before_redraw = (
(2 * iter_before_redraw)
if iter_before_redraw is not None
else iter_before_redraw
) # This will be used for both key and query
self.normalize_inputs = normalize_inputs
self.feature_map_type = feature_map_type
self.attn_drop = nn.Dropout(dropout, inplace=True)
# Setup dimension-dependent variables
# Reasonable dimension default
if dim_features is None:
assert dim_head is not None, "dim_features or dim_head needs to be passed"
self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
self.dim_features = 2 * (
self.dim_features // 2
) # needs to be even for some variants
logger.info(
f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
)
else:
self.dim_features = dim_features
feature_map_constructor = {
FeatureMapType.SMHyp: SMHyperbolic,
FeatureMapType.SMReg: SMReg,
FeatureMapType.SMOrf: SMOrf,
}[self.feature_map_type]
feature_settings = {
"dim_features": self.dim_features,
"iter_before_redraw": self.iter_before_redraw,
"normalize_inputs": self.normalize_inputs,
}
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.supports_key_padding_mask = False
@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
# Only promote fp16 buffers, bfloat16 would be fine for instance
return x.float() if x.dtype == torch.float16 else x
@staticmethod
def _causal_attention(
k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Algorithm 1 in the paper
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
Grenorm = k_prime.unsqueeze(3) * ref_v
# Consolidate against the feature dimension
att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
# Cumulative sum over the sequence
att_raw = att_raw.cumsum(2)
att_norm = att_norm.cumsum(2)
return att_raw, att_norm
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*_,
**__,
):
# Project key and queries onto the feature map space
k_prime = self.feature_map(k)
q_prime = self.feature_map(q)
with autocast(enabled=False):
# The softmax kernel approximation for Favor will easily overflow
# Force the computations here to stay in fp32 for numerical stability
# Note that the dimensions are vastly reduced when compared to scaled_dot_product
k_prime = self._maybe_promote(k_prime)
q_prime = self._maybe_promote(q_prime)
v = self._maybe_promote(v)
if not self.causal:
att_normalization = q_prime @ (
k_prime.transpose(-2, -1) @ torch.ones_like(v)
)
att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
else:
# Actually compute attention
att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
# Normalize
att = att_raw / att_normalization
if self.attn_drop is not None:
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from .base import FeatureMap, FeatureMapConfig
from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg
class FeatureMapType(str, Enum):
SMOrf = "sm_orf"
SMHyp = "sm_hyp"
SMReg = "sm_reg" # regularized softmax kernel
__all__ = [
"SMOrf",
"SMReg",
"SMHyperbolic",
"NormDistribution",
"FeatureMapConfig",
"FeatureMap",
]

View File

@@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
"""
Feature maps allow for a given query or key to be encoded in a different space.
"""
Self = TypeVar("Self", bound="FeatureMap")
@dataclass
class FeatureMapConfig:
name: str
dim_features: int
iter_before_redraw: Optional[int]
normalize_inputs: Optional[bool]
epsilon: Optional[float]
class FeatureMap(torch.nn.Module):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int] = None,
normalize_inputs: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.dim_features = dim_features
self.dim_feature_map = dim_features
self.iter_before_redraw = iter_before_redraw
self.features: Optional[torch.Tensor] = None
self.epsilon = epsilon
self.normalize_inputs = normalize_inputs
self._iter_counter = 0
@abstractmethod
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
raise NotImplementedError()
@classmethod
def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,288 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum, auto
from typing import Optional
import torch
from torch.autograd.profiler import record_function
from .base import FeatureMap
"""
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
class NormDistribution(Enum):
Xi = auto()
Uniform = auto()
class SoftMaxPositiveEstimators(FeatureMap):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
self.softmax_temp = softmax_temp
# Handle the scaling from all kernels by √m.
# This normalizes for all the feature maps involved
self.h_scale = math.log(math.sqrt(self.dim_features))
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
with record_function("feature_map::pre_scale"):
# Re-draw counting logic
if (
(
self.iter_before_redraw is not None
and self._iter_counter > self.iter_before_redraw
)
or self.features is None
or self.features.device != x.device
):
# The feature map is actually using half the dimension, we'll concatenate + and - features
self._iter_counter = 1
self.features = self._get_feature_map(
x.shape[-1], self.dim_feature_map, x.device
)
features = self.features
assert features is not None
if features.dtype != x.dtype:
self.features = features.to(x.dtype)
self._iter_counter += 1
# Normalization / softmax
if self.softmax_temp < 0:
# A = exp(QK.t/√d), so each input will be scaled by √√d
self.softmax_temp = x.shape[-1] ** -0.25
x_scaled = x * self.softmax_temp
# Compute the scaling factors in logspace, applied from within the exponential
# - dimnish possible exponential overflow
# - remove a multiply across the batch, replace by an addition
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
if self.normalize_inputs:
# L0 normalize the exponential term, can be useful for numerical stability
# This ensures that features +- offset is below 1
self.offset -= norm_x_2.max(1, keepdim=True)[0]
# Return the scaled inputs, the rest depends on the kernel being used
return x_scaled
@staticmethod
@torch.no_grad()
def _get_random_ortho_matrix(
blocks: int,
dim: int,
device: torch.device,
norm_distribution: NormDistribution = NormDistribution.Uniform,
) -> torch.Tensor:
r"""
Generate a random matrix whose rows are exactly orthonormal
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
https://arxiv.org/pdf/math-ph/0609050v2.pdf
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
unique and the qr decomposition routines are biased towards numerical stability. See the above
paper for more information.
.. note: this does not follow the original implementation from the Performers authors.
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
"""
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
# Randomly scale the norms of the features, Xi distributed
if norm_distribution == NormDistribution.Xi:
# NOTE: This averages to sqrt(d)
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
Q, R = torch.linalg.qr(H)
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
if norm_distribution == NormDistribution.Xi:
return torch.diag_embed(norms) @ Q
return Q
class SMOrf(SoftMaxPositiveEstimators):
"""
"Positive random orthogonal features" softmax estimator,
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
assert self.features is not None
# Project onto the random feature map.
x_scaled = x_scaled @ self.features
return torch.exp(x_scaled + self.offset)
class SMHyperbolic(SoftMaxPositiveEstimators):
"""
"Positive random features hyperbolic" estimator, SMHyp+,
as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)
class SMReg(SoftMaxPositiveEstimators):
"""
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Uniform,
device=device,
).flatten(0, 1)
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
return (torch.diag(norms) @ features)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation + sample regularization)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
@register_attention("fourier_mix", AttentionConfig)
class FourierMix(Attention):
def __init__(self, dropout: float, *_, **__):
"""
FFT-based pseudo-attention mechanism, from
"
"FNet: Mixing Tokens with Fourier Transforms"
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.requires_input_projection = False
def forward(self, q: torch.Tensor, *_, **__):
# Guard against autocast / fp16, not supported by torch.fft.fft2
with autocast(enabled=False):
att = torch.fft.fft2(q).real
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
global_token_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class GlobalAttentionConfig(AttentionConfig):
attention_query_mask: torch.Tensor # Mark the queries which have global attention
causal: Optional[bool]
force_sparsity: Optional[bool]
@register_attention("global", GlobalAttentionConfig)
class GlobalAttention(Attention):
def __init__(
self,
dropout: float,
attention_query_mask: torch.Tensor,
causal: bool = False,
force_sparsity: bool = False,
*_,
**__,
):
r"""
Global attention, as proposed for instance in BigBird_ or Longformer_.
Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
any other query.
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
Args:
dropout (float): probability of an element to be zeroed
attention_query_mask (torch.Tensor): if true, this query can attend to all the others
"""
super().__init__()
assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
assert (
attention_query_mask.shape[1] == 1
and attention_query_mask.shape[0] > attention_query_mask.shape[1]
), "A N x 1 query mask is expected"
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
self.force_sparsity = force_sparsity
if causal:
self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
self.attention_mask = (
sparsify(self.attention_mask)
if self.force_sparsity
else maybe_sparsify(self.attention_mask)
)
# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*_,
**__,
):
# Make sure that the mask is on the right device
if self.attention_mask.device != q.device:
self.attention_mask = self.attention_mask.to(q.device)
# Mask-aware attention
if att_mask is not None:
if att_mask.dtype == torch.bool and isinstance(
self.attention_mask, AttentionMask
):
if not isinstance(att_mask, AttentionMask):
att_mask = AttentionMask.from_bool(att_mask)
mask = self.attention_mask + att_mask
else:
mask = self.attention_mask & att_mask
else:
mask = self.attention_mask
# Handle q/k/v which would not fit the mask
seq_len = q.shape[-2]
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
# Normal attention with the global tokens mask
att = scaled_dot_product_attention(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)
# Take into account an hypothetical padding
return att[:, :seq_len, :]

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from xformers.components.attention import Attention, AttentionConfig, register_attention
def calc_rel_pos(n: int):
# Adapted from LucidRains
# https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n]
rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
return rel_pos
@dataclass
class LambdaLayerConfig(AttentionConfig):
seq_len: int # dimension of the input sequence
dim_head: int
@register_attention("lambda", LambdaLayerConfig)
class LambdaLayer(Attention):
def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
"""
Attention approximation using Lambda layers, from
"Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
"""
super().__init__()
# Possible extensions:
# - support different dimensions for key and queries
# - support varying dimensions in between inputs and outputs
# - support u hyperparam
self.rel_pos_emb = torch.nn.Parameter(
torch.randn(2 * seq_len - 1, int(dim_head))
)
self.rel_pos = calc_rel_pos(seq_len)
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
"""..NOTE: We're reusing the einsum notation suggested by the paper, changed in that
heads are folded in the batch dimension"""
content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v)
content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda)
rel_pos_emb = self.rel_pos_emb[self.rel_pos]
# Handle real sequence length being possibly smaller
seq_len = q.shape[1]
rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :]
# Compute the position lambda for every possible combination in one go, then compute the
# position related contribution
position_lambdas = torch.einsum(
"mnk,bnv->bnkv", rel_pos_emb, v
) # one lambda per position
position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
att = content_output + position_output
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LinformerSelfAttentionConfig(AttentionConfig):
seq_len: int # dimension of the input sequence
k: Optional[int] # dimension of the internal space
@register_attention("linformer", LinformerSelfAttentionConfig)
class LinformerAttention(Attention):
def __init__(
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
):
"""
Linformer attention mechanism,
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
The original notation is kept as is.
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
"""
super().__init__()
if k is None:
k = seq_len // 4
self.k = k
self.E = nn.Linear(seq_len, k, bias=False)
self.F = nn.Linear(seq_len, k, bias=False)
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.seq_len = seq_len
# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = True
# This attention does not support attention masks
self.supports_attention_mask = False
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
# Handle a smaller dimension than expected
padding = 0
if q.shape[1] < self.seq_len:
padding = self.seq_len - q.shape[1]
pad_dims = (0, 0, 0, padding)
q = torch.nn.functional.pad(q, pad_dims)
k = torch.nn.functional.pad(k, pad_dims)
v = torch.nn.functional.pad(v, pad_dims)
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
y = scaled_dot_product_attention(
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
)
y = self.attn_drop(y)
return y[:, :-padding, :] if padding > 0 else y

View File

@@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
local_1d_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LocalAttentionConfig(AttentionConfig):
causal: Optional[bool] = None
window_size: Optional[int] = None
force_sparsity: Optional[bool] = None
@register_attention("local", LocalAttentionConfig)
class LocalAttention(Attention):
def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
window_size: int = 5,
force_sparsity: bool = False,
*args,
**kwargs,
):
r"""
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
Args:
dropout (float): the probability of an output to be randomly dropped at training time
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
window_size (int): the overall window size for local attention.
Odd number is expected if the mask is not causal, as the window size will be evenly
distributed on both sides of each query
.. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
.. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
"""
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.force_sparsity = force_sparsity
if not self.causal:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
self.window_size = window_size
self.attention_mask: Optional[torch.Tensor] = None
self.requires_same_k_q_dimensions = True
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
mask = local_1d_pattern(shape[1], window_size)
if self.causal:
mask &= causal_1d_pattern(shape[1])
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
return mask
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
# Local window attention masking
if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
self.attention_mask = self._get_local_mask(q.shape).to(q.device)
# Take into account the optional user mask
if att_mask is None:
mask = self.attention_mask
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.attention_mask & att_mask
return scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
)

View File

@@ -0,0 +1,295 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)
logger = logging.getLogger("xformers")
@dataclass
class NystromSelfAttentionConfig(AttentionConfig):
"""
num_heads Number of heads.
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
causal Apply a causal mask, in that the attention cannot be applied to the future.
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
inverse, otherwise use standard torch inverse.
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
method from (Razavi et al. 2014).
False if using exact coefficient computation (leads to faster convergence).
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
v_skip_connection A module that will take V as input and will be added as a skip connection to the
softmax approximation. A skip connection is added in the paper to help with training.
conv_kernel_size Kernel size for convolution optionally added to help in training.
If v_skip_connection is not specified, this will be used to define the default
depth wise convolution used as a skip connection.
If both conv_kernel_size and v_skip_connection are None, no skip connection will
be added.
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
"""
num_heads: int
num_landmarks: Optional[int]
landmark_pooling: Optional[nn.Module]
causal: Optional[bool]
pinverse_original_init: Optional[bool]
inv_iterations: Optional[int]
v_skip_connection: Optional[nn.Module]
conv_kernel_size: Optional[int]
use_razavi_pinverse: Optional[bool]
class AvgPool(nn.Module):
def __init__(self, n: int):
super().__init__()
self.n = n
def forward(self, x: torch.Tensor):
# Average independently for every segment in the sequence dimension
seq_len = x.shape[1]
head_dim = x.shape[2]
segments = seq_len // self.n
assert segments > 0, "num_landmarks should be smaller than the sequence length"
# Dimensions are a match
if seq_len % self.n == 0:
return x.reshape(
-1,
self.n,
segments,
head_dim,
).mean(dim=-2)
# Handle the last segment boundary being off
n_round = self.n - seq_len % self.n
x_avg_round = (
x[:, : n_round * segments, :]
.reshape(-1, n_round, segments, head_dim)
.mean(dim=-2)
)
x_avg_off = (
x[:, n_round * segments :, :]
.reshape(-1, self.n - n_round, segments + 1, head_dim)
.mean(dim=-2)
)
return torch.cat((x_avg_round, x_avg_off), dim=-2)
@register_attention("nystrom", NystromSelfAttentionConfig)
class NystromAttention(Attention):
# TODO: update defaults for use_razavi_pinverse and inv_iterations
def __init__(
self,
dropout: float,
num_heads: int,
num_landmarks: int = 64,
landmark_pooling: Optional[nn.Module] = None,
causal: bool = False,
use_razavi_pinverse: bool = True,
pinverse_original_init: bool = False,
inv_iterations: int = 6, # recommended default in paper was 6.
v_skip_connection: Optional[nn.Module] = None,
conv_kernel_size: Optional[int] = None,
*args,
**kwargs,
):
"""
Nystrom attention mechanism, from Nystromformer_.
::
"A Nystrom-based Algorithm for Approximating Self-Attention."
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
Reference codebase: https://github.com/mlpen/Nystromformer
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
"""
super().__init__()
# merged key padding mask and attention mask is not accepted
self.requires_separate_masks = True
self.num_landmarks = num_landmarks
# TODO: should be able to not have to pass in num_heads
self.num_heads = num_heads
self.use_razavi_pinverse = use_razavi_pinverse
self.pinverse_original_init = pinverse_original_init
self.inv_iterations = inv_iterations
self.attn_drop = nn.Dropout(dropout)
self.skip_connection = v_skip_connection
self.causal = causal
if self.skip_connection is None and conv_kernel_size is not None:
self.skip_connection = nn.Conv2d(
in_channels=self.num_heads,
out_channels=self.num_heads,
kernel_size=(conv_kernel_size, 1),
padding=(conv_kernel_size // 2, 0),
bias=False,
groups=self.num_heads,
)
if landmark_pooling is not None:
self.landmark_pooling = landmark_pooling
else:
self.landmark_pooling = AvgPool(n=self.num_landmarks)
# Optional lower triangular masks for causal attention
self.causal_mask_1: Optional[torch.Tensor] = None
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None
# This attention does not support attention masks
self.supports_attention_mask = False
self.supports_key_padding_mask = True
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
r"""
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
"""
batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}
if key_padding_mask is not None:
if key_padding_mask.dtype == torch.bool:
logger.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)
if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)
zeros = torch.zeros_like(key_padding_mask)
ones = torch.ones_like(key_padding_mask)
is_masked = torch.isinf(-key_padding_mask)
# _mask takes 1 if the token is not padded, otherwise 0.
_mask = torch.where(is_masked, zeros, ones)
_mask = _mask.transpose(2, 1)
assert _mask.shape == (batched_dim, q.shape[1], 1)
# Mask q and k before pooling
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
q = q * _mask
k = k * _mask
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)
if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
else:
q_landmarks = self.landmark_pooling(q)
k_landmarks = self.landmark_pooling(k)
if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
):
self.causal_mask_1 = self._triu_mask(
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._triu_mask(
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._triu_mask(
batched_dim, self.num_landmarks, seq_len, **tt
)
mask_3: Optional[torch.Tensor] = self.causal_mask_3
if key_padding_mask is not None:
mask_3 = (
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
)
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
kernel_2 = scaled_query_key_softmax(
q=q_landmarks, k=k_landmarks, att_mask=None
)
kernel_3 = scaled_dot_product_attention(
q=q_landmarks, k=k, v=v, att_mask=mask_3
)
kernel_2_inv = (
iterative_pinv(
kernel_2, self.inv_iterations, self.pinverse_original_init
)
if self.use_razavi_pinverse
else torch.linalg.pinv(kernel_2)
)
x = torch.matmul(
torch.matmul(
kernel_1,
kernel_2_inv,
),
kernel_3,
)
if self.skip_connection:
# Assumption here is that v is 3D.
v_conv = self.skip_connection(
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
)
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
x = self.attn_drop(x)
return x
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]
return torch.triu(
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
diagonal=1,
).expand(
dim_1, -1, -1
) # micro optim, save memory on the batch dimension

View File

@@ -0,0 +1,324 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
import torch.nn.functional as Fn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
logger = logging.getLogger("xformers")
class LandmarkSelection(str, Enum):
Orthogonal = "orthogonal"
KMeans = "kmeans"
KMeans_Spherical = "kmeans_spherical"
Random = "random"
@dataclass
class OrthoformerAttentionConfig(AttentionConfig):
"""
num_landmarks Number of landmarks to use for softmax approximation.
subsample_fraction Percentage of q_samples matrix to sample per iteration
landmark_selection Landmark selection strategy
"""
num_landmarks: Optional[int]
subsample_fraction: Optional[float]
landmark_selection: Optional[LandmarkSelection]
@register_attention("orthoformer", OrthoformerAttentionConfig)
class OrthoFormerAttention(Attention):
def __init__(
self,
dropout: float,
num_landmarks: int = 32,
subsample_fraction: float = 1.0,
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
*args,
**kwargs,
):
"""
Orthoformer_ attention mechanism.
::
"Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
C., Vedaldi, A., Henriques, J. (2021)
Reference codebase: https://github.com/facebookresearch/Motionformer
.. _Orthoformer: https://arxiv.org/abs/2106.05392
"""
super().__init__()
self.num_landmarks = num_landmarks
self.attn_drop = nn.Dropout(dropout)
self.subsample_fraction = subsample_fraction
self.landmark_selection = landmark_selection
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
):
N = k.shape[1]
if self.num_landmarks == N:
# Default attention
x = scaled_dot_product_attention(q, k, v, att_mask)
else:
with torch.no_grad(), profiler.record_function("select landmarks"):
if self.landmark_selection == LandmarkSelection.Orthogonal:
landmarks = self._compute_orthogonal_landmarks(q)
elif self.landmark_selection == LandmarkSelection.Random:
half_L = self.num_landmarks // 2
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
elif self.landmark_selection == LandmarkSelection.KMeans:
landmarks = self._cluster_landmarks(q)
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
landmarks = self._cluster_landmarks(q, spherical=True)
if att_mask is not None:
logger.warning(
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
The two are typically not compatible"
)
# FIXME: Should we still accept a mask in that case ?
att_mask = None
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
# like it could be uninitialized.
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
# like it could be uninitialized.
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
x = self.attn_drop(x)
return x
def _cluster_landmarks(
self,
q: torch.Tensor,
spherical: bool = False,
num_iters: int = 6,
) -> torch.Tensor:
"""
Construct set of landmarks by recursively selecting new landmarks
that are maximally orthogonal to the existing set.
Returns near orthogonal landmarks with shape (B, M, D).
"""
num_landmarks = min(self.num_landmarks, q.shape[1])
if self.subsample_fraction < 1.0:
num_samples = max(
int(self.subsample_fraction * q.size(-2)), num_landmarks
) # Need at least M/2 samples of queries and keys
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
else:
q_samples = q # (B, N, D)
if spherical:
q_samples_normalized = Fn.normalize(
q_samples, p=2, dim=-1
) # may need to change default eps to eps=1e-8 for mixed precision compatibility
landmarks = self._kmeans_spherical(
q_samples_normalized, num_landmarks, num_iters
)
else:
landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
return landmarks # (B, M, D)
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
"""
Arguments:
x: (B, N, D)
K: number of clusters
num_iters: the number of kmeans updates
"""
B, N, D = x.size()
assert K <= N, f"{K} > {N}"
c = x[
:, torch.randperm(N, device=x.device)[:K], :
].clone() # initialisation for the centroids
with profiler.record_function("kmeans"):
x_i = x.view(B, N, 1, D)
c_j = c.view(B, 1, K, D)
counts = c.new_zeros(B, K)
ones = x.new_ones((B, N))
for _ in range(num_iters):
# E step: assign points to the nearest cluster
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
cl = D_ij.argmin(
dim=-1, keepdim=True
).long() # (B, N, 1) index of point to nearest cluster
# M step: update the centroids
c.zero_()
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
counts.fill_(1e-6) # avoid div0
counts.scatter_add_(
-1, cl.squeeze(-1), ones
) # number of points per cluster
c.divide_(counts.unsqueeze(-1)) # compute the average
return c
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
"""
Arguments:
x: (B, N, D)
"""
B, N, D = x.size()
assert K <= N, f"{K} > {N}"
# initialisation for the centroids
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
with profiler.record_function("kmeans_spherical"):
counts = c.new_zeros(B, K)
ones = x.new_ones((B, N))
for _ in range(num_iters):
# E step: assign points to the nearest cluster
D_ij = torch.matmul(
x, c.transpose(-2, -1)
) # (B, N, K) cosine similarity
cl = D_ij.argmax(
dim=-1, keepdim=True
).long() # (B, N, 1) index of point to nearest cluster
# M step: update the centroids
c.zero_()
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
counts.fill_(1e-6) # avoid div0
counts.scatter_add_(
-1, cl.squeeze(-1), ones
) # number of points per cluster
c.divide_(counts.unsqueeze(-1)) # compute the average
c = Fn.normalize(c, p=2, dim=-1) # renormalise
return c
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
"""
Construct set of landmarks by recursively selecting new landmarks
that are maximally orthogonal to the existing set.
Returns near orthogonal landmarks with shape (B, M, D).
"""
if self.subsample_fraction < 1.0:
# Need at least M samples of queries
num_samples = max(
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
)
q_samples = q[
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
]
else:
# (B, N, D)
q_samples = q
# may need to change default eps to eps=1e-8 for mixed precision compatibility
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
B, N, D = q_samples_normalized.shape
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
landmark_mask = torch.ones(
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
)
#  Get initial random landmark
random_idx = torch.randint(
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
)
selected_mask.scatter_(-2, random_idx, landmark_mask)
#  Selected landmarks
selected_landmarks = torch.empty(
(B, self.num_landmarks, D),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
selected_landmarks[:, 0, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
].view(B, D)
# Store computed cosine similarities
cos_sims = torch.empty(
(B, N, self.num_landmarks),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
for M in range(1, self.num_landmarks):
with profiler.record_function("find new landmark"):
#  Calculate absolute cosine similarity between selected and unselected landmarks
# (B, N, D) * (B, D) -> (B, N)
cos_sims[:, :, M - 1] = torch.einsum(
"b n d, b d -> b n",
q_samples_normalized,
selected_landmarks[:, M - 1, :],
).abs()
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
cos_sim_set = cos_sims[:, :, :M]
#  Get orthogonal landmark: landmark with smallest absolute cosine similarity:
# set cosine similarity for already selected landmarks to > 1
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
# (B,) - want max for non
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
#  Add most orthogonal landmark to selected landmarks:
selected_landmarks[:, M, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
].view(B, D)
#  Removed selected indices from non-selected mask:
selected_mask.scatter_(
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
)
# (B, M, D)
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
B, -1, D
)
return landmarks # (B, M, D)

View File

@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
@dataclass
class PoolingAttentionConfig(AttentionConfig):
pool_size: int # dimension of the input sequence
stride: Optional[int] # dimension of the internal space
padding: Optional[int]
@register_attention("pooling", PoolingAttentionConfig)
class Pooling(Attention):
def __init__(
self,
pool_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
*_,
**__,
):
"""
Pooling token mixing mechanism, as proposed in
`Metaformer is actually what you need for vision`_, Yu et al (2021).
The original notation is kept as is.
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
"""
super().__init__()
padding = padding if padding is not None else pool_size // 2
self.pool = nn.AvgPool2d(
pool_size,
stride=stride,
padding=pool_size // 2,
count_include_pad=False,
)
# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = False
# This attention does not support attention masks
self.supports_attention_mask = False
# This "attention" (token mixing) skips the multihead attention altogether
self.requires_skip_multi_head = True
self.requires_input_projection = False
# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True
# This attention requires the 2d structure out of the context,
# implictly assumed to be a squared length
self.requires_squared_context = True
def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW
q = q.transpose(-2, -1).reshape(B, C, H, H)
# 2D pool
x_pool = self.pool(q) - q # compensate for the residual path
# Get back to B HW C
return x_pool.flatten(2, 3).transpose(-2, -1)

View File

@@ -0,0 +1,126 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
random_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class RandomAttentionConfig(AttentionConfig):
r: Optional[
float
] # the ratio of keys that the query can attend to. 1.0 means dense attention
constant_masking: Optional[
bool
] # whether the randomness is per query or defined at construction time
force_sparsity: Optional[bool] # use sparsity in any case (potentially slower)
@register_attention("random", RandomAttentionConfig)
class RandomAttention(Attention):
def __init__(
self,
dropout: float,
causal: bool = False,
r: float = 0.01,
constant_masking: bool = True,
force_sparsity: bool = False,
*args,
**kwargs,
):
"""
"Random" attention, as proposed for instance in BigBird_.
Random means in that case that each query can attend to a random set of keys.
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
Args:
r (float): the ratio in [0,1] of keys that the query can attend to
constant_masking (bool): if true, keep the same random set for all queries.
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
"""
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.r = r
self.rand_attention_mask: Optional[torch.Tensor] = None
self.constant_masking = constant_masking
self.force_sparsity = force_sparsity
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
self.requires_same_k_q_dimensions = True
def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
sparsity = 1 - self.r
mask = random_pattern(shape[1], sparsity=sparsity)
if self.causal:
mask &= causal_1d_pattern(shape[1])
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
return mask
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
# Rand masking
if not self.constant_masking or self.rand_attention_mask is None:
self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device)
# Mask-aware attention
if att_mask is not None:
if att_mask.dtype == torch.bool and isinstance(
self.rand_attention_mask, AttentionMask
):
mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.rand_attention_mask & att_mask
else:
mask = self.rand_attention_mask
# Handle q/k/v which would not fit the mask
seq_len = q.shape[-2]
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
# Normal attention with the random mask
att = scaled_dot_product_attention(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)
# Take into account an hypothetical padding
return att[:, :seq_len, :]

View File

@@ -0,0 +1,134 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch import nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import scaled_dot_product_attention
logger = logging.getLogger("xformers")
@dataclass
class ScaledDotProductConfig(AttentionConfig):
causal: Optional[bool]
seq_len: Optional[int]
to_seq_len: Optional[int]
@register_attention("scaled_dot_product", ScaledDotProductConfig)
class ScaledDotProduct(Attention):
r"""
Implementing the Scaled Dot-Product attention proposed in
`Attention is all you need`_, Vaswani et al.
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
"""
mask: Optional[AttentionMask]
def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
seq_len: Optional[int] = None,
to_seq_len: Optional[int] = None,
*args,
**kwargs,
):
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.seq_len = seq_len
if causal and seq_len is not None:
self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
else:
self.mask = None
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
r"""
att_mask A 2D or 3D mask which ignores attention at certain positions.
- If the mask is boolean, a value of True will keep the value,
while a value of False will mask the value.
Key padding masks (dimension: batch x sequence length) and attention masks
(dimension: sequence length x sequence length OR batch x sequence length x sequence length)
can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
used for that merging.
- If the mask has the float type, then an additive mask is expected (masked values are -inf)
"""
# Convenience, create an attention mask if a tensor was passed
if att_mask is not None and isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
# Handle a possibly deferred causal mask handling
mask = self.mask
if self.causal and self.mask is None:
mask = AttentionMask.make_causal(
seq_len=q.shape[-2],
to_seq_len=q.shape[-2],
device=q.device,
dtype=q.dtype,
)
# Merge the optional causal mask and the user-provided mask
if mask is not None:
mask = mask.to(dtype=q.dtype, device=q.device)
att_mask = att_mask + mask if att_mask is not None else mask
# Try to handle a case where the sequence is smaller than the mask
if (
att_mask is not None
and q.shape[-2] == k.shape[-2]
and q.shape[-2] < att_mask.shape[1]
):
if isinstance(att_mask, AttentionMask):
att_mask = att_mask.make_crop(seq_len=q.shape[-2])
else:
logger.error(
"Mismatching sparse attention mask and sequence length."
+ " Please pad the inputs or adjust the attention mask"
)
raise NotImplementedError
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
)
return y

View File

@@ -0,0 +1,812 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
The code has been adopted from DeepSpeed
(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py)
"""
import random
import torch
class SparsityConfig:
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
It contains shared property of different block-sparse sparsity patterns. However, each class
needs to extend it based on required property and functionality.
"""
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
"""Initialize the Sparsity Pattern Config.
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be
assigned a different sparsity layout; default is false and this will be satisfied
based on availability.
"""
self.num_heads = num_heads
self.block_size = block_size
self.different_layout_per_head = different_layout_per_head
self.num_layout_heads = num_heads if different_layout_per_head else 1
def setup_layout(self, seq_len):
"""Create layout tensor for the given sequence length
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout
of all head; initialized with zero
"""
if seq_len % self.block_size != 0:
raise ValueError(
f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!"
)
num_blocks = seq_len // self.block_size
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
layout = torch.zeros(
(self.num_heads, num_blocks, num_blocks), dtype=torch.int64
)
return layout
def check_and_propagate_first_head_layout(self, layout):
"""If all heads require same sparsity layout, it propagate first head layout to all heads
Arguments:
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head
"""
if not self.different_layout_per_head:
layout[1 : self.num_heads, :, :] = layout[0, :, :]
return layout
class DenseSparsityConfig(SparsityConfig):
"""Configuration class to store `Dense` configuration.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and
comprehension.
"""
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
"""Initialize the Dense Sparsity Pattern Config.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison
and comprehension.
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: this is just for the sake of consistency with
other sparsity formats; can ignore it for DenseSparsityConfig
"""
super().__init__(num_heads, block_size, different_layout_per_head)
def make_layout(self, seq_len):
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
Arguments:
seq_len: required: an integer determining the underling sequence length;
must be <= max sequence length
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head; for dense everything is 1
"""
layout = self.setup_layout(seq_len)
layout[:, :, :] = 1
return layout
class FixedSparsityConfig(SparsityConfig):
"""Configuration class to store `Fixed` sparsity configuration.
For more details about this sparsity config, please see `Generative Modeling with
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_local_blocks=4,
num_global_blocks=1,
attention="bidirectional",
horizontal_global_attention=False,
num_different_global_patterns=1,
):
"""Initialize `Fixed` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be
assigned a different sparsity layout; default is false and this will be satisfied
based on availability.
num_local_blocks: optional: an integer determining the number of blocks in local attention
window.
num_global_blocks: optional: an integer determining how many consecutive blocks in a local
window is used as the representative of the window for global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global
representative of a local window, also attend to all other blocks. This is valid only if
attention type is `bidirectional`. Looking at the attention matrix, that means global
attention not only includes the vertical blocks, but also horizontal blocks.
num_different_global_patterns: optional: an integer determining number of different global
attentions layouts. While global attention can be fixed by which block/s are representative
of any local window, since there are multi-heads, each head can use a different global representative.
For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different
versions in which the first, Second, third, or forth block of each local window can be global
representative of that window. This parameter determines how many of such patterns we want.
Of course, there is a limitation based on num_local_blocks and num_global_blocks.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_local_blocks = num_local_blocks
if num_local_blocks % num_global_blocks != 0:
raise ValueError(
f"""Number of blocks in a local window, {num_local_blocks},
must be dividable by number of global blocks, {num_global_blocks}!"""
)
self.num_global_blocks = num_global_blocks
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
if attention != "bidirectional" and horizontal_global_attention:
raise ValueError(
'only "bi-directional" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
if num_different_global_patterns > 1 and not different_layout_per_head:
raise ValueError(
"""Number of different layouts cannot be more than one when you have set a single layout
for all heads! Set different_layout_per_head to True."""
)
if num_different_global_patterns > (num_local_blocks // num_global_blocks):
raise ValueError(
f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns},
cannot be larger than number of local window blocks divided by number of global blocks,
{num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!"""
)
self.num_different_global_patterns = num_different_global_patterns
def set_local_layout(self, h, layout):
"""Sets local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
for i in range(0, num_blocks, self.num_local_blocks):
end = min(i + self.num_local_blocks, num_blocks)
for row in range(i, end):
for col in range(
i, (row + 1 if self.attention == "unidirectional" else end)
):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Currently we set global blocks starting from the last block of a local window to the first one.
That means if a local window consists of 4 blocks and global attention size is one block, we use
block #4 in each local window as global. If we have different layout per head, then other heads
will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global
attentions, multiple head may have same global attentions.
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and
vertically.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
first_global_block_idx = (
self.num_local_blocks
- (1 + h % self.num_different_global_patterns) * self.num_global_blocks
)
# set all global blocks except the last one if (in last local window)
end = num_blocks - (num_blocks % self.num_local_blocks)
for i in range(first_global_block_idx, end, self.num_local_blocks):
# vertical global attention
first_row = 0 if self.attention == "bidirectional" else i
# (((i // self.num_local_blocks) + 1) * self.num_local_blocks)
# if (first_row < num_blocks):
layout[h, first_row:, i : i + self.num_global_blocks] = 1
# horizontal global attention; only in bidirectional attention
if self.horizontal_global_attention:
layout[h, i : i + self.num_global_blocks, :] = 1
# set last global blocks; handle possible short last local window
if end < num_blocks:
start = min(
end + first_global_block_idx, num_blocks - self.num_global_blocks
)
end = start + self.num_global_blocks
# vertical global attention
first_row = 0 if self.attention == "bidirectional" else start
# (((start // self.num_local_blocks) + 1) * self.num_local_blocks)
# if (first_row < num_blocks):
layout[h, first_row:, start:end] = 1
# horizontal global attention
if self.horizontal_global_attention:
layout[h, start:end, :] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class VariableSparsityConfig(SparsityConfig):
"""Configuration class to store `Variable` sparsity configuration.
This layout is an extension of FixedSparsityConfig in which:
- user can set random layout; default value is zero means no random block
- user can provide a list of local block sizes
- user can provide a list of global block indices.
For more details about `Fixed` sparsity config, please see `Generative Modeling with
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_random_blocks=0,
local_window_blocks=[4],
global_block_indices=[0],
global_block_end_indices=None,
attention="bidirectional",
horizontal_global_attention=False,
):
"""Initialize `Variable` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of sparse
self-attention is based on blocked sparse matrices. In which this parameter defines
size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a
different sparsity layout; default is false and this will be satisfied based on
availability. Currently this sparsity config can only assign single layout to all heads;
needs to be extended for different layout per head.
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
local_window_blocks: optional: a list of integers determining the number of blocks in each
local attention window. It assumes first number determines # of blocks in the first local
window, second the second window, ..., and the last number determines the number of blocks
in the remaining local windows.
global_block_indices: optional: a list of integers determining which blocks are considered
as global attention. Given indices, determine the blocks that all other token blocks
attend to and they attend to all other token blocks. Default value is only index 0.
Notice that if global_block_end_indices parameter is set, this parameter is used as
starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global
window blocks. By default this is not used. But if it is set, it must have the same size
of global_block_indices parameter, and combining this two parameters, for each index i,
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
considered as global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global
representative of a local window, also attend to all other blocks. This is valid only if
attention type is `bidirectional`. Looking at the attention matrix, that means global
attention not only includes the vertical blocks, but also horizontal blocks.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.local_window_blocks = local_window_blocks
self.global_block_indices = global_block_indices
if global_block_end_indices is not None:
if len(global_block_indices) != len(global_block_end_indices):
raise ValueError(
f"""Global block start indices length, {len(global_block_indices)}, must be same as
global block end indices length, {len(global_block_end_indices)}!"""
)
for _, (start_idx, end_idx) in enumerate(
zip(global_block_indices, global_block_end_indices)
):
if start_idx >= end_idx:
raise ValueError(
f"""Global block start index, {start_idx}, must be smaller than global block end
index, {end_idx}!"""
)
self.global_block_end_indices = global_block_end_indices
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
if attention != "bidirectional" and horizontal_global_attention:
raise ValueError(
'only "bi-directional" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
def set_random_layout(self, h, layout):
"""Sets random attention layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless
`different_layout_per_head` parameter is set in which each head can have a different random
layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_random_blocks:
raise ValueError(
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
for row in range(0, num_blocks):
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_local_layout(self, h, layout):
"""Sets local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
start_block_idx = 0
end_block_idx = 0
for block_size in self.local_window_blocks:
end_block_idx += block_size
end_block_idx = min(end_block_idx, num_blocks)
for row in range(start_block_idx, end_block_idx):
for col in range(
start_block_idx,
(row + 1 if self.attention == "unidirectional" else end_block_idx),
):
layout[h, row, col] = 1
start_block_idx += block_size
# if there is any remaining not attended part, use the lats local window block size as local
# window for the remaining applicable local windows
for i in range(start_block_idx, num_blocks, block_size):
end_block_idx = min(i + block_size, num_blocks)
for row in range(i, end_block_idx):
for col in range(
i,
(row + 1 if self.attention == "unidirectional" else end_block_idx),
):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if self.global_block_end_indices is None:
for idx in self.global_block_indices:
# if global block idx is in the range of the sequence blocks
if idx < num_blocks:
# global rows
if self.horizontal_global_attention:
layout[h, idx, :] = 1
# global columns
first_row = 0 if self.attention == "bidirectional" else idx
layout[h, first_row:, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(
zip(self.global_block_indices, self.global_block_end_indices)
):
# if global block idx is in the range of the sequence blocks
if start_idx < num_blocks:
end_idx = min(end_idx, num_blocks)
# global rows
if self.horizontal_global_attention:
layout[h, start_idx:end_idx, :] = 1
# global columns
first_row = 0 if self.attention == "bidirectional" else start_idx
layout[h, first_row:, start_idx:end_idx] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BigBirdSparsityConfig(SparsityConfig):
"""Configuration class to store `BigBird` sparsity configuration.
For more details about this sparsity config, please see `Big Bird: Transformers for
Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_random_blocks=1,
num_sliding_window_blocks=3,
num_global_blocks=1,
attention="bidirectional",
):
"""Initialize the BigBird Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned
a different sparsity layout; default is false and this will be satisfied based on
availability.
num_random_blocks: optional: an integer determining the number of random blocks in each
block row.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
local attention window.
num_global_blocks: optional: an integer determining how many consecutive blocks, starting
from index 0, are considered as global attention. Global block tokens will be attended
by all other block tokens and will attend to all other block tokens as well.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.num_sliding_window_blocks = num_sliding_window_blocks
self.num_global_blocks = num_global_blocks
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
def set_random_layout(self, h, layout):
"""Sets random attention layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless
`different_layout_per_head` parameter is set in which each head can have a different random layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_random_blocks:
raise ValueError(
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
for row in range(0, num_blocks):
sample_range = (
range(0, num_blocks)
if self.attention == "bidirectional"
else range(0, row + 1)
)
rnd_cols = random.sample(sample_range, self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_sliding_window_blocks:
raise ValueError(
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than
overall number of blocks in a row, {num_blocks}!"""
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout_itc(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_global_blocks:
raise ValueError(
f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
# global rows
layout[h, 0 : self.num_global_blocks, :] = 1
# global columns
layout[h, :, 0 : self.num_global_blocks] = 1
if self.attention == "unidirectional":
# zero out anything attending to the future
layout = torch.tril(layout)
return layout
def make_layout(self, seq_len):
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout_itc(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BSLongformerSparsityConfig(SparsityConfig):
"""Configuration class to store edited `Longformer` sparsity configuration.
Note) this is a block-sparse version of the Longformer which is slightly different than original
Longformer; which is element-wise sparsity.
For more details about this sparsity config, please see `Longformer:
The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_sliding_window_blocks=3,
global_block_indices=[0],
global_block_end_indices=None,
attention="bidirectional",
):
"""Initialize the edited `Longformer` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of sparse
self-attention is based on blocked sparse matrices. In which this parameter defines size
of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a
different sparsity layout; default is false and this will be satisfied based on
availability.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
local attention window.
global_block_indices: optional: a list of integers determining which blocks are considered
as global attention. Given indices, determine the blocks that all other token blocks
attend to and they attend to all other token blocks. Default value is only index 0.
Notice that if global_block_end_indices parameter is set, this parameter is used as
starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global
window blocks. By default this is not used. But if it is set, it must have the same size
of global_block_indices parameter, and combining this two parameters, for each index i,
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
considered as global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_sliding_window_blocks = num_sliding_window_blocks
self.global_block_indices = global_block_indices
self.attention = attention
if global_block_end_indices is not None:
if len(global_block_indices) != len(global_block_end_indices):
raise ValueError(
f"""Global block start indices length, {len(global_block_indices)}, must be same as
global block end indices length, {len(global_block_end_indices)}!"""
)
for _, (start_idx, end_idx) in enumerate(
zip(global_block_indices, global_block_end_indices)
):
if start_idx >= end_idx:
raise ValueError(
f"""Global block start index, {start_idx}, must be smaller than global block end
index, {end_idx}!"""
)
self.global_block_end_indices = global_block_end_indices
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_sliding_window_blocks:
raise ValueError(
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller
than overall number of blocks in a row, {num_blocks}!"""
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if self.global_block_end_indices is None:
for idx in self.global_block_indices:
# if global block idx is in the range of the sequence blocks
if idx < num_blocks:
# global rows
layout[h, idx, :] = 1
# global columns
layout[h, :, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(
zip(self.global_block_indices, self.global_block_end_indices)
):
# if global block idx is in the range of the sequence blocks
if start_idx < num_blocks:
end_idx = min(end_idx, num_blocks)
# global rows
layout[h, start_idx:end_idx, :] = 1
# global columns
layout[h, :, start_idx:end_idx] = 1
if self.attention == "unidirectional":
layout = torch.tril(layout)
return layout
def make_layout(self, seq_len):
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout

View File

@@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
def reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batched_dim: int
) -> torch.Tensor:
assert key_padding_mask.ndim == 2
batch_size, src_len = key_padding_mask.size()
num_heads = batched_dim // batch_size
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
def _reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
) -> torch.Tensor:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = (
key_padding_mask.view(batch_size, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(batch_size * num_heads, 1, src_len)
)
return key_padding_mask
# Combine the attention mask and key padding mask into a single mask
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
# Additive masking not yet supported
def maybe_merge_masks(
att_mask: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor],
batch_size: int,
src_len: int,
num_heads: int,
tgt_len: Optional[int] = None,
) -> Optional[torch.Tensor]:
if tgt_len is None:
tgt_len = src_len
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = _reshape_key_padding_mask(
key_padding_mask, batch_size, src_len, num_heads
)
if att_mask is None:
# make sure dimensions of key padding mask are the same as those expected for att_mask
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
# Assumption is that False means to mask.
elif att_mask.dtype == torch.bool:
att_mask = att_mask.logical_and(key_padding_mask)
else:
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
return att_mask
# Assumes that matrix passed in has had softmax applied to it.
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
"""
Computing the Moore-Penrose inverse.
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
matrix-matrix multiplications.
"""
i = torch.eye(
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
)
k = softmax_mat
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
if pinverse_original_init:
# This original implementation is more conservative to compute coefficient of Z_0.
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
else:
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
# convergence.
v = (
1
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
* k.transpose(-1, -2)
)
for _ in range(n_iter):
kv = torch.matmul(k, v)
v = torch.matmul(
0.25 * v,
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
)
return v
def bool_mask_to_additive(
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
assert (
mask.dtype == torch.bool
), "This util is meant to convert in between bool masks and additive ones"
mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
@dataclass
class VisualAttentionConfig(AttentionConfig):
dim_model: int # dimension of the input sequence
class LKA(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
)
self.conv1 = nn.Conv2d(dim, dim, 1)
def forward(self, x: torch.Tensor):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
@register_attention("visual", VisualAttentionConfig)
class Visual(Attention):
def __init__(
self,
dim_model: int,
*_,
**__,
):
"""
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
for the reference implementation
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
and the prior and posterior transformations (Conv2d and activation)
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
"""
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim_model, dim_model, 1),
nn.GELU(),
LKA(dim_model),
nn.Conv2d(dim_model, dim_model, 1),
)
# MHA related flags:
self.requires_same_k_q_dimensions = (
True # This mechanism only really supports self attention
)
self.supports_attention_mask = False
self.requires_skip_multi_head = (
True # This mechanism skips the multihead attention altogether
)
self.requires_squared_context = (
True # Recovering the 2D structure from context assumes squared content
)
self.requires_input_projection = (
False # This mechanism does not require that the MHA projects inputs
)
def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW
x = q.transpose(-2, -1).reshape(B, C, H, H)
# Large kernel attention
residual = x.clone()
x = self.block(x)
x = x + residual
# Get back to B HW C
return x.flatten(2, 3).transpose(-2, -1)

View File

@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Any, Callable, Dict, Set, Union
from xformers.utils import (
generate_matching_config,
get_registry_decorator,
import_all_modules,
)
from .base import Feedforward, FeedforwardConfig # noqa
# CREDITS: Classy Vision registry mechanism
FEEDFORWARD_REGISTRY: Dict[str, Any] = {}
FEEDFORWARD_CLASS_NAMES: Set[str] = set()
def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]):
"""Builds a feedforward from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_feedforward",
"foo": "bar"}` will find a class that was registered as "my_feedforward"
(see :func:`register_feedforward`) and call .from_config on it."""
if not isinstance(config, FeedforwardConfig):
config_instance = generate_matching_config(
config, FEEDFORWARD_REGISTRY[config["name"]].config
)
else:
config_instance = config
return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config(
config_instance
)
"""Registers a Feedforward subclass.
This decorator allows xFormers to instantiate a subclass of Feedforward
from a configuration file, even if the class itself is not part of the
xFormers framework. To use it, apply this decorator to a Feedforward
subclass, like this:
.. code-block:: python
@dataclass
class MyConfig:
...
@register_feedforward('my_ff', MyConfig)
class MyFeedforward(Feedforward):
...
To instantiate a feedforward from a configuration file, see :func:`build_feedforward`."""
register_feedforward: Callable[
[str, Any], Callable[[Any], Any]
] = get_registry_decorator(
FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig
)
try:
from .fused_mlp import FusedMLP # noqa
_fused_mlp_available = True
except ImportError:
_fused_mlp_available = False
from .mlp import MLP # noqa
__all__ = [
"MLP",
"Feedforward",
"build_feedforward",
"register_feedforward",
]
if _fused_mlp_available:
__all__ += ["FusedMLP"]
# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward")

View File

@@ -0,0 +1,53 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch.nn as nn
from xformers.components import Activation
Self = TypeVar("Self", bound="Feedforward")
@dataclass
class FeedforwardConfig:
name: str
dim_model: int
dropout: float
activation: Activation
# Define the common interface, every feedforward block needs to derive from it
class Feedforward(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(
self,
dim_model: Optional[int] = None,
dropout: Optional[float] = None,
activation: Optional[Activation] = None,
*args,
**kwargs,
):
super().__init__()
# This feedforward requires a CUDA accelerator
self.requires_cuda = False
# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False
@classmethod
def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Largely reusing the code from the reference VAN implementation
# see https://github.com/Visual-Attention-Network
import math
from dataclasses import dataclass
from typing import Optional
import torch.nn as nn
from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig
from . import register_feedforward
@dataclass
class ConvMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
dim_model: int
dim_model_out: Optional[int]
act_layer: Activation
dropout: float
@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
class Conv2DFeedforward(Feedforward):
"""
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
"""
def __init__(
self,
dim_model: int,
hidden_layer_multiplier: int = 1,
dim_model_out: Optional[int] = None,
activation: Activation = Activation.GeLU,
dropout=0.0,
*args,
**kwargs,
):
super().__init__()
out_features = dim_model_out or dim_model
hidden_features = hidden_layer_multiplier * dim_model
self.conv_mlp = nn.Sequential(
nn.Conv2d(dim_model, hidden_features, 1),
nn.Conv2d(
hidden_features,
hidden_features,
3,
1,
1,
bias=True,
groups=hidden_features,
),
build_activation(activation),
nn.Conv2d(hidden_features, out_features, 1),
nn.Dropout(dropout),
)
# This feedforward requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = True
def init_weights(self, **kwargs):
# Follow the original init, but also make it possible to initialize from the outside
def init_module(m: nn.Module):
if isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
self.apply(init_module)
def forward(self, x):
# The conv layers expect NCHW, we have NLC by default
B, L, C = x.shape
HW = int(math.sqrt(x.shape[-2]))
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"
x = x.reshape((B, HW, HW, C)).swapdims(1, -1)
# The actual FW, including the 2d convolutions
x = self.conv_mlp(x)
# back to NLC
x = x.transpose(1, -1)
return x.flatten(1, 2)

View File

@@ -0,0 +1,79 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
import torch
import torch.nn as nn
from xformers.components import Activation
from xformers.components.feedforward import (
Feedforward,
FeedforwardConfig,
register_feedforward,
)
logger = logging.getLogger("xformers")
if torch.cuda.is_available():
try:
from xformers.triton import FusedDropoutBias
@dataclass
class FusedMlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
@register_feedforward("FusedMLP", FusedMlpConfig)
class FusedMLP(Feedforward):
"""
A MLP using fused linear layers.
"""
def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
hidden_layer_multiplier: int,
bias: bool = True,
*args,
**kwargs,
):
super().__init__()
dim_mlp = hidden_layer_multiplier * dim_model
self.mlp = nn.Sequential(
nn.Linear(
in_features=dim_model, out_features=dim_mlp, bias=False
), # bias is handled in the next layer
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
# the `FusedLinear` import.
FusedDropoutBias(
p=dropout,
bias_shape=dim_mlp if bias else None,
activation=activation,
),
nn.Linear(
in_features=dim_mlp, out_features=dim_model, bias=False
), # bias is handled in the next layer
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
# the `FusedLinear` import.
FusedDropoutBias(
p=dropout,
bias_shape=dim_model if bias else None,
activation=None,
),
)
self.requires_cuda = True
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.mlp(inputs)
except ImportError:
logger.warning("Triton is not available, FusedMLP will not be enabled.")

View File

@@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Optional, Union
import torch
from xformers.components import Activation
from xformers.components.feedforward import (
Feedforward,
FeedforwardConfig,
register_feedforward,
)
logger = logging.getLogger("xformers")
_is_fairscale_available = True
try:
import torch.distributed as dist
from fairscale.nn import MOELayer, Top2Gate # type: ignore
from xformers.components.feedforward import MLP
except ImportError:
logger.warning(
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed."
" Please install them if you would like to use MoE"
)
_is_fairscale_available = False
if _is_fairscale_available:
# Credits: initially implemented in FairScale for sanity checking
class RoundRobinGate(torch.nn.Module):
def __init__(self, model_dim, num_experts):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
def forward(self, input):
s = input.shape[0]
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
capacity = 2 * s // self.num_experts
output = torch.zeros(
s, self.num_experts, capacity, dtype=input.dtype, device=input.device
)
for i in range(s):
output[i, i % self.num_experts, i // self.num_experts] = 1.0
return 0.0, output, output.bool()
class GateConfig(str, Enum):
RoundRobin = "round_robin"
Top2 = "top_2"
# Other gating techniques could be exposed here
@dataclass
class MoEConfig(FeedforwardConfig):
number_of_experts: int
gate: GateConfig
number_of_local_experts: Optional[int] = None
expert_constructor: Optional[Any] = None
hidden_layer_multiplier: Optional[int] = None
group: Optional[Any] = None
@register_feedforward("MixtureOfExperts", MoEConfig)
class MixtureOfExperts(Feedforward):
"""
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_.
xFormers uses the FairScale_ implementation under the hood.
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf
.. _FairScale: https://github.com/facebookresearch/fairscale/
"""
def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
number_of_experts: int,
gate: Union[GateConfig, torch.nn.Module],
number_of_local_experts: Optional[int] = None,
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
hidden_layer_multiplier: Optional[int] = None,
group: Optional[Any] = None,
*_,
**__,
):
super().__init__()
# Handle a possibly uninitialized process group
assert (
dist.is_initialized()
), "Mixture of Experts require torch distributed to be initialized"
if number_of_local_experts is not None:
assert number_of_experts >= number_of_local_experts
else:
if dist.get_world_size() == 1:
logger.warning("Local experts no specified but world size of 1")
logger.warning("Assuming that all experts are local")
number_of_local_experts = number_of_experts
else:
number_of_local_experts = 1
# Programatically handle the gating technique
if not isinstance(gate, torch.nn.Module):
gate_constructor = {
GateConfig.RoundRobin: RoundRobinGate,
GateConfig.Top2: Top2Gate,
}[gate]
self.gate = gate_constructor(dim_model, number_of_experts)
else:
self.gate = gate
# Programatically handle the experts
if expert_constructor is None:
multiplier = (
hidden_layer_multiplier
if hidden_layer_multiplier is not None
else 4
)
def expert_constructor() -> torch.nn.Module:
return MLP(dim_model, dropout, activation, multiplier)
assert expert_constructor is not None
local_experts = torch.nn.ModuleList(
[expert_constructor() for _ in range(number_of_local_experts)]
)
self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group)
self.requires_cuda = True
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# FairScale MoE assumes that the dimensions are [S, B, E]
# xFormers assumes [B, S, E]
return self.moe(inputs.movedim(0, 1)).movedim(0, 1)

View File

@@ -0,0 +1,47 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
import torch.nn as nn
from xformers.components import Activation, build_activation
from xformers.components.feedforward import Feedforward, FeedforwardConfig
from . import register_feedforward
@dataclass
class MlpConfig(FeedforwardConfig):
hidden_layer_multiplier: int
bias: bool
@register_feedforward("MLP", MlpConfig)
class MLP(Feedforward):
def __init__(
self,
dim_model: int,
dropout: float,
activation: Activation,
hidden_layer_multiplier: int,
bias: bool = True,
*args,
**kwargs,
):
super().__init__()
dim_mlp = hidden_layer_multiplier * dim_model
self.mlp = nn.Sequential(
nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
build_activation(activation),
nn.Dropout(dropout),
nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
nn.Dropout(dropout),
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.mlp(inputs)

View File

@@ -0,0 +1,99 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py
# and the MultiHeadAttention implementation from PyTorch
import logging
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from torch import nn
logger = logging.getLogger("xformers")
@dataclass
class InputProjectionConfig:
in_features: int
out_features: int
bias: bool
class InputProjection(nn.Module):
"""
Handle all the input projections in one go, opportunistically fuse some operations.
"""
def __init__(
self,
query_proj_params: InputProjectionConfig,
key_proj_params: Optional[InputProjectionConfig],
value_proj_params: Optional[InputProjectionConfig],
use_separate_proj_weight: bool = True,
):
super().__init__()
self.out_features = query_proj_params.out_features
# Each input gets a separate projection
self.q_proj = nn.Linear(
query_proj_params.in_features,
query_proj_params.out_features,
query_proj_params.bias,
)
if key_proj_params is not None:
self.k_proj = nn.Linear(
key_proj_params.in_features,
key_proj_params.out_features,
key_proj_params.bias,
)
else:
logger.info(
"No Key projection parameters were passed, assuming that the weights"
+ " are shared with the query projection"
)
self.k_proj = self.q_proj
if value_proj_params is not None:
self.v_proj = nn.Linear(
value_proj_params.in_features,
value_proj_params.out_features,
value_proj_params.bias,
)
else:
logger.info(
"No Value projection parameters were passed, assuming that the weights"
+ " are shared with the query projection"
)
self.v_proj = self.q_proj
if not use_separate_proj_weight:
# Compute optimization used at times, share the parameters in between Q/K/V
with torch.no_grad():
self.k_proj.weight = self.q_proj.weight
self.v_proj.weight = self.q_proj.weight
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# One projection per input tensor
# NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ?
q, k, v = map(
lambda fn, x: fn(x),
[self.q_proj, self.k_proj, self.v_proj],
[query, key, value],
)
return q, k, v

View File

@@ -0,0 +1,269 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.nn.init import constant_
from xformers.components.attention import Attention
from xformers.components.input_projection import InputProjection, InputProjectionConfig
from xformers.components.positional_embedding import RotaryEmbedding
logger = logging.getLogger("xformers")
@dataclass
class MultiHeadDispatchConfig:
dim_model: int
num_heads: int
attention: Attention
bias: bool
residual_dropout: float
dim_key: Optional[int]
dim_value: Optional[int]
in_proj_container: Optional[InputProjection]
use_separate_proj_weight: Optional[bool]
use_rotary_embeddings: Optional[bool]
out_proj: Optional[nn.Module]
def __getitem__(self, item):
return getattr(self, item)
# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs)
def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1)
# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs)
def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int):
return t.view(B, S, H, Hs).transpose(1, 2)
class MultiHeadDispatch(nn.Module):
"""
A multi-head masked self-attention dispatch mechanism, with a projection at the end,
following the architecture proposed in `Attention is all you need`_, Vaswani et al.
The actual attention mechanism can vary, as well as the projections.
This can be used to wrap the proposed attention mechanisms and make them multi-head aware,
but it is optional.
Args:
dim_model: The model/embedding dimension
num_heads: The number of heads being used
attention: The attention mechanism (needs to be registered to the xformers library)
bias: Whether to use bias for the projections : (Q, K, V, Output)
residual_dropout: Amount of dropout on the residual path
use_separate_proj_weight: Use different weights for the Q, K, V projections
dim_key: Optionally use a different dimension for the key
dim_value: Optionally use a different dimension for the value
in_proj_container: Optionally provide the input projection module
use_rotary_embeddings: Use rotary embeddings
out_proj: Optionally provide the output projection module
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
"""
def __init__(
self,
dim_model: int,
num_heads: int,
attention: Attention,
bias: Tuple[bool, bool, bool, bool] = (True, True, True, True),
residual_dropout: float = 0.0,
use_separate_proj_weight: bool = True,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
in_proj_container: Optional[InputProjection] = None,
use_rotary_embeddings: Optional[bool] = False,
out_proj: Optional[nn.Module] = None,
*args,
**kwargs,
):
super().__init__()
if isinstance(bias, bool):
logger.warning(
"Single bias value provided for the MHA projections."
+ f" Assuming the same parameter ({bias}) is to be used everywhere"
)
bias = (bias, bias, bias, bias)
assert (
dim_model % num_heads == 0
) # static preset for now, each head works on 1/d the embeddings, could be relaxed
assert num_heads > 0
# Popular default is that all latent dimensions are the same
dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value))
self.num_heads = num_heads
self.dim_key_head = dim_key // num_heads
self.dim_value_head = dim_value // num_heads
self.dim_model = dim_model
self.attention = attention
# key, query, value projections for all heads
# critical options are
# - are we sharing weights ?
# - are we adding biases ?
if attention.requires_input_projection:
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InputProjection(
query_proj_params=InputProjectionConfig(
dim_model, dim_key, bias=bias[0]
),
key_proj_params=InputProjectionConfig(
dim_model, dim_key, bias=bias[1]
),
value_proj_params=InputProjectionConfig(
dim_model, dim_value, bias=bias[2]
),
use_separate_proj_weight=use_separate_proj_weight,
)
)
# Optional rotary embeddings
self.rotary_embeddings = (
RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None
)
# Regularization
self.resid_drop = nn.Dropout(residual_dropout, inplace=False)
# Output projection
self.proj = (
out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3])
)
if isinstance(self.proj, nn.Linear) and self.proj.bias is not None:
constant_(self.proj.bias, 0.0)
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
att_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Expected input dimensions are [batch size, sequence length, embed dim]
Output dimensions are [batch size, sequence length, embed dim]
"""
if key is None:
key = query
if value is None:
value = query
if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]:
max_batch = max((query.shape[0], key.shape[0], value.shape[0]))
query, key, value = map(
lambda x: x.expand(max_batch, -1, -1), [query, key, value]
)
B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent)
_, S_K, _ = key.size() # K, Q's sequence length could differ
# Catch different query and key length but a causal attention
if S_Q != S_K:
assert (
not self.attention.requires_same_k_q_dimensions
), "This attention mechanism requires query and key to have the same sequence (context) lengths"
if hasattr(self.attention, "causal"):
assert not self.attention.causal, (
"Causal attention is not supported when key and query have different sequence lengths.\n"
+ "In that case causality is ill-determined. Please pad your sequences accordingly"
)
kw_mask_args = {}
if att_mask is not None:
assert (
self.attention.supports_attention_mask
), "This attention does not support attention masks"
kw_mask_args["att_mask"] = att_mask
if key_padding_mask is not None:
assert (
self.attention.supports_key_padding_mask
), "This attention does not support key padding masks"
kw_mask_args["key_padding_mask"] = key_padding_mask
if self.attention.requires_skip_multi_head:
return self.attention(query, key, value, **kw_mask_args)
# Calculate query, key, values for all heads in batch
if self.attention.requires_input_projection:
q, k, v = self.in_proj_container(query=query, key=key, value=value)
else:
k, q, v = key, query, value
# Check the dimensions properly
def check(t, name):
assert (
t.shape[2] % self.num_heads == 0
), f"the {name} embeddings need to be divisible by the number of heads"
check(q, "projected query")
check(v, "projected value")
check(k, "projected key")
# Optional: rotary embedding, add relative positioning information
if self.rotary_embeddings:
# rotary requires the head dimension
q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head)
k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head)
v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head)
q, k = self.rotary_embeddings(q=q, k=k)
if not self.attention.requires_head_dimension:
q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1)
else:
# Reshape k/q/v to either expose the heads, or fold the head dimension into the batch
reshape_fn = (
_split_heads if self.attention.requires_head_dimension else _fold_heads
)
q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head)
k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head)
v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head)
# Self-attend
y = self.attention(q, k, v, **kw_mask_args)
# Re-assemble all head outputs side by side
y = (
y.view(B, self.num_heads, S_Q, self.dim_value_head)
.transpose(1, 2)
.flatten(start_dim=2, end_dim=3)
)
# Output projection, dropout and good to go
y = self.resid_drop(self.proj(y))
# Return the same sequence size as the input
return y
@classmethod
def from_config(cls, config: MultiHeadDispatchConfig):
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,79 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from enum import Enum
import torch
class PoolType(str, Enum):
Conv2D = "CONV_2D"
# ...
# TODO: Support more cases ?
@dataclass
class PatchEmbeddingConfig:
"""
The configuration for the patch embedding layer, which takes the raw token passed in
and returns a pooled representation along a given embedding dimension.
This typically trades the spatial (context length) representation with the embedding size
This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers)
propose a more general use case for this
"""
in_channels: int
out_channels: int
kernel_size: int
stride: int
padding: int = 0
pool_type: PoolType = PoolType.Conv2D
class ConditionalReshape(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
if x.ndim == 3:
B, HW, C = x.shape
# NOTE: We're assuming a square sample here
H = int(math.sqrt(HW))
assert H * H == HW, f"{H, HW}"
x = x.transpose(1, 2).reshape(B, C, H, H)
return x
class PatchToSequence(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C
def build_patch_embedding(config: PatchEmbeddingConfig):
if not isinstance(config, PatchEmbeddingConfig):
config = PatchEmbeddingConfig(**config)
if config.pool_type == PoolType.Conv2D:
pool = torch.nn.Conv2d(
config.in_channels,
config.out_channels,
kernel_size=config.kernel_size,
stride=config.stride,
padding=config.padding,
)
else:
raise NotImplementedError
# The patch embedding supposes that the input really is 2D in essence
# If this block is in the middle of a stack, we need to reshape
return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence())

View File

@@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from typing import Any, Callable, Dict, Set, Union
from xformers.utils import (
generate_matching_config,
get_registry_decorator,
import_all_modules,
)
from .base import PositionEmbedding, PositionEmbeddingConfig # noqa
# CREDITS: Classy Vision registry mechanism
POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {}
POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set()
def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]):
"""Builds a position encoding from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_position_encoding",
"foo": "bar"}` will find a class that was registered as "my_position_encoding"
(see :func:`register_positional_embedding`) and call .from_config on it."""
if not isinstance(config, PositionEmbeddingConfig):
config_instance = generate_matching_config(
config, POSITION_EMBEDDING_REGISTRY[config["name"]].config
)
else:
config_instance = config
return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config(
config_instance
)
"""Registers a PositionEncoding subclass.
This decorator allows xFormers to instantiate a subclass of PositionEncoding
from a configuration file, even if the class itself is not part of the
xFormers framework. To use it, apply this decorator to a `PositionEncoding`
subclass, like this:
.. code-block:: python
@dataclass
class MyConfig:
...
@register_positional_embedding('my_encoding', MyConfig)
class MyEncoding(PositionEncoding):
...
To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`."""
register_positional_embedding: Callable[
[str, Any], Callable[[Any], Any]
] = get_registry_decorator(
POSITION_EMBEDDING_REGISTRY,
POSITION_EMBEDDING_CLASS_NAMES,
PositionEmbedding,
PositionEmbeddingConfig,
)
from .rotary import RotaryEmbedding # noqa
from .sine import SinePositionalEmbedding # type: ignore # noqa
from .vocab import VocabEmbedding # noqa
__all__ = [
"RotaryEmbedding",
"SinePositionalEmbedding",
"VocabEmbedding",
"build_positional_embedding",
"register_positional_embedding",
]
# automatically import any Python files in the directory
import_all_modules(
str(Path(__file__).parent), "xformers.components.positional_embedding"
)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Type, TypeVar
import torch.nn as nn
Self = TypeVar("Self", bound="PositionEmbedding")
@dataclass
class PositionEmbeddingConfig:
name: str
dim_model: int
seq_len: int
class PositionEmbedding(nn.Module, metaclass=ABCMeta):
@abstractmethod
def __init__(self, *args, **kwargs) -> None:
super().__init__()
@classmethod
def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,54 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@dataclass
class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig):
name: str
seq_len: int
dim_model: int
add_class_token: bool
@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig)
class LearnablePositionalEmbedding(PositionEmbedding):
def __init__(
self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__
):
super().__init__()
# 0.02 is BERT initialization
self.pos_emb = torch.nn.Parameter(
torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02
)
self.class_token = (
torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.class_token is not None:
# Prepend class token
clf_token = (
torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device)
* self.class_token
)
x = torch.cat([clf_token, x], dim=1)
if x.ndim == 2:
x = x.unsqueeze(-1)
return x + self.pos_emb

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox
# NOTE: Almost the same right now, moving parts to Triton is the next step
from typing import Tuple
import torch
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
@torch.jit.script
def apply_rotary_pos_emb(x, cos, sin):
# NOTE: This could probably be moved to Triton
# Handle a possible sequence length mismatch in between q and k
cos = cos[:, :, : x.shape[-2], :]
sin = sin[:, :, : x.shape[-2], :]
return (x * cos) + (rotate_half(x) * sin)
class RotaryEmbedding(torch.nn.Module):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
A crucial insight from the method is that the query and keys are
transformed by rotation matrices which depend on the relative positions.
Other implementations are available in the Rotary Transformer repo_ and in
GPT-NeoX_, GPT-NeoX was an inspiration
.. _RoFormer: https://arxiv.org/abs/2104.09864
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. warning: Please note that this embedding is not registered on purpose, as it is transformative
(it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
"""
def __init__(self, dim_model: int, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_tables(self, x, seq_dimension=1):
seq_len = x.shape[seq_dimension]
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (
seq_len != self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seq_len
t = torch.arange(
x.shape[seq_dimension], device=x.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
return self._cos_cached, self._sin_cached
def forward(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
k, seq_dimension=-2
)
return (
apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
)

View File

@@ -0,0 +1,46 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Silence Mypy errors in this file.
# type: ignore
import math
import torch
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@register_positional_embedding("sine", PositionEmbeddingConfig)
class SinePositionalEmbedding(PositionEmbedding):
def __init__(self, dim_model: int, *args, **kwargs):
super().__init__()
self.dim_model = dim_model
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.shape[1]
pos = (
torch.arange(0, seq_len, device=x.device, dtype=torch.float32)
.unsqueeze(1)
.repeat(1, self.dim_model)
)
dim = (
torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32)
.unsqueeze(0)
.repeat(seq_len, 1)
)
div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model))
pos *= div
pos[:, 0::2] = torch.sin(pos[:, 0::2])
pos[:, 1::2] = torch.cos(pos[:, 1::2])
output = x.unsqueeze(-1) if x.ndim == 2 else x
return output + pos.unsqueeze(0)

View File

@@ -0,0 +1,65 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.positional_embedding import (
PositionEmbedding,
PositionEmbeddingConfig,
register_positional_embedding,
)
@dataclass
class VocabEmbeddingConfig(PositionEmbeddingConfig):
vocab_size: int
dropout: float
@register_positional_embedding("vocab", VocabEmbeddingConfig)
class VocabEmbedding(PositionEmbedding):
def __init__(
self,
dim_model: int,
seq_len: int,
vocab_size: int,
dropout: float = 0.0,
*args,
**kwargs
):
super().__init__()
self.vocab_size = vocab_size
self.dim_model = dim_model
self.dropout = torch.nn.Dropout(p=dropout)
self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
self.position_ids: Optional[torch.Tensor] = None
self.init_weights()
def init_weights(self, gain: float = 1.0):
torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain)
def forward(self, x: torch.Tensor):
position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[
None, :
].repeat(x.shape[0], 1)
X_token = self.word_embeddings(x)
X_pos = self.position_embeddings(position_ids)
X = X_token + X_pos
X = self.dropout(X)
return X

View File

@@ -0,0 +1,206 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from xformers import _is_triton_available
if _is_triton_available():
from xformers.triton.layer_norm import FusedLayerNorm
from collections import namedtuple
class ResidualNormStyle(str, Enum):
"""Support different residual path and norm styles.
See "On Layer Normalization in the Transformer Architecture",
Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
"""
Pre = "pre"
Post = "post"
DeepNorm = "deepnorm"
class NormalizationType(str, Enum):
LayerNorm = "layernorm"
Skip = "skip"
# TODO: BatchNorm = "batchnorm"
# TODO: GroupNorm = "groupnorm"
def get_normalization_layer(normalization_type: NormalizationType):
class Skip(nn.Module):
def __init__(self, *_, **__) -> None:
super().__init__()
def forward(self, x: torch.Tensor, **_):
return x
return {
NormalizationType.LayerNorm: nn.LayerNorm,
NormalizationType.Skip: Skip,
}[normalization_type]
class RequiresWrappedInputs:
"""Used to mark, through inheritance,
the fact that this class will require inputs to be passed as a single list"""
pass
# CREDITS: the following is inspired by FastAI's Transformer implementation
class Residual(nn.Module, RequiresWrappedInputs):
"""
Object-oriented handling of the residual path
This supports scaling of the residual path, as proposed by DeepNet_
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
.. Note: the wrapped layers must accept all the inputs as a single list
"""
def __init__(self, layer: nn.Module, scale: Optional[float] = None):
super().__init__()
self.layer = layer
self.scale = scale
# PreNorm and PostNorm require all the tensors to be passed as a list
self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
if self.scale is not None:
residue = inputs[0] * self.scale
else:
residue = inputs[0]
if self.wrap_inputs:
return residue + self.layer(inputs=inputs, **kwargs)
else:
return residue + self.layer(*inputs, **kwargs)
class PreNorm(nn.Module, RequiresWrappedInputs):
"""Adds a normalization before computing attention
..Note: If a list of inputs is passed, all of them get normalized"""
def __init__(
self,
d_norm: int,
sublayer: nn.Module,
normalization: NormalizationType,
use_triton: bool = True,
):
super().__init__()
if (
_is_triton_available()
and use_triton
and normalization == NormalizationType.LayerNorm
):
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
else:
self.norm = get_normalization_layer(normalization)(d_norm)
self.sublayer = sublayer
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
assert len(inputs) > 0
# Perf improvement: if the inputs are all the same, only norm once
ids = [id(x) for x in inputs]
if ids.count(ids[0]) == len(ids):
# The same tensor is passed multiple times
x_norm = self.norm(inputs[0])
inputs_normed = [x_norm for _ in inputs]
else:
# The inputs differ, norm them all
inputs_normed = [self.norm(x_) for x_ in inputs]
if self.wrap_inputs:
return self.sublayer(inputs=inputs_normed, **kwargs)
else:
return self.sublayer(*inputs_normed, **kwargs)
class PostNorm(nn.Module, RequiresWrappedInputs):
"""Adds LayerNorm after computing attention"""
def __init__(
self,
d_norm: int,
sublayer: nn.Module,
normalization: NormalizationType,
use_triton: bool = True,
):
super().__init__()
if (
_is_triton_available()
and use_triton
and normalization == NormalizationType.LayerNorm
):
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
else:
self.norm = get_normalization_layer(normalization)(d_norm)
self.sublayer = sublayer
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
def forward(self, inputs: List[torch.Tensor], **kwargs):
if self.wrap_inputs:
x = self.sublayer(inputs=inputs, **kwargs)
else:
x = self.sublayer(*inputs, **kwargs)
return self.norm(x)
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
def get_deepnorm_coefficients(
encoder_layers: int, decoder_layers: int
) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
"""
See DeepNet_.
Returns alpha and beta depending on the number of encoder and decoder layers,
first tuple is for the encoder and second for the decoder
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
"""
N = encoder_layers
M = decoder_layers
if decoder_layers == 0:
# Encoder only
return (
DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
None,
)
elif encoder_layers == 0:
# Decoder only
return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
else:
# Encoder/decoder
encoder_coeffs = DeepNormCoefficients(
alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
)
decoder_coeffs = DeepNormCoefficients(
alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
)
return (encoder_coeffs, decoder_coeffs)

View File

@@ -0,0 +1,157 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.utils.checkpoint import get_device_states, set_device_states
from xformers.components import RequiresWrappedInputs
# CREDITS: Code adapted from
# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py
# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py,
# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html
# pyre-fixme[13]: `cpu_state` is not initialized in the constructor.
class Deterministic(nn.Module):
def __init__(self, net: nn.Module):
super().__init__()
self.net = net
self.cpu_state: torch.Tensor = torch.get_rng_state()
self.cuda_in_fwd: bool = False
self.gpu_devices: List[int] = []
self.gpu_states: List[torch.Tensor] = []
self.wrap_inputs = isinstance(net, RequiresWrappedInputs)
def record_rng(self, *args):
self.cpu_state = torch.get_rng_state()
if torch.cuda._initialized:
self.cuda_in_fwd = True
self.gpu_devices, self.gpu_states = get_device_states(*args)
def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs):
if record_rng:
self.record_rng(*args)
if not set_rng:
# Normal FW run
if self.wrap_inputs:
return self.net(inputs=args, **kwargs)
else:
return self.net(*args, **kwargs)
else: # pragma: no cover # this is called in the backward pass, not picked up
# This is analogous to checkpointing, reset the original random state
rng_devices: List[int] = []
if self.cuda_in_fwd:
rng_devices = self.gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=True):
torch.set_rng_state(self.cpu_state)
if self.cuda_in_fwd:
set_device_states(self.gpu_devices, self.gpu_states)
if self.wrap_inputs:
return self.net(inputs=args, **kwargs)
else:
return self.net(*args, **kwargs)
class ReversibleBlock(nn.Module):
def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1):
super().__init__()
self.f = Deterministic(f)
self.g = Deterministic(g)
self.split_dim = split_dim
def forward(self, x: torch.Tensor, f_args={}, g_args={}):
x1, x2 = torch.chunk(x, 2, dim=-1)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
return torch.cat([y1, y2], dim=self.split_dim)
def backward_pass(
self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={}
): # pragma: no cover # this is covered, but called directly from C++
y1, y2 = torch.chunk(y, 2, dim=self.split_dim)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim)
del dy
with torch.enable_grad():
y1.requires_grad = True
gy1 = self.g(y1, set_rng=True, **g_args)
torch.autograd.backward(gy1, dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f(x2, set_rng=True, **f_args)
torch.autograd.backward(fx2, dx1)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=self.split_dim)
dx = torch.cat([dx1, dx2], dim=self.split_dim)
return x, dx
class _ReversibleFunction(Function):
@staticmethod
def forward(ctx, x, blocks, kwargs):
ctx.kwargs = kwargs
for block in blocks:
x = block(x, **kwargs)
ctx.y = x.detach()
ctx.blocks = blocks
return x
@staticmethod
def backward(
ctx, dy
): # pragma: no cover # this is covered, but called directly from C++
y = ctx.y
kwargs = ctx.kwargs
for block in ctx.blocks[::-1]:
y, dy = block.backward_pass(y, dy, **kwargs)
return dy, None, None
class ReversibleSequence(nn.Module):
def __init__(self, blocks: nn.ModuleList):
super().__init__()
# pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values.
self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks])
def forward(self, x, arg_route=(True, False), **kwargs):
f_args, g_args = map(lambda route: kwargs if route else {}, arg_route)
block_kwargs = {"f_args": f_args, "g_args": g_args}
return _ReversibleFunction.apply(x, self.blocks, block_kwargs)

View File

@@ -0,0 +1,73 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
from xformers import _is_triton_available
Self = TypeVar("Self", bound="SimplicialEmbedding")
@dataclass
class SimplicialEmbeddingConfig:
L: int
temperature: float
class SimplicialEmbedding(torch.nn.Module):
"""
An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al
Arguments:
- L: the number of embedding chunks
- temperature: optional scaling parameter for the softmax operation.
A small (<1.) temperature will lead to a sparse representation (up to one-hot),
while a large (>1.) temperature will make the vector more uniform
_"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf
"""
def __init__(self, L: int, temperature: Optional[float] = None) -> None:
super().__init__()
self.L = L
self.temperature = temperature
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert (
x.shape[-1] % self.L == 0
), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}"
# Seperate the input tensor into V chunks
B, C, E = x.shape
V = E // self.L
Vs = x.reshape(B, C, self.L, V)
# Softmax normalize them, with the proposed temperature
# This is done over the last dimension, so only within Vs
if self.temperature is not None:
Vs /= self.temperature
if _is_triton_available():
from xformers.triton.softmax import softmax as triton_softmax
Vs = triton_softmax(
Vs, mask=None, causal=False
) # the softmax is on the last dimension
else:
Vs = torch.nn.functional.softmax(Vs, dim=-1)
# Concatenate back and return
return Vs.reshape(B, C, E)
@classmethod
def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
return cls(**fields)