First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,133 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Set, Union
import torch
from xformers.utils import (
generate_matching_config,
get_registry_decorator,
import_all_modules,
)
from ._sputnik_sparse import SparseCS
from .attention_mask import AttentionMask
from .base import Attention, AttentionConfig # noqa
logger = logging.getLogger("xformers")
# CREDITS: Classy Vision registry mechanism
ATTENTION_REGISTRY: Dict[str, Any] = {}
ATTENTION_CLASS_NAMES: Set[str] = set()
# Arbitrary threshold for now,
# in between dense and sparse matrix algorithms for the attention mechanism
_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs.
_USE_SPUTNIK = True
def build_attention(config: Union[Dict[str, Any], AttentionConfig]):
"""Builds an attention from a config.
This assumes a 'name' key in the config which is used to determine what
attention class to instantiate. For instance, a config `{"name": "my_attention",
"foo": "bar"}` will find a class that was registered as "my_attention"
(see :func:`register_attention`) and call .from_config on it."""
if not isinstance(config, AttentionConfig):
try:
config_instance = generate_matching_config(
config, ATTENTION_REGISTRY[config["name"]].config
)
except KeyError as e:
name = config["name"]
logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}")
raise e
else:
config_instance = config
return ATTENTION_REGISTRY[config_instance.name].constructor.from_config(
config_instance
)
"""Registers an Attention subclass.
This decorator allows xFormers to instantiate a subclass of Attention
from a configuration file, even if the class itself is not part of the
xFormers library. To use it, apply this decorator to an Attention
subclass, like this:
.. code-block:: python
@dataclass
class MyConfig:
...
@register_attention('my_attention', MyConfig)
class MyAttention(Attention):
...
To instantiate an attention from a configuration file, see :func:`build_attention`."""
register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator(
ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig
)
def maybe_sparsify(matrix) -> Any:
# Sparsify if that makes sense
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
# If not sparse, then AttentionMask is the reference type
return AttentionMask.from_bool(matrix)
return sparsify(matrix)
def sparsify(matrix):
if _USE_SPUTNIK:
return SparseCS(matrix)
return matrix.to_sparse()
from .favor import FavorAttention # noqa
from .global_tokens import GlobalAttention # noqa
from .linformer import LinformerAttention # noqa
from .local import LocalAttention # noqa
from .nystrom import NystromAttention # noqa
from .ortho import OrthoFormerAttention # noqa
from .random import RandomAttention # noqa
from .scaled_dot_product import ScaledDotProduct # noqa
__all__ = [
"ScaledDotProduct",
"LocalAttention",
"LinformerAttention",
"NystromAttention",
"RandomAttention",
"OrthoFormerAttention",
"GlobalAttention",
"FavorAttention",
"Attention",
"AttentionMask",
"build_attention",
"register_attention",
]
# Optionally expose the BlockSparse attention
try:
from .blocksparse import BlockSparseAttention # noqa
__all__ += ["BlockSparseAttention"]
except ImportError:
pass
# automatically import any Python files in the directory
import_all_modules(str(Path(__file__).parent), "xformers.components.attention")

View File

@@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.ops import masked_matmul
from xformers.sparse import SparseCSRTensor
# TODO: this is here for BC
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
class SparseCS:
def __init__(self, matrix, device=None):
if device is None:
device = torch.device("cpu")
if matrix.ndim == 2:
matrix = matrix[None]
assert matrix.ndim == 3
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
@property
def device(self):
return self._mat.device
@property
def ndim(self):
return self._mat.ndim
@property
def dtype(self):
return self._mat.dtype
@property
def is_sparse(self):
return True
@property
def shape(self):
return self._mat.shape[1:]
@property
def values(self):
return self._mat.values()
@property
def row_indices(self):
return self._mat._csr_row_indices
@property
def column_indices(self):
return self._mat._csr_column_indices
@property
def row_offsets(self):
return self._mat._csr_row_offsets
@property
def _transp_info(self):
return self._mat._csr_transp_info
@classmethod
def wrap(
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
):
matrix = cls.__new__(cls)
_shape = (values.shape[0],) + shape
csr_matrix = SparseCSRTensor._wrap(
_shape, values, row_indices, row_offsets, column_indices, _transp_info
)
matrix._mat = csr_matrix
return matrix
@classmethod
def _wrap(cls, csr_matrix):
assert isinstance(csr_matrix, SparseCSRTensor)
matrix = cls.__new__(cls)
matrix._mat = csr_matrix
return matrix
def __mul__(self, other):
assert isinstance(other, (int, float))
return type(self)._wrap(self._mat * other)
def __add__(self, other):
assert isinstance(other, type(self))
return type(self)._wrap(self._mat + other._mat)
def matmul_with_mask(self, a, b):
return type(self)._wrap(masked_matmul(a, b, self._mat))
def softmax(self):
out = torch.nn.functional.softmax(self._mat, -1)
return type(self)._wrap(out)
def spmm(self, b):
out = torch.bmm(self._mat, b)
return out
def transpose(self):
out = torch.transpose(self._mat, -2, -1)
return type(self)._wrap(out)
def to(self, device):
assert isinstance(device, torch.device)
out = self._mat.to(device)
return type(self)._wrap(out)
def to_dense(self):
return self._mat.to_dense()
def logical_and(self, other: torch.Tensor):
assert not isinstance(other, SparseCS)
out = torch.logical_and(self._mat, other)
return type(self)._wrap(out)
def __and__(self, other):
return self.logical_and(other)

View File

@@ -0,0 +1,143 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Type, TypeVar
import torch
Self = TypeVar("Self", bound="AttentionMask")
class AttentionMask:
"""
Holds an attention mask, along with a couple of helpers and attributes.
.. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value,
and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose
is to bias the attention computation for instance
.. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`,
`[to_sequence, from_sequence]`, or anything broadcastable in between
"""
def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
assert additive_mask.is_floating_point(), additive_mask.dtype
assert not additive_mask.requires_grad
if additive_mask.ndim == 2:
additive_mask = additive_mask.unsqueeze(0)
self.values = additive_mask
self.is_causal = is_causal
self.seq_len = additive_mask.shape[1]
self.to_seq_len = additive_mask.shape[0]
def to_bool(self) -> torch.Tensor:
"""
.. warning: we assume here that True implies that the value should be computed
"""
return self.values != float("-inf")
@classmethod
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
"""
Create an AttentionMask given a boolean pattern.
.. warning: we assume here that True implies that the value should be computed
"""
assert x.dtype == torch.bool
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
additive_mask.masked_fill_(x, 0.0)
additive_mask.masked_fill_(~x, float("-inf"))
return cls(additive_mask)
@classmethod
def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
"""
Create an AttentionMask given a multiplicative attention mask.
"""
assert not x.dtype == torch.bool
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
x = x.bool()
additive_mask.masked_fill_(x, 0.0)
additive_mask.masked_fill_(~x, float("-inf"))
return cls(additive_mask)
@classmethod
def make_causal(
cls: Type[Self],
seq_len: int,
to_seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not to_seq_len:
to_seq_len = seq_len
additive_mask = torch.triu(
torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"),
diagonal=1,
)
return cls(additive_mask=additive_mask, is_causal=True)
def make_crop(
self, seq_len: int, to_seq_len: Optional[int] = None
) -> "AttentionMask":
"""
Return a cropped attention mask, whose underlying tensor is a view of this one
"""
if not to_seq_len:
to_seq_len = seq_len
return AttentionMask(
self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal
)
def __repr__(self):
return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values)
@property
def device(self):
return self.values.device
@property
def is_sparse(self):
return False
@property
def ndim(self):
return len(self.values.shape)
@property
def dtype(self):
return self.values.dtype
@property
def shape(self):
return self.values.shape
def __add__(self, other):
return AttentionMask(self.values + other.values, is_causal=False)
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
) -> "AttentionMask":
assert device is None or isinstance(device, torch.device)
assert dtype is None or isinstance(dtype, torch.dtype)
assert device is not None or dtype is not None
# Noop if we don't need to create another instance
if ((device and device == self.device) or not device) and (
(dtype and dtype == self.dtype) or not dtype
):
return self
return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)

View File

@@ -0,0 +1,295 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
import numpy as np
import torch
from xformers.components.attention.sparsity_config import (
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
FixedSparsityConfig,
VariableSparsityConfig,
)
# generic nd cases
def _generate_nd_grid(*sizes):
coords = [torch.arange(s) for s in sizes]
return torch.meshgrid(*coords)
def local_nd_distance(*sizes, p=2.0, weights=None):
if weights is None:
weights = (1,) * len(sizes)
assert len(sizes) == len(weights)
grid = _generate_nd_grid(*sizes)
grid = [i.flatten() * w for i, w in zip(grid, weights)]
grid = torch.stack(grid, dim=1).float()
d = torch.cdist(grid, grid, p=p)
return d
def local_nd_gaussian_distribution(*sizes, sigma=1):
d = local_nd_distance(*sizes, p=2.0) ** 2
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
return d
def local_nd_pattern(*sizes, distance, p=2.0):
d = local_nd_distance(*sizes, p=p)
return d < distance
def axial_nd_pattern(*sizes):
# axial is a special case with p=0 and distance=2
d = local_nd_distance(*sizes, p=0)
return d < 2
def random_pattern_from_probability_matrix(dist_matrix, nnz):
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
# PyTorch multinomial wrongly doesn't support sampling when number of categories
# is > 2^24, arguing that it's because it's the max representable consecutive element
# in fp32 and that the kernels use float32. This is actually not true, and the kernels
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
if dist_matrix.numel() > 2**24:
dist_matrix = dist_matrix.double()
dist_matrix /= dist_matrix.sum()
idxs = np.random.choice(
dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
)
idxs = torch.as_tensor(idxs)
else:
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
att.view(-1)[idxs] = True
return att
def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor:
assert attention_query_mask.ndim == 1
assert attention_query_mask.dtype == torch.bool
attention_query_mask = attention_query_mask[None, :]
mask = attention_query_mask | attention_query_mask.transpose(1, 0)
return mask
def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
assert 0 < sparsity < 1
mask = torch.rand(attn_size, attn_size) > sparsity
return mask
# 1d-specific cases
def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
h_win_size = window_size // 2 + 1
return local_nd_pattern(attn_size, distance=h_win_size, p=1.0)
def causal_1d_pattern(attn_size: int) -> torch.Tensor:
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
return mask
# 2d-specific cases
def horizontal_axial_2d_distance(H, W, p=2.0):
d = local_nd_distance(H, W, p=p, weights=(1, 0))
return d
def vertical_axial_2d_distance(H, W, p=2.0):
d = local_nd_distance(H, W, p=p, weights=(0, 1))
return d
def local_2d_distance(H, W, p=2.0):
return local_nd_distance(H, W, p=p)
def local_2d_gausian_distribution(H, W, sigma=1):
return local_nd_gaussian_distribution(H, W, sigma=sigma)
def local_2d_pattern(H, W, distance, p=2.0):
return local_nd_pattern(H, W, distance=distance, p=p)
def axial_2d_pattern(H, W):
return axial_nd_pattern(H, W)
def swin_attention_pattern(H, W, window_size, shift_size=0):
assert H % window_size == 0
assert W % window_size == 0
assert 0 <= shift_size < window_size, "shift_size must in 0-window_size"
# input grid
i, j = _generate_nd_grid(H, W)
i, j = i + 0.5, j + 0.5
# anchors grid
# if shift is present, add extra element to the grid
# to account for the uneven partitioning
extra = int(shift_size % window_size != 0)
grid_h = H // window_size + extra
grid_w = W // window_size + extra
ii, jj = _generate_nd_grid(grid_h, grid_w)
# convert shift to be compatible with the paper representation
s = (-shift_size) % window_size
offset = window_size / 2 - s
ii = ii * window_size + offset
jj = jj * window_size + offset
input_coords = torch.stack([i.flatten(), j.flatten()], 1).float()
anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float()
anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
mask = anchor_id[:, None] == anchor_id[None, :]
return mask
def dilated_2d_pattern(H, W, k=2):
"""
Returns a 2d pattern that samples 1 every k elements in the attention mask.
Can be seen as a form of downsampling, where every pixel attends to a downsampled
version of the input.
"""
d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
return d
# Block sparse utils
def block_sparsify_tensor(x, mask, block_size):
"""
Block sparsify a tensor, given a mask and block size
"""
ret = torch.empty(
(x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device
)
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
ret[:, idx, :, :] = x[
:,
h,
i * block_size : (i + 1) * block_size,
j * block_size : (j + 1) * block_size,
]
return ret
def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
r"""
Given a mask pattern and blocksize, return the corresponding layout
which makes sure that all the positives in the mask are covered
"""
assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]"
_should_squeeze = False
if mask.ndim == 2:
mask = mask.unsqueeze(0)
_should_squeeze = True
assert (
mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0
), "We're only handling masks divisible by block_size"
# Now mark the mask
layout = torch.nn.functional.max_pool2d(
mask.to(torch.float), kernel_size=block_size, stride=block_size
)
layout = layout.to(torch.long)
if _should_squeeze:
layout.squeeze_(0)
return layout
def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor:
r"""
Use the additive bias computation from ALiBi_ to generate a mask.
Note that this mask can in turn be used to generate a blocksparse attention computation layout
.. note: mask_shape is expected to hold the [heads, seq, seq] dimensions
.. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf
"""
# CREDITS: code snippet from Ofir Press, one of the authors
def get_slopes(n: int):
def get_slopes_power_of_2(n: int) -> List[float]:
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
# In the paper, we only train models that have 2^a heads for some a. This function has
# some good properties that only occur when the input is a power of 2. To maintain that even
# when the number of heads is not a power of 2, we use this workaround.
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
maxpos = mask_shape[1]
attn_heads = mask_shape[0]
slopes = torch.Tensor(get_slopes(attn_heads))
# In the next line, the part after the * is what constructs the diagonal matrix
# (right matrix in Figure 3 in the paper).
# If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3,
# but one where all rows are identical.
# This works because the softmax operation is invariant to translation,
# and our bias functions are always linear.
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(
0
).unsqueeze(0).expand(attn_heads, -1, -1)
alibi = alibi.view(attn_heads, 1, maxpos)
# Now threshold arbitrarily, report the mask
return alibi < threshold
def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
return config.make_layout(seq_len)
def layout_to_pattern(layout: torch.Tensor, block_size: int):
r"""
create a pattern of shape [heads, seq, seq] out of a blocksparse
layout of shape [heads, seq/block_size, seq/block_size]
"""
return torch.kron(layout, torch.ones(block_size, block_size))

View File

@@ -0,0 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
import torch.nn as nn
from xformers.components.attention import AttentionMask
@dataclass
class AttentionConfig:
"""Parameters required for all Attentions.
Can accept and store extra parameters.
"""
name: str # the registered name for this attention mechanism
dropout: float # dropout probability
Self = TypeVar("Self", bound="Attention")
# Define the common interface, every attention block needs to derive from it
class Attention(nn.Module, metaclass=ABCMeta):
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
_causal_mask: Optional[AttentionMask] = None
@abstractmethod
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
super().__init__()
# Requires the inputs to be projected
self.requires_input_projection = True
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
self.requires_head_dimension = False
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
self.requires_separate_masks = False
# Requires that K and Q have the same sequence length
self.requires_same_k_q_dimensions = False
# Whether the attention owns the single head/multihead mechanism
# so that the MHA wrapper should skip it
self.requires_skip_multi_head = False
# This attention requires a context length which is squared, often due to 2D pooling
self.requires_squared_context = False
# Whether this attention mechanism supports attention masks
self.supports_attention_mask = True
self.supports_key_padding_mask = False
@classmethod
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)
@abstractmethod
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
"""
If the sequence is shorter than the mask, return a padded view
"""
if x.shape[-2] != mask.shape[-1]:
assert x.shape[-2] < mask.shape[-1], (
"Sequence is bigger than the provided mask, cannot infer what to do with it."
" Please update your attention mask"
)
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
return x

View File

@@ -0,0 +1,190 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
import torch
from xformers import _is_triton_available
from xformers.components.attention import Attention, AttentionConfig, register_attention
logger = logging.getLogger("xformers")
_is_blocksparse_available = _is_triton_available()
if _is_blocksparse_available:
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore
from xformers.triton.utils import gpu_capabilities_older_than_70
# Blocksparse requires Tensor cores
if gpu_capabilities_older_than_70():
logger.warning(
"Blocksparse is not available: the current GPU does not expose Tensor cores"
)
_is_blocksparse_available = False
if _is_blocksparse_available:
@dataclass
class BlockSparseAttentionConfig(AttentionConfig):
layout: torch.Tensor # The dimensions of the random features
block_size: int
dropout: float
num_heads: int
@register_attention("blocksparse", BlockSparseAttentionConfig)
class BlockSparseAttention(Attention):
r"""
Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout.
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
If some dimensions are missing, we assume that the same layout is to be used across heads.
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
"""
def __init__(
self,
layout: torch.Tensor,
block_size: int = 16,
dropout: float = 0.0,
num_heads: int = 1, # optional, used to adapt the layout if in need
causal: bool = False,
*args,
**kwargs,
):
if layout.dim() == 2:
logger.warning(
"The layout passed is lacking a head dimension and a batch dimension"
)
logger.warning(
"Now assuming that the same layout is to be used across all heads"
)
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
logger.warning(f"New layout dimensions: {layout.shape}")
assert block_size in (
16,
32,
64,
128,
), "Only block sizes in [16, 32, 64, 128] are supported"
super().__init__()
self.causal = causal
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Pure blocksparse data
self.layout = layout
self.block_size = block_size
# make sure that the head dimension is not folded down with the batch
self.requires_head_dimension = True
# key padding mask and attention mask must be passed in separately
self.requires_same_k_q_dimensions = True
# The underlying triton op does not support per element attention mask
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def create_triton_kernels(self, device):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=device,
)
self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=device,
)
self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=device,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
assert (
"att_mask" not in kwargs.keys() and "att_mask" not in args
), "This attention does not support an attention mask, but you can specify causality."
r"""
A thin wrap around the Triton blockparse attention operation
.. note: Per element attention mask is not supported, but you can specify causality
"""
# Delayed triton init, to make sure that we get the right device
# Infer device from query
if not hasattr(self, "sparse_dot_sdd"):
self.create_triton_kernels(q.device)
assert (
q.shape[-2] == k.shape[-2]
), "Blocksparse requires the same dimensions for K and Q for now"
assert (
q.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
q.shape[-2] % self.block_size
) == 0, "Sequence length {} must be a multiple of block size {}".format(
q.shape[-2], self.block_size
)
# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
# When the computations are block sparse, the matrix types change along the way:
# - (sparse) attention matrix = (dense) Kt * (dense) Q
q = q / math.sqrt(q.size(-1))
sparse_att_mat = self.sparse_dot_sdd(q, k)
# - softmax on the sparse attention matrix
sparse_att_mat = self.sparse_softmax(
sparse_att_mat, scale=scale, is_causal=self.causal
)
sparse_att_mat = self.attn_drop(sparse_att_mat)
# - then (dense) attention is (sparse) attention matrix * dense (value)
a = self.sparse_dot_dsd(sparse_att_mat, v)
return a

View File

@@ -0,0 +1,341 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Credits: this is heavily inspired by the official implementation, present in
# https://github.com/sarthmit/Compositional-Attention
# Original author: Sarthak Mittal
# This is a simplified version, for the sake of clarity, and because some features could be exposed later
# via the library directly.
# In particular, code paths for TPUs, quantization and gumbel softmax have been removed
# We're also following the same dimension ordering as in the rest of the xformers library
# which is to say [Batch, Sequence, Embedding] wherever possible
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.input_projection import InputProjection, InputProjectionConfig
def _either_or(a: Optional[int], b: int) -> int:
return a if a is not None else b
@dataclass
class CompositionalAttentionConfig(AttentionConfig):
dim_model: int
num_heads: int
dim_attn: Optional[int] = None
num_rules: Optional[int] = None
dim_key: Optional[int] = None
dim_value: Optional[int] = None
dim_selection: Optional[int] = None
dropout: float
qk_rule: bool = False
nonlinear: bool = False
q_compose: bool = False
bias: bool = True
causal: Optional[bool] = False
in_proj_container: Optional[InputProjection] = None
use_separate_proj_weight: Optional[bool] = False
@register_attention("compositional", CompositionalAttentionConfig)
class CompositionalAttention(Attention):
"""Compositional Attention, as proposed in
"Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
A key insight from this proposal is that the attention mechanism can be conceived as two steps:
a search and a retrieval operation. When queried, the model can search for the most relevant information
(Softmax(QKt)), then retrieve information given the Value.
Contrary to the original attention proposal, which does not consider interactions in between heads,
the compositional attention will consider all possible interactions and softmax over that dimension,
so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
may not fit in memory.
Args:
dim_model: dimension of the incoming latent space
num_heads: number of heads *for the search operation*
dim_attn: dimension (embedding) of the attention
num_rules: number of rules to consider *for the retrieval operation*
dim_selection: dimension of the scoring/selection space for the retrievals
dim_key, dim_value: dimensions of K and V, if different from Q
dropout: attention dropout probability
qk_rule: QK product will drive the retrieval process
nonlinear: use a non linear method to score the retrievals
bias: use bias in the initial projection step
causal: causal computations (attend to the past only)
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
"""
def __init__(
self,
dim_model: int,
num_heads: int,
dim_attn: Optional[int] = None,
num_rules: Optional[int] = None,
dim_selection: Optional[int] = None,
dim_key: Optional[int] = None,
dim_value: Optional[int] = None,
dropout=0.0,
qk_rule=False,
nonlinear=False,
q_compose=False,
in_proj_container: Optional[InputProjection] = None,
use_separate_proj_weight: Optional[bool] = False,
bias=True,
causal=False,
*_,
**__,
):
super().__init__()
# Define the inherited flags
self.requires_skip_multi_head = (
True # This attention owns the multi-head mechanism
)
# Handle defaults / undefined values
self.dim_model = dim_model
num_rules = _either_or(num_rules, num_heads)
dim_selection = _either_or(dim_selection, dim_model // num_heads)
# All the initial definition plumbing
dim_attn = _either_or(dim_attn, dim_model)
dim_key = _either_or(dim_key, dim_model)
dim_value = _either_or(dim_value, dim_model)
self.in_proj_container = (
in_proj_container
if in_proj_container is not None
else InputProjection(
query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
if use_separate_proj_weight
else None,
value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
if use_separate_proj_weight
else None,
)
)
self.num_heads = num_heads
self.num_rules = num_rules
self.qk_rule = qk_rule
self.dim_selection = dim_selection
self.nonlinear = nonlinear
self.q_compose = q_compose
self.dropout_module = nn.Dropout(dropout)
self.dim_head = dim_model // num_heads
self.value_dim = dim_attn // num_rules
assert (
self.value_dim * num_rules == dim_attn
), "value_dim must be divisible by num_rules"
self.scaling = self.dim_head**-0.5
self.scaling_values = self.dim_selection**-0.5
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
if self.qk_rule:
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
else:
if self.q_compose:
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
else:
self.value_q = nn.Linear(
dim_model, self.dim_selection * self.num_heads, bias=bias
)
if self.nonlinear:
self.score_network: nn.Module = nn.Sequential(
nn.Linear(
self.dim_selection + self.value_dim,
self.dim_selection,
bias=bias,
),
nn.ReLU(),
nn.Linear(self.dim_selection, 1, bias=bias),
)
else:
self.score_network = nn.Linear(
self.dim_selection + self.value_dim, 1, bias=bias
)
self.causal = causal
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
self._reset_parameters()
def _reset_parameters(self):
# NOTE: in_proj_container is already initialized
if self.qk_rule:
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.value_q.weight)
if self.nonlinear:
nn.init.xavier_uniform_(self.score_network[0].weight)
nn.init.xavier_uniform_(self.score_network[2].weight)
else:
nn.init.xavier_uniform_(self.score_network.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
q: Tensor,
k: Tensor,
v: Tensor,
att_mask: Optional[Tensor] = None,
*args,
**kwargs,
) -> Tensor:
"""
Input shape: Time x Batch x Channel
Args:
att_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
B, Sq, E = q.shape
_, Sk, _ = k.shape
assert E == self.dim_model
# First define projected query/key/values
# We keep the projected and original tensors in flight,
# depending on the options the original values could be reused
q_unprojected = q
q, k, v = self.in_proj_container(query=q, key=k, value=v)
q *= self.scaling
# Init causal mask if needed, now that we know the context length
if self.causal and (
self._causal_mask is None or self._causal_mask.shape[0] != Sk
):
self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
# Convenience, create an attention mask if a tensor was passed
# This sanitizes different mask types being passed, from now on it's additive
if isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask_additive: Optional[AttentionMask] = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
else:
att_mask_additive = None
# Handle the attention and key padding masks
if self._causal_mask is not None:
# Optionally add the causal mask
if att_mask_additive is not None:
att_mask_additive += self._causal_mask
else:
att_mask_additive = self._causal_mask
# Flatten the heads or the rules
q = (
q.view(B, Sq, self.num_heads, self.dim_head)
.movedim(2, 1)
.flatten(0, 1) # [B * num_heads, Sq, dim_head]
)
k = (
k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
) # [B * num_heads, Sk, dim_head]
v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
# Compute the search: Softmax(QKt)
attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
if att_mask_additive is not None:
attn_weights += att_mask_additive.values
attn_weights = _softmax(attn_weights, causal=self.causal)
attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
attn_probs = self.dropout_module(attn_weights)
# Now compute the information retrieval
# keep all the heads in flight, we'll score the different possibilities
# - compute all the possible retrievals
v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
attn_probs = attn_probs.unsqueeze(2)
attn = torch.matmul(attn_probs, v).view(
B, self.num_heads, self.num_rules, Sq, self.value_dim
)
attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
# - search the most appropriate retrieval among all the values
if self.q_compose:
v_q = self.value_q(q.transpose(0, 1)).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
else:
v_q = self.value_q(q_unprojected).view(
B, Sq, self.num_heads, 1, self.dim_selection
)
if self.qk_rule:
v_q *= self.scaling_values
v_k = (
self.value_k(attn)
.view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
.transpose(4, 3)
.contiguous()
)
v_score = torch.matmul(v_q, v_k).view(
B, Sq, self.num_heads, self.num_rules, 1
)
else:
v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
v_in = torch.cat([attn, v_q], dim=-1)
v_score = self.score_network(v_in).view(
B, Sq, self.num_heads, self.num_rules, 1
)
v_score = F.softmax(v_score, dim=3)
# - extracted values are the original attention (inc. all the values) weighted by value score
attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
# Final attention projection, same as other mechanisms
attn = self.out_proj(attn)
return attn

View File

@@ -0,0 +1,342 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from contextlib import nullcontext
from functools import lru_cache
from typing import Optional, Union
import torch
from xformers import _has_cpp_library, _is_triton_available
from xformers.components.attention.attention_mask import AttentionMask
if _has_cpp_library:
from ._sputnik_sparse import SparseCS
if _is_triton_available():
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70
_is_blocksparse_available = (
_is_triton_available() and not gpu_capabilities_older_than_70()
)
if _is_blocksparse_available:
from xformers.components.attention.blocksparse import BlockSparseAttention
logger = logging.getLogger("xformers")
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
assert matrix.ndim == 3
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
nonzero = torch.nonzero(keep)
nnz = nonzero.shape[0]
# NOTE: need to make it a multiple of 4 for sputnik
nonzero = nonzero[: (nnz - nnz % divisible_by)]
i, j = nonzero.unbind(1)
output = torch.zeros_like(matrix)
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
output[bdim, i, j] = matrix[bdim, i, j]
return output
def _broadcast_batch(mask, batch_size):
if mask.ndim == 3:
return mask
assert mask.ndim == 2
mask = mask.coalesce()
values = mask.values()
indices = mask.indices()
nnz = len(values)
# strategy: repeat the indices and append the extra batch dimension to the indices
indices = indices.repeat(1, batch_size)
# now create the batch indices
batch_indices = torch.arange(batch_size, device=indices.device)
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
# put them together
indices = torch.cat([batch_indices[None, :], indices], dim=0)
# now repeat the values
values = values.repeat(batch_size)
size = (batch_size,) + mask.shape
return torch.sparse_coo_tensor(indices, values, size)
def _matmul_with_mask(
a: torch.Tensor,
b: torch.Tensor,
mask: Optional[Union[torch.Tensor, "SparseCS"]],
) -> torch.Tensor:
if mask is None:
return a @ b
if _has_cpp_library and mask.dtype == torch.bool:
if isinstance(mask, SparseCS):
return mask.matmul_with_mask(a, b)
if mask.is_sparse:
# perform broadcasting if needed
mask = _broadcast_batch(mask, a.shape[0])
# coalesced is not implemented for bool tensors, so need to cast
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
return torch.ops.xformers.matmul_with_mask(a, b, mask)
# Non optimized codepath
if _has_cpp_library:
assert not isinstance(mask, SparseCS)
att = a @ b
if mask.dtype == torch.bool:
assert not isinstance(mask, SparseCS)
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
# repeat if batch sizes don't match
if (
not isinstance(mask, SparseCS)
and mask.ndim == 3
and mask.shape[0] != att.shape[0]
and (att.shape[0] % mask.shape[0]) == 0
):
repeat_factor = att.shape[0] // mask.shape[0]
mask = mask.repeat([repeat_factor, 1, 1])
logger.info("Mismatched batch dimensions for mask, repeating mask.")
att += mask
return att
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
if _has_cpp_library and isinstance(a, SparseCS):
return a.softmax()
if a.is_sparse:
return torch.sparse.softmax(a, dim=a.ndim - 1)
if _is_triton_available():
return triton_softmax(a, mask=None, causal=causal)
else:
return torch.softmax(a, dim=a.ndim - 1)
if _has_cpp_library:
class SparseBMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
a = a.coalesce()
r = torch.bmm(a, b)
ctx.save_for_backward(a, b)
return r
@staticmethod
def backward(ctx, grad):
a, b = ctx.saved_tensors
# gradients w.r.t. a
ga = None
if ctx.needs_input_grad[0]:
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
# gradients w.r.t. b
gb = None
if ctx.needs_input_grad[1]:
gb = a.transpose(1, 2).bmm(grad)
return ga, gb
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Batch matrix multiply between a sparse matrix and a dense matrix
"""
assert a.ndim == b.ndim == 3
assert a.shape[0] == b.shape[0]
assert a.shape[2] == b.shape[1]
return SparseBMM.apply(a, b)
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
if _has_cpp_library:
if isinstance(a, SparseCS):
return a.spmm(b)
if a.is_sparse:
return _sparse_bmm(a, b)
return a @ b
def _apply_dropout(att, dropout):
if dropout is None:
return att
# Dropout chokes on sparse tensors
if _has_cpp_library:
if isinstance(att, SparseCS):
values = att.values.clone()
values = dropout(values)
att = SparseCS.wrap(
att.shape,
values,
att.row_indices,
att.row_offsets,
att.column_indices,
att._transp_info,
)
elif att.is_sparse:
att = att.coalesce()
values = att.values().clone() # protect against in-place dropout
values = dropout(values)
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
else:
# Simple dense case
att = dropout(att)
return att
# Non optimized vanilla dropout
att = dropout(att)
return att
def scaled_query_key_softmax(
q: torch.Tensor,
k: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
) -> torch.Tensor:
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
# this is needed due to limitations in sparse_bmm for now
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
q = q / math.sqrt(k.size(-1))
# Matmul with mask
if att_mask is not None and isinstance(att_mask, AttentionMask):
# Additive mask
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
else:
mask = att_mask
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
# Softmax to get the attention probabilities
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
att = _softmax(att, causal=is_causal)
return att
if _is_blocksparse_available:
# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache
blocks = seq_len // block_size
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(
layout=layout_fill, block_size=block_size, causal=True
)
def blocksparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
orig_dim = q.dim()
seq_len = q.shape[-2]
# Layout head dimension: 1 or batch size (q.shape[0])
layout_heads = 1
# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
if orig_dim == 3:
# Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh
# Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs)
if layout_heads == 1:
q = q.unsqueeze(1)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
else:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
# Dropout is a no-op in evaluation mode
if isinstance(dropout, torch.nn.Dropout):
blocksparse_attention.attn_drop = dropout
else:
blocksparse_attention.attn_drop = torch.nn.Dropout(0.0)
att = blocksparse_attention(q, k, v)
# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return att.flatten(0, 1)
return att
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
autocast_disabled = (
_has_cpp_library
and isinstance(att_mask, SparseCS)
or (att_mask is not None and att_mask.is_sparse)
)
seq_len = q.shape[-2]
# switch if:
# causal is required but mask is not sparse
# fp16 or under amp context
# sequence length is divisible by block size
# same seq len for K and Q
switch_to_blocksparse = (
_is_blocksparse_available
and (att_mask is not None and not att_mask.is_sparse)
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
)
if switch_to_blocksparse:
logger.info("Switching causal attention to Triton blocksparse...")
return blocksparse_attention(q, k, v, dropout, block_size)
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore
if autocast_disabled:
q, k, v = q.float(), k.float(), v.float()
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
# Optional dropout, could be part of the masking in the future
att = _apply_dropout(att, dropout)
# Get to the predicted values, for all heads
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
y = bmm(att, v)
return y

View File

@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.feature_maps import (
FeatureMap,
FeatureMapType,
SMHyperbolic,
SMOrf,
SMReg,
)
logger = logging.getLogger("xformers")
@dataclass
class FavorAttentionConfig(AttentionConfig):
causal: Optional[bool]
dim_features: Optional[int] = None # The dimensions of the random features
dim_head: Optional[
int
] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
iter_before_redraw: Optional[
int
] = None # The number of iterations before the random features are re-drawn from scratch
feature_map: Optional[FeatureMapType] = None
@register_attention("favor", FavorAttentionConfig)
class FavorAttention(Attention):
def __init__(
self,
causal: bool = False,
dropout: float = 0.0,
dim_features: Optional[int] = None,
dim_head: Optional[int] = None,
iter_before_redraw: Optional[int] = None,
feature_map_type: FeatureMapType = FeatureMapType.SMReg,
normalize_inputs: bool = False,
*_,
**__,
):
r"""
Kernelized attention, as proposed in Performers_
("Rethinking attention with performers." K. Choromanski et al. (2020).).
FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
Args:
dropout (float): the probability of an output to be randomly dropped at training time
dim_features (int): the dimension of the random features space
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
feature_map_type (FeatureMapType): the type of feature map being used,
for instance orthogonal random features.
.. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
"""
super().__init__()
self.causal = causal
self.iter_before_redraw = (
(2 * iter_before_redraw)
if iter_before_redraw is not None
else iter_before_redraw
) # This will be used for both key and query
self.normalize_inputs = normalize_inputs
self.feature_map_type = feature_map_type
self.attn_drop = nn.Dropout(dropout, inplace=True)
# Setup dimension-dependent variables
# Reasonable dimension default
if dim_features is None:
assert dim_head is not None, "dim_features or dim_head needs to be passed"
self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
self.dim_features = 2 * (
self.dim_features // 2
) # needs to be even for some variants
logger.info(
f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
)
else:
self.dim_features = dim_features
feature_map_constructor = {
FeatureMapType.SMHyp: SMHyperbolic,
FeatureMapType.SMReg: SMReg,
FeatureMapType.SMOrf: SMOrf,
}[self.feature_map_type]
feature_settings = {
"dim_features": self.dim_features,
"iter_before_redraw": self.iter_before_redraw,
"normalize_inputs": self.normalize_inputs,
}
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.supports_key_padding_mask = False
@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
# Only promote fp16 buffers, bfloat16 would be fine for instance
return x.float() if x.dtype == torch.float16 else x
@staticmethod
def _causal_attention(
k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Algorithm 1 in the paper
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
Grenorm = k_prime.unsqueeze(3) * ref_v
# Consolidate against the feature dimension
att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
# Cumulative sum over the sequence
att_raw = att_raw.cumsum(2)
att_norm = att_norm.cumsum(2)
return att_raw, att_norm
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*_,
**__,
):
# Project key and queries onto the feature map space
k_prime = self.feature_map(k)
q_prime = self.feature_map(q)
with autocast(enabled=False):
# The softmax kernel approximation for Favor will easily overflow
# Force the computations here to stay in fp32 for numerical stability
# Note that the dimensions are vastly reduced when compared to scaled_dot_product
k_prime = self._maybe_promote(k_prime)
q_prime = self._maybe_promote(q_prime)
v = self._maybe_promote(v)
if not self.causal:
att_normalization = q_prime @ (
k_prime.transpose(-2, -1) @ torch.ones_like(v)
)
att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
else:
# Actually compute attention
att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
# Normalize
att = att_raw / att_normalization
if self.attn_drop is not None:
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,26 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum
from .base import FeatureMap, FeatureMapConfig
from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg
class FeatureMapType(str, Enum):
SMOrf = "sm_orf"
SMHyp = "sm_hyp"
SMReg = "sm_reg" # regularized softmax kernel
__all__ = [
"SMOrf",
"SMReg",
"SMHyperbolic",
"NormDistribution",
"FeatureMapConfig",
"FeatureMap",
]

View File

@@ -0,0 +1,61 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import abstractmethod
from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar
import torch
"""
Feature maps allow for a given query or key to be encoded in a different space.
"""
Self = TypeVar("Self", bound="FeatureMap")
@dataclass
class FeatureMapConfig:
name: str
dim_features: int
iter_before_redraw: Optional[int]
normalize_inputs: Optional[bool]
epsilon: Optional[float]
class FeatureMap(torch.nn.Module):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int] = None,
normalize_inputs: bool = False,
epsilon: float = 1e-6,
):
super().__init__()
self.dim_features = dim_features
self.dim_feature_map = dim_features
self.iter_before_redraw = iter_before_redraw
self.features: Optional[torch.Tensor] = None
self.epsilon = epsilon
self.normalize_inputs = normalize_inputs
self._iter_counter = 0
@abstractmethod
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
raise NotImplementedError()
@classmethod
def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)
# Skip all Nones so that default values are used
fields = {k: v for k, v in fields.items() if v is not None}
return cls(**fields)

View File

@@ -0,0 +1,288 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum, auto
from typing import Optional
import torch
from torch.autograd.profiler import record_function
from .base import FeatureMap
"""
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
class NormDistribution(Enum):
Xi = auto()
Uniform = auto()
class SoftMaxPositiveEstimators(FeatureMap):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
self.softmax_temp = softmax_temp
# Handle the scaling from all kernels by √m.
# This normalizes for all the feature maps involved
self.h_scale = math.log(math.sqrt(self.dim_features))
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
with record_function("feature_map::pre_scale"):
# Re-draw counting logic
if (
(
self.iter_before_redraw is not None
and self._iter_counter > self.iter_before_redraw
)
or self.features is None
or self.features.device != x.device
):
# The feature map is actually using half the dimension, we'll concatenate + and - features
self._iter_counter = 1
self.features = self._get_feature_map(
x.shape[-1], self.dim_feature_map, x.device
)
features = self.features
assert features is not None
if features.dtype != x.dtype:
self.features = features.to(x.dtype)
self._iter_counter += 1
# Normalization / softmax
if self.softmax_temp < 0:
# A = exp(QK.t/√d), so each input will be scaled by √√d
self.softmax_temp = x.shape[-1] ** -0.25
x_scaled = x * self.softmax_temp
# Compute the scaling factors in logspace, applied from within the exponential
# - dimnish possible exponential overflow
# - remove a multiply across the batch, replace by an addition
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
if self.normalize_inputs:
# L0 normalize the exponential term, can be useful for numerical stability
# This ensures that features +- offset is below 1
self.offset -= norm_x_2.max(1, keepdim=True)[0]
# Return the scaled inputs, the rest depends on the kernel being used
return x_scaled
@staticmethod
@torch.no_grad()
def _get_random_ortho_matrix(
blocks: int,
dim: int,
device: torch.device,
norm_distribution: NormDistribution = NormDistribution.Uniform,
) -> torch.Tensor:
r"""
Generate a random matrix whose rows are exactly orthonormal
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
https://arxiv.org/pdf/math-ph/0609050v2.pdf
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
unique and the qr decomposition routines are biased towards numerical stability. See the above
paper for more information.
.. note: this does not follow the original implementation from the Performers authors.
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
"""
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
# Randomly scale the norms of the features, Xi distributed
if norm_distribution == NormDistribution.Xi:
# NOTE: This averages to sqrt(d)
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
Q, R = torch.linalg.qr(H)
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
if norm_distribution == NormDistribution.Xi:
return torch.diag_embed(norms) @ Q
return Q
class SMOrf(SoftMaxPositiveEstimators):
"""
"Positive random orthogonal features" softmax estimator,
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
assert self.features is not None
# Project onto the random feature map.
x_scaled = x_scaled @ self.features
return torch.exp(x_scaled + self.offset)
class SMHyperbolic(SoftMaxPositiveEstimators):
"""
"Positive random features hyperbolic" estimator, SMHyp+,
as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)
class SMReg(SoftMaxPositiveEstimators):
"""
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Uniform,
device=device,
).flatten(0, 1)
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
return (torch.diag(norms) @ features)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation + sample regularization)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
@register_attention("fourier_mix", AttentionConfig)
class FourierMix(Attention):
def __init__(self, dropout: float, *_, **__):
"""
FFT-based pseudo-attention mechanism, from
"
"FNet: Mixing Tokens with Fourier Transforms"
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.requires_input_projection = False
def forward(self, q: torch.Tensor, *_, **__):
# Guard against autocast / fp16, not supported by torch.fft.fft2
with autocast(enabled=False):
att = torch.fft.fft2(q).real
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
global_token_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class GlobalAttentionConfig(AttentionConfig):
attention_query_mask: torch.Tensor # Mark the queries which have global attention
causal: Optional[bool]
force_sparsity: Optional[bool]
@register_attention("global", GlobalAttentionConfig)
class GlobalAttention(Attention):
def __init__(
self,
dropout: float,
attention_query_mask: torch.Tensor,
causal: bool = False,
force_sparsity: bool = False,
*_,
**__,
):
r"""
Global attention, as proposed for instance in BigBird_ or Longformer_.
Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
any other query.
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
Args:
dropout (float): probability of an element to be zeroed
attention_query_mask (torch.Tensor): if true, this query can attend to all the others
"""
super().__init__()
assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
assert (
attention_query_mask.shape[1] == 1
and attention_query_mask.shape[0] > attention_query_mask.shape[1]
), "A N x 1 query mask is expected"
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
self.force_sparsity = force_sparsity
if causal:
self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
self.attention_mask = (
sparsify(self.attention_mask)
if self.force_sparsity
else maybe_sparsify(self.attention_mask)
)
# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*_,
**__,
):
# Make sure that the mask is on the right device
if self.attention_mask.device != q.device:
self.attention_mask = self.attention_mask.to(q.device)
# Mask-aware attention
if att_mask is not None:
if att_mask.dtype == torch.bool and isinstance(
self.attention_mask, AttentionMask
):
if not isinstance(att_mask, AttentionMask):
att_mask = AttentionMask.from_bool(att_mask)
mask = self.attention_mask + att_mask
else:
mask = self.attention_mask & att_mask
else:
mask = self.attention_mask
# Handle q/k/v which would not fit the mask
seq_len = q.shape[-2]
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
# Normal attention with the global tokens mask
att = scaled_dot_product_attention(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)
# Take into account an hypothetical padding
return att[:, :seq_len, :]

View File

@@ -0,0 +1,78 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
import torch
from xformers.components.attention import Attention, AttentionConfig, register_attention
def calc_rel_pos(n: int):
# Adapted from LucidRains
# https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n]
rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
return rel_pos
@dataclass
class LambdaLayerConfig(AttentionConfig):
seq_len: int # dimension of the input sequence
dim_head: int
@register_attention("lambda", LambdaLayerConfig)
class LambdaLayer(Attention):
def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
"""
Attention approximation using Lambda layers, from
"Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
"""
super().__init__()
# Possible extensions:
# - support different dimensions for key and queries
# - support varying dimensions in between inputs and outputs
# - support u hyperparam
self.rel_pos_emb = torch.nn.Parameter(
torch.randn(2 * seq_len - 1, int(dim_head))
)
self.rel_pos = calc_rel_pos(seq_len)
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
# Properties specific to this attention mechanism
self.requires_same_k_q_dimensions = True
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
"""..NOTE: We're reusing the einsum notation suggested by the paper, changed in that
heads are folded in the batch dimension"""
content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v)
content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda)
rel_pos_emb = self.rel_pos_emb[self.rel_pos]
# Handle real sequence length being possibly smaller
seq_len = q.shape[1]
rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :]
# Compute the position lambda for every possible combination in one go, then compute the
# position related contribution
position_lambdas = torch.einsum(
"mnk,bnv->bnkv", rel_pos_emb, v
) # one lambda per position
position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
att = content_output + position_output
att = self.attn_drop(att)
return att

View File

@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LinformerSelfAttentionConfig(AttentionConfig):
seq_len: int # dimension of the input sequence
k: Optional[int] # dimension of the internal space
@register_attention("linformer", LinformerSelfAttentionConfig)
class LinformerAttention(Attention):
def __init__(
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
):
"""
Linformer attention mechanism,
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
The original notation is kept as is.
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
"""
super().__init__()
if k is None:
k = seq_len // 4
self.k = k
self.E = nn.Linear(seq_len, k, bias=False)
self.F = nn.Linear(seq_len, k, bias=False)
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.seq_len = seq_len
# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = True
# This attention does not support attention masks
self.supports_attention_mask = False
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
# Handle a smaller dimension than expected
padding = 0
if q.shape[1] < self.seq_len:
padding = self.seq_len - q.shape[1]
pad_dims = (0, 0, 0, padding)
q = torch.nn.functional.pad(q, pad_dims)
k = torch.nn.functional.pad(k, pad_dims)
v = torch.nn.functional.pad(v, pad_dims)
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
y = scaled_dot_product_attention(
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
)
y = self.attn_drop(y)
return y[:, :-padding, :] if padding > 0 else y

View File

@@ -0,0 +1,120 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
local_1d_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LocalAttentionConfig(AttentionConfig):
causal: Optional[bool] = None
window_size: Optional[int] = None
force_sparsity: Optional[bool] = None
@register_attention("local", LocalAttentionConfig)
class LocalAttention(Attention):
def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
window_size: int = 5,
force_sparsity: bool = False,
*args,
**kwargs,
):
r"""
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
Args:
dropout (float): the probability of an output to be randomly dropped at training time
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
window_size (int): the overall window size for local attention.
Odd number is expected if the mask is not causal, as the window size will be evenly
distributed on both sides of each query
.. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
.. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
"""
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.force_sparsity = force_sparsity
if not self.causal:
assert (
window_size % 2 == 1
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
self.window_size = window_size
self.attention_mask: Optional[torch.Tensor] = None
self.requires_same_k_q_dimensions = True
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
mask = local_1d_pattern(shape[1], window_size)
if self.causal:
mask &= causal_1d_pattern(shape[1])
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
return mask
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
# Local window attention masking
if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
self.attention_mask = self._get_local_mask(q.shape).to(q.device)
# Take into account the optional user mask
if att_mask is None:
mask = self.attention_mask
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.attention_mask & att_mask
return scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
)

View File

@@ -0,0 +1,295 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
from xformers.components.attention.utils import (
bool_mask_to_additive,
iterative_pinv,
reshape_key_padding_mask,
)
logger = logging.getLogger("xformers")
@dataclass
class NystromSelfAttentionConfig(AttentionConfig):
"""
num_heads Number of heads.
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
causal Apply a causal mask, in that the attention cannot be applied to the future.
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
inverse, otherwise use standard torch inverse.
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
method from (Razavi et al. 2014).
False if using exact coefficient computation (leads to faster convergence).
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
v_skip_connection A module that will take V as input and will be added as a skip connection to the
softmax approximation. A skip connection is added in the paper to help with training.
conv_kernel_size Kernel size for convolution optionally added to help in training.
If v_skip_connection is not specified, this will be used to define the default
depth wise convolution used as a skip connection.
If both conv_kernel_size and v_skip_connection are None, no skip connection will
be added.
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
"""
num_heads: int
num_landmarks: Optional[int]
landmark_pooling: Optional[nn.Module]
causal: Optional[bool]
pinverse_original_init: Optional[bool]
inv_iterations: Optional[int]
v_skip_connection: Optional[nn.Module]
conv_kernel_size: Optional[int]
use_razavi_pinverse: Optional[bool]
class AvgPool(nn.Module):
def __init__(self, n: int):
super().__init__()
self.n = n
def forward(self, x: torch.Tensor):
# Average independently for every segment in the sequence dimension
seq_len = x.shape[1]
head_dim = x.shape[2]
segments = seq_len // self.n
assert segments > 0, "num_landmarks should be smaller than the sequence length"
# Dimensions are a match
if seq_len % self.n == 0:
return x.reshape(
-1,
self.n,
segments,
head_dim,
).mean(dim=-2)
# Handle the last segment boundary being off
n_round = self.n - seq_len % self.n
x_avg_round = (
x[:, : n_round * segments, :]
.reshape(-1, n_round, segments, head_dim)
.mean(dim=-2)
)
x_avg_off = (
x[:, n_round * segments :, :]
.reshape(-1, self.n - n_round, segments + 1, head_dim)
.mean(dim=-2)
)
return torch.cat((x_avg_round, x_avg_off), dim=-2)
@register_attention("nystrom", NystromSelfAttentionConfig)
class NystromAttention(Attention):
# TODO: update defaults for use_razavi_pinverse and inv_iterations
def __init__(
self,
dropout: float,
num_heads: int,
num_landmarks: int = 64,
landmark_pooling: Optional[nn.Module] = None,
causal: bool = False,
use_razavi_pinverse: bool = True,
pinverse_original_init: bool = False,
inv_iterations: int = 6, # recommended default in paper was 6.
v_skip_connection: Optional[nn.Module] = None,
conv_kernel_size: Optional[int] = None,
*args,
**kwargs,
):
"""
Nystrom attention mechanism, from Nystromformer_.
::
"A Nystrom-based Algorithm for Approximating Self-Attention."
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
Reference codebase: https://github.com/mlpen/Nystromformer
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
"""
super().__init__()
# merged key padding mask and attention mask is not accepted
self.requires_separate_masks = True
self.num_landmarks = num_landmarks
# TODO: should be able to not have to pass in num_heads
self.num_heads = num_heads
self.use_razavi_pinverse = use_razavi_pinverse
self.pinverse_original_init = pinverse_original_init
self.inv_iterations = inv_iterations
self.attn_drop = nn.Dropout(dropout)
self.skip_connection = v_skip_connection
self.causal = causal
if self.skip_connection is None and conv_kernel_size is not None:
self.skip_connection = nn.Conv2d(
in_channels=self.num_heads,
out_channels=self.num_heads,
kernel_size=(conv_kernel_size, 1),
padding=(conv_kernel_size // 2, 0),
bias=False,
groups=self.num_heads,
)
if landmark_pooling is not None:
self.landmark_pooling = landmark_pooling
else:
self.landmark_pooling = AvgPool(n=self.num_landmarks)
# Optional lower triangular masks for causal attention
self.causal_mask_1: Optional[torch.Tensor] = None
self.causal_mask_2: Optional[torch.Tensor] = None
self.causal_mask_3: Optional[torch.Tensor] = None
# This attention does not support attention masks
self.supports_attention_mask = False
self.supports_key_padding_mask = True
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
*args,
**kwargs,
):
r"""
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
"""
batched_dim = k.size(0)
seq_len = k.size(-2)
tt = {"dtype": q.dtype, "device": q.device}
if key_padding_mask is not None:
if key_padding_mask.dtype == torch.bool:
logger.warning(
"Bool mask found, but an additive mask is expected. Converting but this is slow"
)
key_padding_mask = bool_mask_to_additive(key_padding_mask)
if key_padding_mask.ndim == 2:
key_padding_mask = reshape_key_padding_mask(
key_padding_mask, batched_dim
)
zeros = torch.zeros_like(key_padding_mask)
ones = torch.ones_like(key_padding_mask)
is_masked = torch.isinf(-key_padding_mask)
# _mask takes 1 if the token is not padded, otherwise 0.
_mask = torch.where(is_masked, zeros, ones)
_mask = _mask.transpose(2, 1)
assert _mask.shape == (batched_dim, q.shape[1], 1)
# Mask q and k before pooling
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
q = q * _mask
k = k * _mask
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
)
if self.num_landmarks >= seq_len:
mask: Optional[torch.Tensor] = None
if self.causal:
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
if key_padding_mask is not None:
mask = key_padding_mask if mask is None else mask + key_padding_mask
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
else:
q_landmarks = self.landmark_pooling(q)
k_landmarks = self.landmark_pooling(k)
if self.causal and (
self.causal_mask_1 is None
or (batched_dim, seq_len, self.num_landmarks)
!= self.causal_mask_1.size()
):
self.causal_mask_1 = self._triu_mask(
batched_dim, seq_len, self.num_landmarks, **tt
)
self.causal_mask_2 = self._triu_mask(
batched_dim, self.num_landmarks, self.num_landmarks, **tt
)
self.causal_mask_3 = self._triu_mask(
batched_dim, self.num_landmarks, seq_len, **tt
)
mask_3: Optional[torch.Tensor] = self.causal_mask_3
if key_padding_mask is not None:
mask_3 = (
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
)
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
kernel_2 = scaled_query_key_softmax(
q=q_landmarks, k=k_landmarks, att_mask=None
)
kernel_3 = scaled_dot_product_attention(
q=q_landmarks, k=k, v=v, att_mask=mask_3
)
kernel_2_inv = (
iterative_pinv(
kernel_2, self.inv_iterations, self.pinverse_original_init
)
if self.use_razavi_pinverse
else torch.linalg.pinv(kernel_2)
)
x = torch.matmul(
torch.matmul(
kernel_1,
kernel_2_inv,
),
kernel_3,
)
if self.skip_connection:
# Assumption here is that v is 3D.
v_conv = self.skip_connection(
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
)
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
x = self.attn_drop(x)
return x
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
device = kwargs["device"]
dtype = kwargs["dtype"]
return torch.triu(
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
diagonal=1,
).expand(
dim_1, -1, -1
) # micro optim, save memory on the batch dimension

View File

@@ -0,0 +1,324 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Union
import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
import torch.nn.functional as Fn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import (
scaled_dot_product_attention,
scaled_query_key_softmax,
)
logger = logging.getLogger("xformers")
class LandmarkSelection(str, Enum):
Orthogonal = "orthogonal"
KMeans = "kmeans"
KMeans_Spherical = "kmeans_spherical"
Random = "random"
@dataclass
class OrthoformerAttentionConfig(AttentionConfig):
"""
num_landmarks Number of landmarks to use for softmax approximation.
subsample_fraction Percentage of q_samples matrix to sample per iteration
landmark_selection Landmark selection strategy
"""
num_landmarks: Optional[int]
subsample_fraction: Optional[float]
landmark_selection: Optional[LandmarkSelection]
@register_attention("orthoformer", OrthoformerAttentionConfig)
class OrthoFormerAttention(Attention):
def __init__(
self,
dropout: float,
num_landmarks: int = 32,
subsample_fraction: float = 1.0,
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
*args,
**kwargs,
):
"""
Orthoformer_ attention mechanism.
::
"Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
C., Vedaldi, A., Henriques, J. (2021)
Reference codebase: https://github.com/facebookresearch/Motionformer
.. _Orthoformer: https://arxiv.org/abs/2106.05392
"""
super().__init__()
self.num_landmarks = num_landmarks
self.attn_drop = nn.Dropout(dropout)
self.subsample_fraction = subsample_fraction
self.landmark_selection = landmark_selection
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
):
N = k.shape[1]
if self.num_landmarks == N:
# Default attention
x = scaled_dot_product_attention(q, k, v, att_mask)
else:
with torch.no_grad(), profiler.record_function("select landmarks"):
if self.landmark_selection == LandmarkSelection.Orthogonal:
landmarks = self._compute_orthogonal_landmarks(q)
elif self.landmark_selection == LandmarkSelection.Random:
half_L = self.num_landmarks // 2
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
elif self.landmark_selection == LandmarkSelection.KMeans:
landmarks = self._cluster_landmarks(q)
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
landmarks = self._cluster_landmarks(q, spherical=True)
if att_mask is not None:
logger.warning(
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
The two are typically not compatible"
)
# FIXME: Should we still accept a mask in that case ?
att_mask = None
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
# like it could be uninitialized.
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
# like it could be uninitialized.
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
x = self.attn_drop(x)
return x
def _cluster_landmarks(
self,
q: torch.Tensor,
spherical: bool = False,
num_iters: int = 6,
) -> torch.Tensor:
"""
Construct set of landmarks by recursively selecting new landmarks
that are maximally orthogonal to the existing set.
Returns near orthogonal landmarks with shape (B, M, D).
"""
num_landmarks = min(self.num_landmarks, q.shape[1])
if self.subsample_fraction < 1.0:
num_samples = max(
int(self.subsample_fraction * q.size(-2)), num_landmarks
) # Need at least M/2 samples of queries and keys
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
else:
q_samples = q # (B, N, D)
if spherical:
q_samples_normalized = Fn.normalize(
q_samples, p=2, dim=-1
) # may need to change default eps to eps=1e-8 for mixed precision compatibility
landmarks = self._kmeans_spherical(
q_samples_normalized, num_landmarks, num_iters
)
else:
landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
return landmarks # (B, M, D)
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
"""
Arguments:
x: (B, N, D)
K: number of clusters
num_iters: the number of kmeans updates
"""
B, N, D = x.size()
assert K <= N, f"{K} > {N}"
c = x[
:, torch.randperm(N, device=x.device)[:K], :
].clone() # initialisation for the centroids
with profiler.record_function("kmeans"):
x_i = x.view(B, N, 1, D)
c_j = c.view(B, 1, K, D)
counts = c.new_zeros(B, K)
ones = x.new_ones((B, N))
for _ in range(num_iters):
# E step: assign points to the nearest cluster
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
cl = D_ij.argmin(
dim=-1, keepdim=True
).long() # (B, N, 1) index of point to nearest cluster
# M step: update the centroids
c.zero_()
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
counts.fill_(1e-6) # avoid div0
counts.scatter_add_(
-1, cl.squeeze(-1), ones
) # number of points per cluster
c.divide_(counts.unsqueeze(-1)) # compute the average
return c
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
"""
Arguments:
x: (B, N, D)
"""
B, N, D = x.size()
assert K <= N, f"{K} > {N}"
# initialisation for the centroids
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
with profiler.record_function("kmeans_spherical"):
counts = c.new_zeros(B, K)
ones = x.new_ones((B, N))
for _ in range(num_iters):
# E step: assign points to the nearest cluster
D_ij = torch.matmul(
x, c.transpose(-2, -1)
) # (B, N, K) cosine similarity
cl = D_ij.argmax(
dim=-1, keepdim=True
).long() # (B, N, 1) index of point to nearest cluster
# M step: update the centroids
c.zero_()
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
counts.fill_(1e-6) # avoid div0
counts.scatter_add_(
-1, cl.squeeze(-1), ones
) # number of points per cluster
c.divide_(counts.unsqueeze(-1)) # compute the average
c = Fn.normalize(c, p=2, dim=-1) # renormalise
return c
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
"""
Construct set of landmarks by recursively selecting new landmarks
that are maximally orthogonal to the existing set.
Returns near orthogonal landmarks with shape (B, M, D).
"""
if self.subsample_fraction < 1.0:
# Need at least M samples of queries
num_samples = max(
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
)
q_samples = q[
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
]
else:
# (B, N, D)
q_samples = q
# may need to change default eps to eps=1e-8 for mixed precision compatibility
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
B, N, D = q_samples_normalized.shape
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
landmark_mask = torch.ones(
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
)
#  Get initial random landmark
random_idx = torch.randint(
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
)
selected_mask.scatter_(-2, random_idx, landmark_mask)
#  Selected landmarks
selected_landmarks = torch.empty(
(B, self.num_landmarks, D),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
selected_landmarks[:, 0, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
].view(B, D)
# Store computed cosine similarities
cos_sims = torch.empty(
(B, N, self.num_landmarks),
device=q_samples_normalized.device,
dtype=q_samples_normalized.dtype,
)
for M in range(1, self.num_landmarks):
with profiler.record_function("find new landmark"):
#  Calculate absolute cosine similarity between selected and unselected landmarks
# (B, N, D) * (B, D) -> (B, N)
cos_sims[:, :, M - 1] = torch.einsum(
"b n d, b d -> b n",
q_samples_normalized,
selected_landmarks[:, M - 1, :],
).abs()
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
cos_sim_set = cos_sims[:, :, :M]
#  Get orthogonal landmark: landmark with smallest absolute cosine similarity:
# set cosine similarity for already selected landmarks to > 1
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
# (B,) - want max for non
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
#  Add most orthogonal landmark to selected landmarks:
selected_landmarks[:, M, :] = q_samples_normalized[
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
].view(B, D)
#  Removed selected indices from non-selected mask:
selected_mask.scatter_(
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
)
# (B, M, D)
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
B, -1, D
)
return landmarks # (B, M, D)

View File

@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
@dataclass
class PoolingAttentionConfig(AttentionConfig):
pool_size: int # dimension of the input sequence
stride: Optional[int] # dimension of the internal space
padding: Optional[int]
@register_attention("pooling", PoolingAttentionConfig)
class Pooling(Attention):
def __init__(
self,
pool_size: int = 3,
stride: int = 1,
padding: Optional[int] = None,
*_,
**__,
):
"""
Pooling token mixing mechanism, as proposed in
`Metaformer is actually what you need for vision`_, Yu et al (2021).
The original notation is kept as is.
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
"""
super().__init__()
padding = padding if padding is not None else pool_size // 2
self.pool = nn.AvgPool2d(
pool_size,
stride=stride,
padding=pool_size // 2,
count_include_pad=False,
)
# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = False
# This attention does not support attention masks
self.supports_attention_mask = False
# This "attention" (token mixing) skips the multihead attention altogether
self.requires_skip_multi_head = True
self.requires_input_projection = False
# This operator does not really handle q,k,v
self.requires_same_k_q_dimensions = True
# This attention requires the 2d structure out of the context,
# implictly assumed to be a squared length
self.requires_squared_context = True
def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW
q = q.transpose(-2, -1).reshape(B, C, H, H)
# 2D pool
x_pool = self.pool(q) - q # compensate for the residual path
# Get back to B HW C
return x_pool.flatten(2, 3).transpose(-2, -1)

View File

@@ -0,0 +1,126 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
maybe_sparsify,
register_attention,
sparsify,
)
from xformers.components.attention.attention_patterns import (
causal_1d_pattern,
random_pattern,
)
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class RandomAttentionConfig(AttentionConfig):
r: Optional[
float
] # the ratio of keys that the query can attend to. 1.0 means dense attention
constant_masking: Optional[
bool
] # whether the randomness is per query or defined at construction time
force_sparsity: Optional[bool] # use sparsity in any case (potentially slower)
@register_attention("random", RandomAttentionConfig)
class RandomAttention(Attention):
def __init__(
self,
dropout: float,
causal: bool = False,
r: float = 0.01,
constant_masking: bool = True,
force_sparsity: bool = False,
*args,
**kwargs,
):
"""
"Random" attention, as proposed for instance in BigBird_.
Random means in that case that each query can attend to a random set of keys.
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
Args:
r (float): the ratio in [0,1] of keys that the query can attend to
constant_masking (bool): if true, keep the same random set for all queries.
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
"""
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.r = r
self.rand_attention_mask: Optional[torch.Tensor] = None
self.constant_masking = constant_masking
self.force_sparsity = force_sparsity
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
self.requires_same_k_q_dimensions = True
def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
sparsity = 1 - self.r
mask = random_pattern(shape[1], sparsity=sparsity)
if self.causal:
mask &= causal_1d_pattern(shape[1])
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
return mask
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
*args,
**kwargs,
):
# Rand masking
if not self.constant_masking or self.rand_attention_mask is None:
self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device)
# Mask-aware attention
if att_mask is not None:
if att_mask.dtype == torch.bool and isinstance(
self.rand_attention_mask, AttentionMask
):
mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
else:
if isinstance(att_mask, AttentionMask):
# Needed because & op not defined for SparseCS with AttentionMask
att_mask = att_mask.to_bool()
mask = self.rand_attention_mask & att_mask
else:
mask = self.rand_attention_mask
# Handle q/k/v which would not fit the mask
seq_len = q.shape[-2]
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
# Normal attention with the random mask
att = scaled_dot_product_attention(
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
)
# Take into account an hypothetical padding
return att[:, :seq_len, :]

View File

@@ -0,0 +1,134 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional, Union
import torch
from torch import nn
from xformers.components.attention import (
Attention,
AttentionConfig,
AttentionMask,
register_attention,
)
from xformers.components.attention.core import scaled_dot_product_attention
logger = logging.getLogger("xformers")
@dataclass
class ScaledDotProductConfig(AttentionConfig):
causal: Optional[bool]
seq_len: Optional[int]
to_seq_len: Optional[int]
@register_attention("scaled_dot_product", ScaledDotProductConfig)
class ScaledDotProduct(Attention):
r"""
Implementing the Scaled Dot-Product attention proposed in
`Attention is all you need`_, Vaswani et al.
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
"""
mask: Optional[AttentionMask]
def __init__(
self,
dropout: float = 0.0,
causal: bool = False,
seq_len: Optional[int] = None,
to_seq_len: Optional[int] = None,
*args,
**kwargs,
):
super().__init__()
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.causal = causal
self.seq_len = seq_len
if causal and seq_len is not None:
self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
else:
self.mask = None
# Properties specific to this attention mechanism
self.supports_attention_mask = True
self.supports_key_padding_mask = False
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
r"""
att_mask A 2D or 3D mask which ignores attention at certain positions.
- If the mask is boolean, a value of True will keep the value,
while a value of False will mask the value.
Key padding masks (dimension: batch x sequence length) and attention masks
(dimension: sequence length x sequence length OR batch x sequence length x sequence length)
can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
used for that merging.
- If the mask has the float type, then an additive mask is expected (masked values are -inf)
"""
# Convenience, create an attention mask if a tensor was passed
if att_mask is not None and isinstance(att_mask, torch.Tensor):
# By default we don't know of the causality, and a check would be expensive
att_mask = (
AttentionMask.from_bool(att_mask)
if att_mask.dtype == torch.bool
else AttentionMask(att_mask, is_causal=False)
)
# Handle a possibly deferred causal mask handling
mask = self.mask
if self.causal and self.mask is None:
mask = AttentionMask.make_causal(
seq_len=q.shape[-2],
to_seq_len=q.shape[-2],
device=q.device,
dtype=q.dtype,
)
# Merge the optional causal mask and the user-provided mask
if mask is not None:
mask = mask.to(dtype=q.dtype, device=q.device)
att_mask = att_mask + mask if att_mask is not None else mask
# Try to handle a case where the sequence is smaller than the mask
if (
att_mask is not None
and q.shape[-2] == k.shape[-2]
and q.shape[-2] < att_mask.shape[1]
):
if isinstance(att_mask, AttentionMask):
att_mask = att_mask.make_crop(seq_len=q.shape[-2])
else:
logger.error(
"Mismatching sparse attention mask and sequence length."
+ " Please pad the inputs or adjust the attention mask"
)
raise NotImplementedError
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
y = scaled_dot_product_attention(
q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
)
return y

View File

@@ -0,0 +1,812 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
The code has been adopted from DeepSpeed
(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py)
"""
import random
import torch
class SparsityConfig:
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
It contains shared property of different block-sparse sparsity patterns. However, each class
needs to extend it based on required property and functionality.
"""
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
"""Initialize the Sparsity Pattern Config.
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be
assigned a different sparsity layout; default is false and this will be satisfied
based on availability.
"""
self.num_heads = num_heads
self.block_size = block_size
self.different_layout_per_head = different_layout_per_head
self.num_layout_heads = num_heads if different_layout_per_head else 1
def setup_layout(self, seq_len):
"""Create layout tensor for the given sequence length
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout
of all head; initialized with zero
"""
if seq_len % self.block_size != 0:
raise ValueError(
f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!"
)
num_blocks = seq_len // self.block_size
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
layout = torch.zeros(
(self.num_heads, num_blocks, num_blocks), dtype=torch.int64
)
return layout
def check_and_propagate_first_head_layout(self, layout):
"""If all heads require same sparsity layout, it propagate first head layout to all heads
Arguments:
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head
"""
if not self.different_layout_per_head:
layout[1 : self.num_heads, :, :] = layout[0, :, :]
return layout
class DenseSparsityConfig(SparsityConfig):
"""Configuration class to store `Dense` configuration.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and
comprehension.
"""
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
"""Initialize the Dense Sparsity Pattern Config.
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison
and comprehension.
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: this is just for the sake of consistency with
other sparsity formats; can ignore it for DenseSparsityConfig
"""
super().__init__(num_heads, block_size, different_layout_per_head)
def make_layout(self, seq_len):
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
Arguments:
seq_len: required: an integer determining the underling sequence length;
must be <= max sequence length
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head; for dense everything is 1
"""
layout = self.setup_layout(seq_len)
layout[:, :, :] = 1
return layout
class FixedSparsityConfig(SparsityConfig):
"""Configuration class to store `Fixed` sparsity configuration.
For more details about this sparsity config, please see `Generative Modeling with
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_local_blocks=4,
num_global_blocks=1,
attention="bidirectional",
horizontal_global_attention=False,
num_different_global_patterns=1,
):
"""Initialize `Fixed` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be
assigned a different sparsity layout; default is false and this will be satisfied
based on availability.
num_local_blocks: optional: an integer determining the number of blocks in local attention
window.
num_global_blocks: optional: an integer determining how many consecutive blocks in a local
window is used as the representative of the window for global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global
representative of a local window, also attend to all other blocks. This is valid only if
attention type is `bidirectional`. Looking at the attention matrix, that means global
attention not only includes the vertical blocks, but also horizontal blocks.
num_different_global_patterns: optional: an integer determining number of different global
attentions layouts. While global attention can be fixed by which block/s are representative
of any local window, since there are multi-heads, each head can use a different global representative.
For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different
versions in which the first, Second, third, or forth block of each local window can be global
representative of that window. This parameter determines how many of such patterns we want.
Of course, there is a limitation based on num_local_blocks and num_global_blocks.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_local_blocks = num_local_blocks
if num_local_blocks % num_global_blocks != 0:
raise ValueError(
f"""Number of blocks in a local window, {num_local_blocks},
must be dividable by number of global blocks, {num_global_blocks}!"""
)
self.num_global_blocks = num_global_blocks
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
if attention != "bidirectional" and horizontal_global_attention:
raise ValueError(
'only "bi-directional" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
if num_different_global_patterns > 1 and not different_layout_per_head:
raise ValueError(
"""Number of different layouts cannot be more than one when you have set a single layout
for all heads! Set different_layout_per_head to True."""
)
if num_different_global_patterns > (num_local_blocks // num_global_blocks):
raise ValueError(
f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns},
cannot be larger than number of local window blocks divided by number of global blocks,
{num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!"""
)
self.num_different_global_patterns = num_different_global_patterns
def set_local_layout(self, h, layout):
"""Sets local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
for i in range(0, num_blocks, self.num_local_blocks):
end = min(i + self.num_local_blocks, num_blocks)
for row in range(i, end):
for col in range(
i, (row + 1 if self.attention == "unidirectional" else end)
):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Currently we set global blocks starting from the last block of a local window to the first one.
That means if a local window consists of 4 blocks and global attention size is one block, we use
block #4 in each local window as global. If we have different layout per head, then other heads
will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global
attentions, multiple head may have same global attentions.
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and
vertically.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
first_global_block_idx = (
self.num_local_blocks
- (1 + h % self.num_different_global_patterns) * self.num_global_blocks
)
# set all global blocks except the last one if (in last local window)
end = num_blocks - (num_blocks % self.num_local_blocks)
for i in range(first_global_block_idx, end, self.num_local_blocks):
# vertical global attention
first_row = 0 if self.attention == "bidirectional" else i
# (((i // self.num_local_blocks) + 1) * self.num_local_blocks)
# if (first_row < num_blocks):
layout[h, first_row:, i : i + self.num_global_blocks] = 1
# horizontal global attention; only in bidirectional attention
if self.horizontal_global_attention:
layout[h, i : i + self.num_global_blocks, :] = 1
# set last global blocks; handle possible short last local window
if end < num_blocks:
start = min(
end + first_global_block_idx, num_blocks - self.num_global_blocks
)
end = start + self.num_global_blocks
# vertical global attention
first_row = 0 if self.attention == "bidirectional" else start
# (((start // self.num_local_blocks) + 1) * self.num_local_blocks)
# if (first_row < num_blocks):
layout[h, first_row:, start:end] = 1
# horizontal global attention
if self.horizontal_global_attention:
layout[h, start:end, :] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class VariableSparsityConfig(SparsityConfig):
"""Configuration class to store `Variable` sparsity configuration.
This layout is an extension of FixedSparsityConfig in which:
- user can set random layout; default value is zero means no random block
- user can provide a list of local block sizes
- user can provide a list of global block indices.
For more details about `Fixed` sparsity config, please see `Generative Modeling with
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_random_blocks=0,
local_window_blocks=[4],
global_block_indices=[0],
global_block_end_indices=None,
attention="bidirectional",
horizontal_global_attention=False,
):
"""Initialize `Variable` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of sparse
self-attention is based on blocked sparse matrices. In which this parameter defines
size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a
different sparsity layout; default is false and this will be satisfied based on
availability. Currently this sparsity config can only assign single layout to all heads;
needs to be extended for different layout per head.
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
local_window_blocks: optional: a list of integers determining the number of blocks in each
local attention window. It assumes first number determines # of blocks in the first local
window, second the second window, ..., and the last number determines the number of blocks
in the remaining local windows.
global_block_indices: optional: a list of integers determining which blocks are considered
as global attention. Given indices, determine the blocks that all other token blocks
attend to and they attend to all other token blocks. Default value is only index 0.
Notice that if global_block_end_indices parameter is set, this parameter is used as
starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global
window blocks. By default this is not used. But if it is set, it must have the same size
of global_block_indices parameter, and combining this two parameters, for each index i,
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
considered as global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
horizontal_global_attention: optional: a boolean determining if blocks that are global
representative of a local window, also attend to all other blocks. This is valid only if
attention type is `bidirectional`. Looking at the attention matrix, that means global
attention not only includes the vertical blocks, but also horizontal blocks.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.local_window_blocks = local_window_blocks
self.global_block_indices = global_block_indices
if global_block_end_indices is not None:
if len(global_block_indices) != len(global_block_end_indices):
raise ValueError(
f"""Global block start indices length, {len(global_block_indices)}, must be same as
global block end indices length, {len(global_block_end_indices)}!"""
)
for _, (start_idx, end_idx) in enumerate(
zip(global_block_indices, global_block_end_indices)
):
if start_idx >= end_idx:
raise ValueError(
f"""Global block start index, {start_idx}, must be smaller than global block end
index, {end_idx}!"""
)
self.global_block_end_indices = global_block_end_indices
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
if attention != "bidirectional" and horizontal_global_attention:
raise ValueError(
'only "bi-directional" attentions can support horizontal global attention!'
)
self.horizontal_global_attention = horizontal_global_attention
def set_random_layout(self, h, layout):
"""Sets random attention layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless
`different_layout_per_head` parameter is set in which each head can have a different random
layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_random_blocks:
raise ValueError(
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
for row in range(0, num_blocks):
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_local_layout(self, h, layout):
"""Sets local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local layout is set
"""
num_blocks = layout.shape[1]
start_block_idx = 0
end_block_idx = 0
for block_size in self.local_window_blocks:
end_block_idx += block_size
end_block_idx = min(end_block_idx, num_blocks)
for row in range(start_block_idx, end_block_idx):
for col in range(
start_block_idx,
(row + 1 if self.attention == "unidirectional" else end_block_idx),
):
layout[h, row, col] = 1
start_block_idx += block_size
# if there is any remaining not attended part, use the lats local window block size as local
# window for the remaining applicable local windows
for i in range(start_block_idx, num_blocks, block_size):
end_block_idx = min(i + block_size, num_blocks)
for row in range(i, end_block_idx):
for col in range(
i,
(row + 1 if self.attention == "unidirectional" else end_block_idx),
):
layout[h, row, col] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if self.global_block_end_indices is None:
for idx in self.global_block_indices:
# if global block idx is in the range of the sequence blocks
if idx < num_blocks:
# global rows
if self.horizontal_global_attention:
layout[h, idx, :] = 1
# global columns
first_row = 0 if self.attention == "bidirectional" else idx
layout[h, first_row:, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(
zip(self.global_block_indices, self.global_block_end_indices)
):
# if global block idx is in the range of the sequence blocks
if start_idx < num_blocks:
end_idx = min(end_idx, num_blocks)
# global rows
if self.horizontal_global_attention:
layout[h, start_idx:end_idx, :] = 1
# global columns
first_row = 0 if self.attention == "bidirectional" else start_idx
layout[h, first_row:, start_idx:end_idx] = 1
return layout
def make_layout(self, seq_len):
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_local_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BigBirdSparsityConfig(SparsityConfig):
"""Configuration class to store `BigBird` sparsity configuration.
For more details about this sparsity config, please see `Big Bird: Transformers for
Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_random_blocks=1,
num_sliding_window_blocks=3,
num_global_blocks=1,
attention="bidirectional",
):
"""Initialize the BigBird Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of
sparse self-attention is based on blocked sparse matrices. In which this parameter
defines size of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned
a different sparsity layout; default is false and this will be satisfied based on
availability.
num_random_blocks: optional: an integer determining the number of random blocks in each
block row.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
local attention window.
num_global_blocks: optional: an integer determining how many consecutive blocks, starting
from index 0, are considered as global attention. Global block tokens will be attended
by all other block tokens and will attend to all other block tokens as well.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_random_blocks = num_random_blocks
self.num_sliding_window_blocks = num_sliding_window_blocks
self.num_global_blocks = num_global_blocks
if attention != "unidirectional" and attention != "bidirectional":
raise NotImplementedError(
'only "uni/bi-directional" attentions are supported for now!'
)
self.attention = attention
def set_random_layout(self, h, layout):
"""Sets random attention layout used by the given head in the sparse attention.
Note) By default, it assumes there will be a unique random block layout for all heads; unless
`different_layout_per_head` parameter is set in which each head can have a different random layout.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which random layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_random_blocks:
raise ValueError(
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
for row in range(0, num_blocks):
sample_range = (
range(0, num_blocks)
if self.attention == "bidirectional"
else range(0, row + 1)
)
rnd_cols = random.sample(sample_range, self.num_random_blocks)
layout[h, row, rnd_cols] = 1
return layout
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_sliding_window_blocks:
raise ValueError(
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than
overall number of blocks in a row, {num_blocks}!"""
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout_itc(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_global_blocks:
raise ValueError(
f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number
of blocks in a row, {num_blocks}!"""
)
# global rows
layout[h, 0 : self.num_global_blocks, :] = 1
# global columns
layout[h, :, 0 : self.num_global_blocks] = 1
if self.attention == "unidirectional":
# zero out anything attending to the future
layout = torch.tril(layout)
return layout
def make_layout(self, seq_len):
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_random_layout(h, layout)
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout_itc(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout
class BSLongformerSparsityConfig(SparsityConfig):
"""Configuration class to store edited `Longformer` sparsity configuration.
Note) this is a block-sparse version of the Longformer which is slightly different than original
Longformer; which is element-wise sparsity.
For more details about this sparsity config, please see `Longformer:
The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
"""
def __init__(
self,
num_heads,
block_size=16,
different_layout_per_head=False,
num_sliding_window_blocks=3,
global_block_indices=[0],
global_block_end_indices=None,
attention="bidirectional",
):
"""Initialize the edited `Longformer` Sparsity Pattern Config.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
Arguments:
num_heads: required: an integer determining number of attention heads of the layer.
block_size: optional: an integer determining the block size. Current implementation of sparse
self-attention is based on blocked sparse matrices. In which this parameter defines size
of such blocks, `Block X Block`.
different_layout_per_head: optional: a boolean determining if each head should be assigned a
different sparsity layout; default is false and this will be satisfied based on
availability.
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
local attention window.
global_block_indices: optional: a list of integers determining which blocks are considered
as global attention. Given indices, determine the blocks that all other token blocks
attend to and they attend to all other token blocks. Default value is only index 0.
Notice that if global_block_end_indices parameter is set, this parameter is used as
starting index of each global window.
global_block_end_indices: optional: a list of integers determining end indices of global
window blocks. By default this is not used. But if it is set, it must have the same size
of global_block_indices parameter, and combining this two parameters, for each index i,
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
considered as global attention.
attention: optional: a string determining attention type. Attention can be `unidirectional`,
such as autoregressive models, in which tokens attend only to tokens appear before them
in the context. Considering that, the upper triangular of attention matrix is empty as
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
any other tokens before or after them. Then, the upper triangular part of the attention
matrix is mirror of the lower triangular in the above figure.
"""
super().__init__(num_heads, block_size, different_layout_per_head)
self.num_sliding_window_blocks = num_sliding_window_blocks
self.global_block_indices = global_block_indices
self.attention = attention
if global_block_end_indices is not None:
if len(global_block_indices) != len(global_block_end_indices):
raise ValueError(
f"""Global block start indices length, {len(global_block_indices)}, must be same as
global block end indices length, {len(global_block_end_indices)}!"""
)
for _, (start_idx, end_idx) in enumerate(
zip(global_block_indices, global_block_end_indices)
):
if start_idx >= end_idx:
raise ValueError(
f"""Global block start index, {start_idx}, must be smaller than global block end
index, {end_idx}!"""
)
self.global_block_end_indices = global_block_end_indices
def set_sliding_window_layout(self, h, layout):
"""Sets sliding local attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
of all head in which local sliding window layout is set
"""
num_blocks = layout.shape[1]
if num_blocks < self.num_sliding_window_blocks:
raise ValueError(
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller
than overall number of blocks in a row, {num_blocks}!"""
)
w = self.num_sliding_window_blocks // 2
for row in range(0, num_blocks):
start = max(0, row - w)
end = min(row + w + 1, num_blocks)
layout[h, row, start:end] = 1
return layout
def set_global_layout(self, h, layout):
"""Sets global attention layout used by the given head in the sparse attention.
Arguments:
h: required: an integer determining head index
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
sparsity layout of all head; may not be completely set at this step
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
layout of all head in which global layout is set
"""
num_blocks = layout.shape[1]
if self.global_block_end_indices is None:
for idx in self.global_block_indices:
# if global block idx is in the range of the sequence blocks
if idx < num_blocks:
# global rows
layout[h, idx, :] = 1
# global columns
layout[h, :, idx] = 1
else:
for _, (start_idx, end_idx) in enumerate(
zip(self.global_block_indices, self.global_block_end_indices)
):
# if global block idx is in the range of the sequence blocks
if start_idx < num_blocks:
end_idx = min(end_idx, num_blocks)
# global rows
layout[h, start_idx:end_idx, :] = 1
# global columns
layout[h, :, start_idx:end_idx] = 1
if self.attention == "unidirectional":
layout = torch.tril(layout)
return layout
def make_layout(self, seq_len):
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
Arguments:
seq_len: required: an integer determining number of attention heads of the layer.
Return:
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer`
sparsity layout of all head
"""
layout = self.setup_layout(seq_len)
for h in range(0, self.num_layout_heads):
layout = self.set_sliding_window_layout(h, layout)
layout = self.set_global_layout(h, layout)
layout = self.check_and_propagate_first_head_layout(layout)
return layout

View File

@@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
def reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batched_dim: int
) -> torch.Tensor:
assert key_padding_mask.ndim == 2
batch_size, src_len = key_padding_mask.size()
num_heads = batched_dim // batch_size
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
def _reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
) -> torch.Tensor:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = (
key_padding_mask.view(batch_size, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(batch_size * num_heads, 1, src_len)
)
return key_padding_mask
# Combine the attention mask and key padding mask into a single mask
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
# Additive masking not yet supported
def maybe_merge_masks(
att_mask: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor],
batch_size: int,
src_len: int,
num_heads: int,
tgt_len: Optional[int] = None,
) -> Optional[torch.Tensor]:
if tgt_len is None:
tgt_len = src_len
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = _reshape_key_padding_mask(
key_padding_mask, batch_size, src_len, num_heads
)
if att_mask is None:
# make sure dimensions of key padding mask are the same as those expected for att_mask
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
# Assumption is that False means to mask.
elif att_mask.dtype == torch.bool:
att_mask = att_mask.logical_and(key_padding_mask)
else:
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
return att_mask
# Assumes that matrix passed in has had softmax applied to it.
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
"""
Computing the Moore-Penrose inverse.
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
matrix-matrix multiplications.
"""
i = torch.eye(
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
)
k = softmax_mat
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
if pinverse_original_init:
# This original implementation is more conservative to compute coefficient of Z_0.
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
else:
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
# convergence.
v = (
1
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
* k.transpose(-1, -2)
)
for _ in range(n_iter):
kv = torch.matmul(k, v)
v = torch.matmul(
0.25 * v,
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
)
return v
def bool_mask_to_additive(
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
assert (
mask.dtype == torch.bool
), "This util is meant to convert in between bool masks and additive ones"
mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_

View File

@@ -0,0 +1,96 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
@dataclass
class VisualAttentionConfig(AttentionConfig):
dim_model: int # dimension of the input sequence
class LKA(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv_spatial = nn.Conv2d(
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
)
self.conv1 = nn.Conv2d(dim, dim, 1)
def forward(self, x: torch.Tensor):
u = x.clone()
attn = self.conv0(x)
attn = self.conv_spatial(attn)
attn = self.conv1(attn)
return u * attn
@register_attention("visual", VisualAttentionConfig)
class Visual(Attention):
def __init__(
self,
dim_model: int,
*_,
**__,
):
"""
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
for the reference implementation
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
and the prior and posterior transformations (Conv2d and activation)
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
"""
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim_model, dim_model, 1),
nn.GELU(),
LKA(dim_model),
nn.Conv2d(dim_model, dim_model, 1),
)
# MHA related flags:
self.requires_same_k_q_dimensions = (
True # This mechanism only really supports self attention
)
self.supports_attention_mask = False
self.requires_skip_multi_head = (
True # This mechanism skips the multihead attention altogether
)
self.requires_squared_context = (
True # Recovering the 2D structure from context assumes squared content
)
self.requires_input_projection = (
False # This mechanism does not require that the MHA projects inputs
)
def forward(self, q: torch.Tensor, *_, **__):
# Expose the 2D token structure
B, HW, C = q.shape
H = int(math.sqrt(HW))
assert H * H == HW
x = q.transpose(-2, -1).reshape(B, C, H, H)
# Large kernel attention
residual = x.clone()
x = self.block(x)
x = x + residual
# Get back to B HW C
return x.flatten(2, 3).transpose(-2, -1)