First commit
This commit is contained in:
133
pkgs/xformers/components/attention/__init__.py
Normal file
133
pkgs/xformers/components/attention/__init__.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.utils import (
|
||||
generate_matching_config,
|
||||
get_registry_decorator,
|
||||
import_all_modules,
|
||||
)
|
||||
|
||||
from ._sputnik_sparse import SparseCS
|
||||
from .attention_mask import AttentionMask
|
||||
from .base import Attention, AttentionConfig # noqa
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
# CREDITS: Classy Vision registry mechanism
|
||||
|
||||
ATTENTION_REGISTRY: Dict[str, Any] = {}
|
||||
ATTENTION_CLASS_NAMES: Set[str] = set()
|
||||
|
||||
# Arbitrary threshold for now,
|
||||
# in between dense and sparse matrix algorithms for the attention mechanism
|
||||
_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs.
|
||||
_USE_SPUTNIK = True
|
||||
|
||||
|
||||
def build_attention(config: Union[Dict[str, Any], AttentionConfig]):
|
||||
"""Builds an attention from a config.
|
||||
|
||||
This assumes a 'name' key in the config which is used to determine what
|
||||
attention class to instantiate. For instance, a config `{"name": "my_attention",
|
||||
"foo": "bar"}` will find a class that was registered as "my_attention"
|
||||
(see :func:`register_attention`) and call .from_config on it."""
|
||||
|
||||
if not isinstance(config, AttentionConfig):
|
||||
try:
|
||||
config_instance = generate_matching_config(
|
||||
config, ATTENTION_REGISTRY[config["name"]].config
|
||||
)
|
||||
except KeyError as e:
|
||||
name = config["name"]
|
||||
logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}")
|
||||
raise e
|
||||
else:
|
||||
config_instance = config
|
||||
|
||||
return ATTENTION_REGISTRY[config_instance.name].constructor.from_config(
|
||||
config_instance
|
||||
)
|
||||
|
||||
|
||||
"""Registers an Attention subclass.
|
||||
|
||||
This decorator allows xFormers to instantiate a subclass of Attention
|
||||
from a configuration file, even if the class itself is not part of the
|
||||
xFormers library. To use it, apply this decorator to an Attention
|
||||
subclass, like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dataclass
|
||||
class MyConfig:
|
||||
...
|
||||
|
||||
@register_attention('my_attention', MyConfig)
|
||||
class MyAttention(Attention):
|
||||
...
|
||||
|
||||
To instantiate an attention from a configuration file, see :func:`build_attention`."""
|
||||
register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator(
|
||||
ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig
|
||||
)
|
||||
|
||||
|
||||
def maybe_sparsify(matrix) -> Any:
|
||||
# Sparsify if that makes sense
|
||||
if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD:
|
||||
# If not sparse, then AttentionMask is the reference type
|
||||
return AttentionMask.from_bool(matrix)
|
||||
|
||||
return sparsify(matrix)
|
||||
|
||||
|
||||
def sparsify(matrix):
|
||||
if _USE_SPUTNIK:
|
||||
return SparseCS(matrix)
|
||||
return matrix.to_sparse()
|
||||
|
||||
|
||||
from .favor import FavorAttention # noqa
|
||||
from .global_tokens import GlobalAttention # noqa
|
||||
from .linformer import LinformerAttention # noqa
|
||||
from .local import LocalAttention # noqa
|
||||
from .nystrom import NystromAttention # noqa
|
||||
from .ortho import OrthoFormerAttention # noqa
|
||||
from .random import RandomAttention # noqa
|
||||
from .scaled_dot_product import ScaledDotProduct # noqa
|
||||
|
||||
__all__ = [
|
||||
"ScaledDotProduct",
|
||||
"LocalAttention",
|
||||
"LinformerAttention",
|
||||
"NystromAttention",
|
||||
"RandomAttention",
|
||||
"OrthoFormerAttention",
|
||||
"GlobalAttention",
|
||||
"FavorAttention",
|
||||
"Attention",
|
||||
"AttentionMask",
|
||||
"build_attention",
|
||||
"register_attention",
|
||||
]
|
||||
|
||||
# Optionally expose the BlockSparse attention
|
||||
try:
|
||||
from .blocksparse import BlockSparseAttention # noqa
|
||||
|
||||
__all__ += ["BlockSparseAttention"]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# automatically import any Python files in the directory
|
||||
import_all_modules(str(Path(__file__).parent), "xformers.components.attention")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
121
pkgs/xformers/components/attention/_sputnik_sparse.py
Normal file
121
pkgs/xformers/components/attention/_sputnik_sparse.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.ops import masked_matmul
|
||||
from xformers.sparse import SparseCSRTensor
|
||||
|
||||
# TODO: this is here for BC
|
||||
from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401
|
||||
|
||||
|
||||
class SparseCS:
|
||||
def __init__(self, matrix, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if matrix.ndim == 2:
|
||||
matrix = matrix[None]
|
||||
assert matrix.ndim == 3
|
||||
self._mat = SparseCSRTensor.from_dense(matrix).to(device)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._mat.device
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self._mat.ndim
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._mat.dtype
|
||||
|
||||
@property
|
||||
def is_sparse(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._mat.shape[1:]
|
||||
|
||||
@property
|
||||
def values(self):
|
||||
return self._mat.values()
|
||||
|
||||
@property
|
||||
def row_indices(self):
|
||||
return self._mat._csr_row_indices
|
||||
|
||||
@property
|
||||
def column_indices(self):
|
||||
return self._mat._csr_column_indices
|
||||
|
||||
@property
|
||||
def row_offsets(self):
|
||||
return self._mat._csr_row_offsets
|
||||
|
||||
@property
|
||||
def _transp_info(self):
|
||||
return self._mat._csr_transp_info
|
||||
|
||||
@classmethod
|
||||
def wrap(
|
||||
cls, shape, values, row_indices, row_offsets, column_indices, _transp_info
|
||||
):
|
||||
matrix = cls.__new__(cls)
|
||||
_shape = (values.shape[0],) + shape
|
||||
csr_matrix = SparseCSRTensor._wrap(
|
||||
_shape, values, row_indices, row_offsets, column_indices, _transp_info
|
||||
)
|
||||
matrix._mat = csr_matrix
|
||||
return matrix
|
||||
|
||||
@classmethod
|
||||
def _wrap(cls, csr_matrix):
|
||||
assert isinstance(csr_matrix, SparseCSRTensor)
|
||||
matrix = cls.__new__(cls)
|
||||
matrix._mat = csr_matrix
|
||||
return matrix
|
||||
|
||||
def __mul__(self, other):
|
||||
assert isinstance(other, (int, float))
|
||||
return type(self)._wrap(self._mat * other)
|
||||
|
||||
def __add__(self, other):
|
||||
assert isinstance(other, type(self))
|
||||
return type(self)._wrap(self._mat + other._mat)
|
||||
|
||||
def matmul_with_mask(self, a, b):
|
||||
return type(self)._wrap(masked_matmul(a, b, self._mat))
|
||||
|
||||
def softmax(self):
|
||||
out = torch.nn.functional.softmax(self._mat, -1)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def spmm(self, b):
|
||||
out = torch.bmm(self._mat, b)
|
||||
return out
|
||||
|
||||
def transpose(self):
|
||||
out = torch.transpose(self._mat, -2, -1)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def to(self, device):
|
||||
assert isinstance(device, torch.device)
|
||||
out = self._mat.to(device)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def to_dense(self):
|
||||
return self._mat.to_dense()
|
||||
|
||||
def logical_and(self, other: torch.Tensor):
|
||||
assert not isinstance(other, SparseCS)
|
||||
out = torch.logical_and(self._mat, other)
|
||||
return type(self)._wrap(out)
|
||||
|
||||
def __and__(self, other):
|
||||
return self.logical_and(other)
|
||||
143
pkgs/xformers/components/attention/attention_mask.py
Normal file
143
pkgs/xformers/components/attention/attention_mask.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
Self = TypeVar("Self", bound="AttentionMask")
|
||||
|
||||
|
||||
class AttentionMask:
|
||||
"""
|
||||
Holds an attention mask, along with a couple of helpers and attributes.
|
||||
|
||||
.. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value,
|
||||
and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose
|
||||
is to bias the attention computation for instance
|
||||
|
||||
.. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`,
|
||||
`[to_sequence, from_sequence]`, or anything broadcastable in between
|
||||
"""
|
||||
|
||||
def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False):
|
||||
assert additive_mask.is_floating_point(), additive_mask.dtype
|
||||
assert not additive_mask.requires_grad
|
||||
|
||||
if additive_mask.ndim == 2:
|
||||
additive_mask = additive_mask.unsqueeze(0)
|
||||
|
||||
self.values = additive_mask
|
||||
self.is_causal = is_causal
|
||||
self.seq_len = additive_mask.shape[1]
|
||||
self.to_seq_len = additive_mask.shape[0]
|
||||
|
||||
def to_bool(self) -> torch.Tensor:
|
||||
"""
|
||||
.. warning: we assume here that True implies that the value should be computed
|
||||
"""
|
||||
return self.values != float("-inf")
|
||||
|
||||
@classmethod
|
||||
def from_bool(cls: Type[Self], x: torch.Tensor) -> Self:
|
||||
"""
|
||||
Create an AttentionMask given a boolean pattern.
|
||||
.. warning: we assume here that True implies that the value should be computed
|
||||
"""
|
||||
assert x.dtype == torch.bool
|
||||
|
||||
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
|
||||
additive_mask.masked_fill_(x, 0.0)
|
||||
additive_mask.masked_fill_(~x, float("-inf"))
|
||||
|
||||
return cls(additive_mask)
|
||||
|
||||
@classmethod
|
||||
def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self:
|
||||
"""
|
||||
Create an AttentionMask given a multiplicative attention mask.
|
||||
"""
|
||||
assert not x.dtype == torch.bool
|
||||
|
||||
additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device)
|
||||
x = x.bool()
|
||||
|
||||
additive_mask.masked_fill_(x, 0.0)
|
||||
additive_mask.masked_fill_(~x, float("-inf"))
|
||||
|
||||
return cls(additive_mask)
|
||||
|
||||
@classmethod
|
||||
def make_causal(
|
||||
cls: Type[Self],
|
||||
seq_len: int,
|
||||
to_seq_len: Optional[int] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Self:
|
||||
if not to_seq_len:
|
||||
to_seq_len = seq_len
|
||||
|
||||
additive_mask = torch.triu(
|
||||
torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"),
|
||||
diagonal=1,
|
||||
)
|
||||
return cls(additive_mask=additive_mask, is_causal=True)
|
||||
|
||||
def make_crop(
|
||||
self, seq_len: int, to_seq_len: Optional[int] = None
|
||||
) -> "AttentionMask":
|
||||
"""
|
||||
Return a cropped attention mask, whose underlying tensor is a view of this one
|
||||
"""
|
||||
|
||||
if not to_seq_len:
|
||||
to_seq_len = seq_len
|
||||
|
||||
return AttentionMask(
|
||||
self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.values.device
|
||||
|
||||
@property
|
||||
def is_sparse(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return len(self.values.shape)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.values.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.values.shape
|
||||
|
||||
def __add__(self, other):
|
||||
return AttentionMask(self.values + other.values, is_causal=False)
|
||||
|
||||
def to(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None
|
||||
) -> "AttentionMask":
|
||||
assert device is None or isinstance(device, torch.device)
|
||||
assert dtype is None or isinstance(dtype, torch.dtype)
|
||||
assert device is not None or dtype is not None
|
||||
|
||||
# Noop if we don't need to create another instance
|
||||
if ((device and device == self.device) or not device) and (
|
||||
(dtype and dtype == self.dtype) or not dtype
|
||||
):
|
||||
return self
|
||||
|
||||
return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal)
|
||||
295
pkgs/xformers/components/attention/attention_patterns.py
Normal file
295
pkgs/xformers/components/attention/attention_patterns.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from xformers.components.attention.sparsity_config import (
|
||||
BigBirdSparsityConfig,
|
||||
BSLongformerSparsityConfig,
|
||||
FixedSparsityConfig,
|
||||
VariableSparsityConfig,
|
||||
)
|
||||
|
||||
|
||||
# generic nd cases
|
||||
def _generate_nd_grid(*sizes):
|
||||
coords = [torch.arange(s) for s in sizes]
|
||||
return torch.meshgrid(*coords)
|
||||
|
||||
|
||||
def local_nd_distance(*sizes, p=2.0, weights=None):
|
||||
if weights is None:
|
||||
weights = (1,) * len(sizes)
|
||||
assert len(sizes) == len(weights)
|
||||
grid = _generate_nd_grid(*sizes)
|
||||
grid = [i.flatten() * w for i, w in zip(grid, weights)]
|
||||
grid = torch.stack(grid, dim=1).float()
|
||||
d = torch.cdist(grid, grid, p=p)
|
||||
return d
|
||||
|
||||
|
||||
def local_nd_gaussian_distribution(*sizes, sigma=1):
|
||||
d = local_nd_distance(*sizes, p=2.0) ** 2
|
||||
d = torch.exp(-0.5 * sigma ** (-2.0) * d)
|
||||
return d
|
||||
|
||||
|
||||
def local_nd_pattern(*sizes, distance, p=2.0):
|
||||
d = local_nd_distance(*sizes, p=p)
|
||||
return d < distance
|
||||
|
||||
|
||||
def axial_nd_pattern(*sizes):
|
||||
# axial is a special case with p=0 and distance=2
|
||||
d = local_nd_distance(*sizes, p=0)
|
||||
return d < 2
|
||||
|
||||
|
||||
def random_pattern_from_probability_matrix(dist_matrix, nnz):
|
||||
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
|
||||
# PyTorch multinomial wrongly doesn't support sampling when number of categories
|
||||
# is > 2^24, arguing that it's because it's the max representable consecutive element
|
||||
# in fp32 and that the kernels use float32. This is actually not true, and the kernels
|
||||
# should work fine if double tensor is passed on CPU. This is a bug that was introduced
|
||||
# in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66
|
||||
# when unifying the checks between CPU and CUDA. For now, just fall-back to numpy
|
||||
if dist_matrix.numel() > 2**24:
|
||||
dist_matrix = dist_matrix.double()
|
||||
dist_matrix /= dist_matrix.sum()
|
||||
idxs = np.random.choice(
|
||||
dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False
|
||||
)
|
||||
idxs = torch.as_tensor(idxs)
|
||||
else:
|
||||
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
|
||||
att.view(-1)[idxs] = True
|
||||
return att
|
||||
|
||||
|
||||
def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor:
|
||||
assert attention_query_mask.ndim == 1
|
||||
assert attention_query_mask.dtype == torch.bool
|
||||
attention_query_mask = attention_query_mask[None, :]
|
||||
mask = attention_query_mask | attention_query_mask.transpose(1, 0)
|
||||
return mask
|
||||
|
||||
|
||||
def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
|
||||
assert 0 < sparsity < 1
|
||||
mask = torch.rand(attn_size, attn_size) > sparsity
|
||||
return mask
|
||||
|
||||
|
||||
# 1d-specific cases
|
||||
def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor:
|
||||
assert (
|
||||
window_size % 2 == 1
|
||||
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
|
||||
h_win_size = window_size // 2 + 1
|
||||
return local_nd_pattern(attn_size, distance=h_win_size, p=1.0)
|
||||
|
||||
|
||||
def causal_1d_pattern(attn_size: int) -> torch.Tensor:
|
||||
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
|
||||
return mask
|
||||
|
||||
|
||||
# 2d-specific cases
|
||||
def horizontal_axial_2d_distance(H, W, p=2.0):
|
||||
d = local_nd_distance(H, W, p=p, weights=(1, 0))
|
||||
return d
|
||||
|
||||
|
||||
def vertical_axial_2d_distance(H, W, p=2.0):
|
||||
d = local_nd_distance(H, W, p=p, weights=(0, 1))
|
||||
return d
|
||||
|
||||
|
||||
def local_2d_distance(H, W, p=2.0):
|
||||
return local_nd_distance(H, W, p=p)
|
||||
|
||||
|
||||
def local_2d_gausian_distribution(H, W, sigma=1):
|
||||
return local_nd_gaussian_distribution(H, W, sigma=sigma)
|
||||
|
||||
|
||||
def local_2d_pattern(H, W, distance, p=2.0):
|
||||
return local_nd_pattern(H, W, distance=distance, p=p)
|
||||
|
||||
|
||||
def axial_2d_pattern(H, W):
|
||||
return axial_nd_pattern(H, W)
|
||||
|
||||
|
||||
def swin_attention_pattern(H, W, window_size, shift_size=0):
|
||||
assert H % window_size == 0
|
||||
assert W % window_size == 0
|
||||
assert 0 <= shift_size < window_size, "shift_size must in 0-window_size"
|
||||
|
||||
# input grid
|
||||
i, j = _generate_nd_grid(H, W)
|
||||
i, j = i + 0.5, j + 0.5
|
||||
|
||||
# anchors grid
|
||||
# if shift is present, add extra element to the grid
|
||||
# to account for the uneven partitioning
|
||||
extra = int(shift_size % window_size != 0)
|
||||
grid_h = H // window_size + extra
|
||||
grid_w = W // window_size + extra
|
||||
|
||||
ii, jj = _generate_nd_grid(grid_h, grid_w)
|
||||
# convert shift to be compatible with the paper representation
|
||||
s = (-shift_size) % window_size
|
||||
offset = window_size / 2 - s
|
||||
ii = ii * window_size + offset
|
||||
jj = jj * window_size + offset
|
||||
|
||||
input_coords = torch.stack([i.flatten(), j.flatten()], 1).float()
|
||||
anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float()
|
||||
|
||||
anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1)
|
||||
mask = anchor_id[:, None] == anchor_id[None, :]
|
||||
return mask
|
||||
|
||||
|
||||
def dilated_2d_pattern(H, W, k=2):
|
||||
"""
|
||||
Returns a 2d pattern that samples 1 every k elements in the attention mask.
|
||||
Can be seen as a form of downsampling, where every pixel attends to a downsampled
|
||||
version of the input.
|
||||
"""
|
||||
d_h = local_nd_distance(H, W, p=1, weights=(1, 0))
|
||||
d_w = local_nd_distance(H, W, p=1, weights=(0, 1))
|
||||
d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0)
|
||||
return d
|
||||
|
||||
|
||||
# Block sparse utils
|
||||
def block_sparsify_tensor(x, mask, block_size):
|
||||
"""
|
||||
Block sparsify a tensor, given a mask and block size
|
||||
"""
|
||||
ret = torch.empty(
|
||||
(x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device
|
||||
)
|
||||
|
||||
for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))):
|
||||
ret[:, idx, :, :] = x[
|
||||
:,
|
||||
h,
|
||||
i * block_size : (i + 1) * block_size,
|
||||
j * block_size : (j + 1) * block_size,
|
||||
]
|
||||
return ret
|
||||
|
||||
|
||||
def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor:
|
||||
r"""
|
||||
Given a mask pattern and blocksize, return the corresponding layout
|
||||
which makes sure that all the positives in the mask are covered
|
||||
"""
|
||||
assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]"
|
||||
_should_squeeze = False
|
||||
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
_should_squeeze = True
|
||||
|
||||
assert (
|
||||
mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0
|
||||
), "We're only handling masks divisible by block_size"
|
||||
|
||||
# Now mark the mask
|
||||
layout = torch.nn.functional.max_pool2d(
|
||||
mask.to(torch.float), kernel_size=block_size, stride=block_size
|
||||
)
|
||||
layout = layout.to(torch.long)
|
||||
|
||||
if _should_squeeze:
|
||||
layout.squeeze_(0)
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor:
|
||||
r"""
|
||||
Use the additive bias computation from ALiBi_ to generate a mask.
|
||||
Note that this mask can in turn be used to generate a blocksparse attention computation layout
|
||||
|
||||
.. note: mask_shape is expected to hold the [heads, seq, seq] dimensions
|
||||
|
||||
.. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf
|
||||
"""
|
||||
|
||||
# CREDITS: code snippet from Ofir Press, one of the authors
|
||||
|
||||
def get_slopes(n: int):
|
||||
def get_slopes_power_of_2(n: int) -> List[float]:
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
# In the paper, we only train models that have 2^a heads for some a. This function has
|
||||
# some good properties that only occur when the input is a power of 2. To maintain that even
|
||||
# when the number of heads is not a power of 2, we use this workaround.
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
maxpos = mask_shape[1]
|
||||
attn_heads = mask_shape[0]
|
||||
slopes = torch.Tensor(get_slopes(attn_heads))
|
||||
|
||||
# In the next line, the part after the * is what constructs the diagonal matrix
|
||||
# (right matrix in Figure 3 in the paper).
|
||||
# If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3,
|
||||
# but one where all rows are identical.
|
||||
# This works because the softmax operation is invariant to translation,
|
||||
# and our bias functions are always linear.
|
||||
alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze(
|
||||
0
|
||||
).unsqueeze(0).expand(attn_heads, -1, -1)
|
||||
alibi = alibi.view(attn_heads, 1, maxpos)
|
||||
|
||||
# Now threshold arbitrarily, report the mask
|
||||
return alibi < threshold
|
||||
|
||||
|
||||
def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int):
|
||||
config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size)
|
||||
return config.make_layout(seq_len)
|
||||
|
||||
|
||||
def quick_variable_layout(num_heads: int, block_size: int, seq_len: int):
|
||||
config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size)
|
||||
return config.make_layout(seq_len)
|
||||
|
||||
|
||||
def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int):
|
||||
config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size)
|
||||
return config.make_layout(seq_len)
|
||||
|
||||
|
||||
def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int):
|
||||
config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size)
|
||||
return config.make_layout(seq_len)
|
||||
|
||||
|
||||
def layout_to_pattern(layout: torch.Tensor, block_size: int):
|
||||
r"""
|
||||
create a pattern of shape [heads, seq, seq] out of a blocksparse
|
||||
layout of shape [heads, seq/block_size, seq/block_size]
|
||||
"""
|
||||
return torch.kron(layout, torch.ones(block_size, block_size))
|
||||
93
pkgs/xformers/components/attention/base.py
Normal file
93
pkgs/xformers/components/attention/base.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import AttentionMask
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionConfig:
|
||||
"""Parameters required for all Attentions.
|
||||
Can accept and store extra parameters.
|
||||
"""
|
||||
|
||||
name: str # the registered name for this attention mechanism
|
||||
dropout: float # dropout probability
|
||||
|
||||
|
||||
Self = TypeVar("Self", bound="Attention")
|
||||
|
||||
|
||||
# Define the common interface, every attention block needs to derive from it
|
||||
class Attention(nn.Module, metaclass=ABCMeta):
|
||||
r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention"""
|
||||
|
||||
_causal_mask: Optional[AttentionMask] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, dropout: Optional[float] = None, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Requires the inputs to be projected
|
||||
self.requires_input_projection = True
|
||||
|
||||
# Whether the head dimension needs to be present (if not it can be folded into the batch dimension)
|
||||
self.requires_head_dimension = False
|
||||
|
||||
# key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask
|
||||
self.requires_separate_masks = False
|
||||
|
||||
# Requires that K and Q have the same sequence length
|
||||
self.requires_same_k_q_dimensions = False
|
||||
|
||||
# Whether the attention owns the single head/multihead mechanism
|
||||
# so that the MHA wrapper should skip it
|
||||
self.requires_skip_multi_head = False
|
||||
|
||||
# This attention requires a context length which is squared, often due to 2D pooling
|
||||
self.requires_squared_context = False
|
||||
|
||||
# Whether this attention mechanism supports attention masks
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[Self], config: AttentionConfig) -> Self:
|
||||
# Generate the class inputs from the config
|
||||
fields = asdict(config)
|
||||
|
||||
# Skip all Nones so that default values are used
|
||||
fields = {k: v for k, v in fields.items() if v is not None}
|
||||
|
||||
return cls(**fields)
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor):
|
||||
"""
|
||||
If the sequence is shorter than the mask, return a padded view
|
||||
"""
|
||||
if x.shape[-2] != mask.shape[-1]:
|
||||
assert x.shape[-2] < mask.shape[-1], (
|
||||
"Sequence is bigger than the provided mask, cannot infer what to do with it."
|
||||
" Please update your attention mask"
|
||||
)
|
||||
|
||||
pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0)
|
||||
return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0)
|
||||
|
||||
return x
|
||||
190
pkgs/xformers/components/attention/blocksparse.py
Normal file
190
pkgs/xformers/components/attention/blocksparse.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from xformers import _is_triton_available
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_is_blocksparse_available = _is_triton_available()
|
||||
|
||||
|
||||
if _is_blocksparse_available:
|
||||
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
|
||||
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore
|
||||
|
||||
from xformers.triton.utils import gpu_capabilities_older_than_70
|
||||
|
||||
# Blocksparse requires Tensor cores
|
||||
if gpu_capabilities_older_than_70():
|
||||
logger.warning(
|
||||
"Blocksparse is not available: the current GPU does not expose Tensor cores"
|
||||
)
|
||||
_is_blocksparse_available = False
|
||||
|
||||
|
||||
if _is_blocksparse_available:
|
||||
|
||||
@dataclass
|
||||
class BlockSparseAttentionConfig(AttentionConfig):
|
||||
layout: torch.Tensor # The dimensions of the random features
|
||||
block_size: int
|
||||
dropout: float
|
||||
num_heads: int
|
||||
|
||||
@register_attention("blocksparse", BlockSparseAttentionConfig)
|
||||
class BlockSparseAttention(Attention):
|
||||
r"""
|
||||
Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout.
|
||||
|
||||
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
|
||||
If some dimensions are missing, we assume that the same layout is to be used across heads.
|
||||
|
||||
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
|
||||
be relaxed in the future.
|
||||
|
||||
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
|
||||
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout: torch.Tensor,
|
||||
block_size: int = 16,
|
||||
dropout: float = 0.0,
|
||||
num_heads: int = 1, # optional, used to adapt the layout if in need
|
||||
causal: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if layout.dim() == 2:
|
||||
logger.warning(
|
||||
"The layout passed is lacking a head dimension and a batch dimension"
|
||||
)
|
||||
logger.warning(
|
||||
"Now assuming that the same layout is to be used across all heads"
|
||||
)
|
||||
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
|
||||
logger.warning(f"New layout dimensions: {layout.shape}")
|
||||
|
||||
assert block_size in (
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
), "Only block sizes in [16, 32, 64, 128] are supported"
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
|
||||
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
|
||||
|
||||
# Pure blocksparse data
|
||||
self.layout = layout
|
||||
self.block_size = block_size
|
||||
|
||||
# make sure that the head dimension is not folded down with the batch
|
||||
self.requires_head_dimension = True
|
||||
|
||||
# key padding mask and attention mask must be passed in separately
|
||||
self.requires_same_k_q_dimensions = True
|
||||
|
||||
# The underlying triton op does not support per element attention mask
|
||||
self.supports_attention_mask = False
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def create_triton_kernels(self, device):
|
||||
# blocksparse operators
|
||||
self.sparse_dot_sdd = blocksparse_matmul(
|
||||
self.layout,
|
||||
self.block_size,
|
||||
"sdd",
|
||||
trans_a=False,
|
||||
trans_b=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.sparse_dot_dsd = blocksparse_matmul(
|
||||
self.layout,
|
||||
self.block_size,
|
||||
"dsd",
|
||||
trans_a=False,
|
||||
trans_b=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.sparse_softmax = blocksparse_softmax(
|
||||
self.layout,
|
||||
self.block_size,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float = 1.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
"att_mask" not in kwargs.keys() and "att_mask" not in args
|
||||
), "This attention does not support an attention mask, but you can specify causality."
|
||||
|
||||
r"""
|
||||
A thin wrap around the Triton blockparse attention operation
|
||||
|
||||
.. note: Per element attention mask is not supported, but you can specify causality
|
||||
"""
|
||||
|
||||
# Delayed triton init, to make sure that we get the right device
|
||||
# Infer device from query
|
||||
if not hasattr(self, "sparse_dot_sdd"):
|
||||
self.create_triton_kernels(q.device)
|
||||
|
||||
assert (
|
||||
q.shape[-2] == k.shape[-2]
|
||||
), "Blocksparse requires the same dimensions for K and Q for now"
|
||||
|
||||
assert (
|
||||
q.shape[-2] == self.layout.shape[-2] * self.block_size
|
||||
), "Actual sequence size and layout are inconsistent"
|
||||
assert (
|
||||
k.shape[-2] == self.layout.shape[-2] * self.block_size
|
||||
), "Actual sequence size and layout are inconsistent"
|
||||
|
||||
assert (
|
||||
q.shape[-2] % self.block_size
|
||||
) == 0, "Sequence length {} must be a multiple of block size {}".format(
|
||||
q.shape[-2], self.block_size
|
||||
)
|
||||
|
||||
# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
|
||||
# When the computations are block sparse, the matrix types change along the way:
|
||||
# - (sparse) attention matrix = (dense) Kt * (dense) Q
|
||||
q = q / math.sqrt(q.size(-1))
|
||||
sparse_att_mat = self.sparse_dot_sdd(q, k)
|
||||
|
||||
# - softmax on the sparse attention matrix
|
||||
sparse_att_mat = self.sparse_softmax(
|
||||
sparse_att_mat, scale=scale, is_causal=self.causal
|
||||
)
|
||||
|
||||
sparse_att_mat = self.attn_drop(sparse_att_mat)
|
||||
|
||||
# - then (dense) attention is (sparse) attention matrix * dense (value)
|
||||
a = self.sparse_dot_dsd(sparse_att_mat, v)
|
||||
return a
|
||||
341
pkgs/xformers/components/attention/compositional.py
Normal file
341
pkgs/xformers/components/attention/compositional.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Credits: this is heavily inspired by the official implementation, present in
|
||||
# https://github.com/sarthmit/Compositional-Attention
|
||||
# Original author: Sarthak Mittal
|
||||
|
||||
# This is a simplified version, for the sake of clarity, and because some features could be exposed later
|
||||
# via the library directly.
|
||||
# In particular, code paths for TPUs, quantization and gumbel softmax have been removed
|
||||
# We're also following the same dimension ordering as in the rest of the xformers library
|
||||
# which is to say [Batch, Sequence, Embedding] wherever possible
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
register_attention,
|
||||
)
|
||||
from xformers.components.attention.core import _softmax
|
||||
from xformers.components.input_projection import InputProjection, InputProjectionConfig
|
||||
|
||||
|
||||
def _either_or(a: Optional[int], b: int) -> int:
|
||||
return a if a is not None else b
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompositionalAttentionConfig(AttentionConfig):
|
||||
dim_model: int
|
||||
num_heads: int
|
||||
dim_attn: Optional[int] = None
|
||||
num_rules: Optional[int] = None
|
||||
dim_key: Optional[int] = None
|
||||
dim_value: Optional[int] = None
|
||||
dim_selection: Optional[int] = None
|
||||
dropout: float
|
||||
qk_rule: bool = False
|
||||
nonlinear: bool = False
|
||||
q_compose: bool = False
|
||||
bias: bool = True
|
||||
causal: Optional[bool] = False
|
||||
in_proj_container: Optional[InputProjection] = None
|
||||
use_separate_proj_weight: Optional[bool] = False
|
||||
|
||||
|
||||
@register_attention("compositional", CompositionalAttentionConfig)
|
||||
class CompositionalAttention(Attention):
|
||||
"""Compositional Attention, as proposed in
|
||||
"Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al.
|
||||
|
||||
A key insight from this proposal is that the attention mechanism can be conceived as two steps:
|
||||
a search and a retrieval operation. When queried, the model can search for the most relevant information
|
||||
(Softmax(QKt)), then retrieve information given the Value.
|
||||
|
||||
Contrary to the original attention proposal, which does not consider interactions in between heads,
|
||||
the compositional attention will consider all possible interactions and softmax over that dimension,
|
||||
so that the information retrieved covers the most relevant dimensions. The number of heads and rules to
|
||||
use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads
|
||||
may not fit in memory.
|
||||
|
||||
Args:
|
||||
dim_model: dimension of the incoming latent space
|
||||
num_heads: number of heads *for the search operation*
|
||||
dim_attn: dimension (embedding) of the attention
|
||||
num_rules: number of rules to consider *for the retrieval operation*
|
||||
dim_selection: dimension of the scoring/selection space for the retrievals
|
||||
dim_key, dim_value: dimensions of K and V, if different from Q
|
||||
dropout: attention dropout probability
|
||||
qk_rule: QK product will drive the retrieval process
|
||||
nonlinear: use a non linear method to score the retrievals
|
||||
bias: use bias in the initial projection step
|
||||
causal: causal computations (attend to the past only)
|
||||
|
||||
_"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
num_heads: int,
|
||||
dim_attn: Optional[int] = None,
|
||||
num_rules: Optional[int] = None,
|
||||
dim_selection: Optional[int] = None,
|
||||
dim_key: Optional[int] = None,
|
||||
dim_value: Optional[int] = None,
|
||||
dropout=0.0,
|
||||
qk_rule=False,
|
||||
nonlinear=False,
|
||||
q_compose=False,
|
||||
in_proj_container: Optional[InputProjection] = None,
|
||||
use_separate_proj_weight: Optional[bool] = False,
|
||||
bias=True,
|
||||
causal=False,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Define the inherited flags
|
||||
self.requires_skip_multi_head = (
|
||||
True # This attention owns the multi-head mechanism
|
||||
)
|
||||
|
||||
# Handle defaults / undefined values
|
||||
self.dim_model = dim_model
|
||||
num_rules = _either_or(num_rules, num_heads)
|
||||
dim_selection = _either_or(dim_selection, dim_model // num_heads)
|
||||
|
||||
# All the initial definition plumbing
|
||||
dim_attn = _either_or(dim_attn, dim_model)
|
||||
dim_key = _either_or(dim_key, dim_model)
|
||||
dim_value = _either_or(dim_value, dim_model)
|
||||
|
||||
self.in_proj_container = (
|
||||
in_proj_container
|
||||
if in_proj_container is not None
|
||||
else InputProjection(
|
||||
query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias),
|
||||
key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias)
|
||||
if use_separate_proj_weight
|
||||
else None,
|
||||
value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias)
|
||||
if use_separate_proj_weight
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.num_rules = num_rules
|
||||
self.qk_rule = qk_rule
|
||||
self.dim_selection = dim_selection
|
||||
self.nonlinear = nonlinear
|
||||
self.q_compose = q_compose
|
||||
|
||||
self.dropout_module = nn.Dropout(dropout)
|
||||
self.dim_head = dim_model // num_heads
|
||||
self.value_dim = dim_attn // num_rules
|
||||
|
||||
assert (
|
||||
self.value_dim * num_rules == dim_attn
|
||||
), "value_dim must be divisible by num_rules"
|
||||
|
||||
self.scaling = self.dim_head**-0.5
|
||||
self.scaling_values = self.dim_selection**-0.5
|
||||
|
||||
self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias)
|
||||
|
||||
if self.qk_rule:
|
||||
self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias)
|
||||
if self.q_compose:
|
||||
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
|
||||
else:
|
||||
self.value_q = nn.Linear(
|
||||
dim_model, self.dim_selection * self.num_heads, bias=bias
|
||||
)
|
||||
else:
|
||||
if self.q_compose:
|
||||
self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias)
|
||||
else:
|
||||
self.value_q = nn.Linear(
|
||||
dim_model, self.dim_selection * self.num_heads, bias=bias
|
||||
)
|
||||
if self.nonlinear:
|
||||
self.score_network: nn.Module = nn.Sequential(
|
||||
nn.Linear(
|
||||
self.dim_selection + self.value_dim,
|
||||
self.dim_selection,
|
||||
bias=bias,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.dim_selection, 1, bias=bias),
|
||||
)
|
||||
else:
|
||||
self.score_network = nn.Linear(
|
||||
self.dim_selection + self.value_dim, 1, bias=bias
|
||||
)
|
||||
|
||||
self.causal = causal
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
# NOTE: in_proj_container is already initialized
|
||||
|
||||
if self.qk_rule:
|
||||
nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.value_q.weight)
|
||||
if self.nonlinear:
|
||||
nn.init.xavier_uniform_(self.score_network[0].weight)
|
||||
nn.init.xavier_uniform_(self.score_network[2].weight)
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.score_network.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
att_mask: Optional[Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Input shape: Time x Batch x Channel
|
||||
|
||||
Args:
|
||||
att_mask (ByteTensor, optional): typically used to
|
||||
implement causal attention, where the mask prevents the
|
||||
attention from looking forward in time (default: None).
|
||||
"""
|
||||
|
||||
B, Sq, E = q.shape
|
||||
_, Sk, _ = k.shape
|
||||
|
||||
assert E == self.dim_model
|
||||
|
||||
# First define projected query/key/values
|
||||
# We keep the projected and original tensors in flight,
|
||||
# depending on the options the original values could be reused
|
||||
q_unprojected = q
|
||||
q, k, v = self.in_proj_container(query=q, key=k, value=v)
|
||||
q *= self.scaling
|
||||
|
||||
# Init causal mask if needed, now that we know the context length
|
||||
if self.causal and (
|
||||
self._causal_mask is None or self._causal_mask.shape[0] != Sk
|
||||
):
|
||||
self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device)
|
||||
|
||||
# Convenience, create an attention mask if a tensor was passed
|
||||
# This sanitizes different mask types being passed, from now on it's additive
|
||||
if isinstance(att_mask, torch.Tensor):
|
||||
# By default we don't know of the causality, and a check would be expensive
|
||||
att_mask_additive: Optional[AttentionMask] = (
|
||||
AttentionMask.from_bool(att_mask)
|
||||
if att_mask.dtype == torch.bool
|
||||
else AttentionMask(att_mask, is_causal=False)
|
||||
)
|
||||
else:
|
||||
att_mask_additive = None
|
||||
|
||||
# Handle the attention and key padding masks
|
||||
if self._causal_mask is not None:
|
||||
# Optionally add the causal mask
|
||||
if att_mask_additive is not None:
|
||||
att_mask_additive += self._causal_mask
|
||||
else:
|
||||
att_mask_additive = self._causal_mask
|
||||
|
||||
# Flatten the heads or the rules
|
||||
q = (
|
||||
q.view(B, Sq, self.num_heads, self.dim_head)
|
||||
.movedim(2, 1)
|
||||
.flatten(0, 1) # [B * num_heads, Sq, dim_head]
|
||||
)
|
||||
k = (
|
||||
k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1)
|
||||
) # [B * num_heads, Sk, dim_head]
|
||||
v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1)
|
||||
|
||||
# Compute the search: Softmax(QKt)
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk]
|
||||
|
||||
if att_mask_additive is not None:
|
||||
attn_weights += att_mask_additive.values
|
||||
|
||||
attn_weights = _softmax(attn_weights, causal=self.causal)
|
||||
|
||||
attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk)
|
||||
attn_probs = self.dropout_module(attn_weights)
|
||||
|
||||
# Now compute the information retrieval
|
||||
# keep all the heads in flight, we'll score the different possibilities
|
||||
# - compute all the possible retrievals
|
||||
v = v.view(B, 1, self.num_rules, Sk, self.value_dim)
|
||||
attn_probs = attn_probs.unsqueeze(2)
|
||||
attn = torch.matmul(attn_probs, v).view(
|
||||
B, self.num_heads, self.num_rules, Sq, self.value_dim
|
||||
)
|
||||
|
||||
attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values]
|
||||
|
||||
# - search the most appropriate retrieval among all the values
|
||||
if self.q_compose:
|
||||
v_q = self.value_q(q.transpose(0, 1)).view(
|
||||
B, Sq, self.num_heads, 1, self.dim_selection
|
||||
)
|
||||
else:
|
||||
v_q = self.value_q(q_unprojected).view(
|
||||
B, Sq, self.num_heads, 1, self.dim_selection
|
||||
)
|
||||
|
||||
if self.qk_rule:
|
||||
v_q *= self.scaling_values
|
||||
v_k = (
|
||||
self.value_k(attn)
|
||||
.view(B, Sq, self.num_heads, self.num_rules, self.dim_selection)
|
||||
.transpose(4, 3)
|
||||
.contiguous()
|
||||
)
|
||||
v_score = torch.matmul(v_q, v_k).view(
|
||||
B, Sq, self.num_heads, self.num_rules, 1
|
||||
)
|
||||
else:
|
||||
v_q = v_q.expand(-1, -1, -1, self.num_rules, -1)
|
||||
v_in = torch.cat([attn, v_q], dim=-1)
|
||||
v_score = self.score_network(v_in).view(
|
||||
B, Sq, self.num_heads, self.num_rules, 1
|
||||
)
|
||||
|
||||
v_score = F.softmax(v_score, dim=3)
|
||||
|
||||
# - extracted values are the original attention (inc. all the values) weighted by value score
|
||||
attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim)
|
||||
|
||||
# Final attention projection, same as other mechanisms
|
||||
attn = self.out_proj(attn)
|
||||
|
||||
return attn
|
||||
342
pkgs/xformers/components/attention/core.py
Normal file
342
pkgs/xformers/components/attention/core.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers import _has_cpp_library, _is_triton_available
|
||||
from xformers.components.attention.attention_mask import AttentionMask
|
||||
|
||||
if _has_cpp_library:
|
||||
from ._sputnik_sparse import SparseCS
|
||||
|
||||
if _is_triton_available():
|
||||
from xformers.triton.softmax import softmax as triton_softmax
|
||||
from xformers.triton.utils import gpu_capabilities_older_than_70
|
||||
|
||||
_is_blocksparse_available = (
|
||||
_is_triton_available() and not gpu_capabilities_older_than_70()
|
||||
)
|
||||
|
||||
if _is_blocksparse_available:
|
||||
from xformers.components.attention.blocksparse import BlockSparseAttention
|
||||
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
|
||||
assert matrix.ndim == 3
|
||||
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
|
||||
nonzero = torch.nonzero(keep)
|
||||
nnz = nonzero.shape[0]
|
||||
# NOTE: need to make it a multiple of 4 for sputnik
|
||||
nonzero = nonzero[: (nnz - nnz % divisible_by)]
|
||||
i, j = nonzero.unbind(1)
|
||||
output = torch.zeros_like(matrix)
|
||||
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
|
||||
output[bdim, i, j] = matrix[bdim, i, j]
|
||||
return output
|
||||
|
||||
|
||||
def _broadcast_batch(mask, batch_size):
|
||||
if mask.ndim == 3:
|
||||
return mask
|
||||
assert mask.ndim == 2
|
||||
|
||||
mask = mask.coalesce()
|
||||
values = mask.values()
|
||||
indices = mask.indices()
|
||||
nnz = len(values)
|
||||
# strategy: repeat the indices and append the extra batch dimension to the indices
|
||||
indices = indices.repeat(1, batch_size)
|
||||
# now create the batch indices
|
||||
batch_indices = torch.arange(batch_size, device=indices.device)
|
||||
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
|
||||
|
||||
# put them together
|
||||
indices = torch.cat([batch_indices[None, :], indices], dim=0)
|
||||
|
||||
# now repeat the values
|
||||
values = values.repeat(batch_size)
|
||||
|
||||
size = (batch_size,) + mask.shape
|
||||
|
||||
return torch.sparse_coo_tensor(indices, values, size)
|
||||
|
||||
|
||||
def _matmul_with_mask(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
mask: Optional[Union[torch.Tensor, "SparseCS"]],
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return a @ b
|
||||
|
||||
if _has_cpp_library and mask.dtype == torch.bool:
|
||||
if isinstance(mask, SparseCS):
|
||||
return mask.matmul_with_mask(a, b)
|
||||
if mask.is_sparse:
|
||||
# perform broadcasting if needed
|
||||
mask = _broadcast_batch(mask, a.shape[0])
|
||||
|
||||
# coalesced is not implemented for bool tensors, so need to cast
|
||||
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
|
||||
|
||||
return torch.ops.xformers.matmul_with_mask(a, b, mask)
|
||||
|
||||
# Non optimized codepath
|
||||
if _has_cpp_library:
|
||||
assert not isinstance(mask, SparseCS)
|
||||
|
||||
att = a @ b
|
||||
if mask.dtype == torch.bool:
|
||||
assert not isinstance(mask, SparseCS)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
||||
# mask is presumed false == ignore
|
||||
att[~mask] = float("-inf")
|
||||
else:
|
||||
# mask is presumed additive
|
||||
# repeat if batch sizes don't match
|
||||
if (
|
||||
not isinstance(mask, SparseCS)
|
||||
and mask.ndim == 3
|
||||
and mask.shape[0] != att.shape[0]
|
||||
and (att.shape[0] % mask.shape[0]) == 0
|
||||
):
|
||||
repeat_factor = att.shape[0] // mask.shape[0]
|
||||
mask = mask.repeat([repeat_factor, 1, 1])
|
||||
logger.info("Mismatched batch dimensions for mask, repeating mask.")
|
||||
att += mask
|
||||
return att
|
||||
|
||||
|
||||
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
|
||||
if _has_cpp_library and isinstance(a, SparseCS):
|
||||
return a.softmax()
|
||||
|
||||
if a.is_sparse:
|
||||
return torch.sparse.softmax(a, dim=a.ndim - 1)
|
||||
|
||||
if _is_triton_available():
|
||||
return triton_softmax(a, mask=None, causal=causal)
|
||||
else:
|
||||
return torch.softmax(a, dim=a.ndim - 1)
|
||||
|
||||
|
||||
if _has_cpp_library:
|
||||
|
||||
class SparseBMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
a = a.coalesce()
|
||||
r = torch.bmm(a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
a, b = ctx.saved_tensors
|
||||
|
||||
# gradients w.r.t. a
|
||||
ga = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
|
||||
|
||||
# gradients w.r.t. b
|
||||
gb = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
gb = a.transpose(1, 2).bmm(grad)
|
||||
|
||||
return ga, gb
|
||||
|
||||
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Batch matrix multiply between a sparse matrix and a dense matrix
|
||||
"""
|
||||
assert a.ndim == b.ndim == 3
|
||||
assert a.shape[0] == b.shape[0]
|
||||
assert a.shape[2] == b.shape[1]
|
||||
return SparseBMM.apply(a, b)
|
||||
|
||||
|
||||
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
if _has_cpp_library:
|
||||
if isinstance(a, SparseCS):
|
||||
return a.spmm(b)
|
||||
if a.is_sparse:
|
||||
return _sparse_bmm(a, b)
|
||||
return a @ b
|
||||
|
||||
|
||||
def _apply_dropout(att, dropout):
|
||||
if dropout is None:
|
||||
return att
|
||||
|
||||
# Dropout chokes on sparse tensors
|
||||
if _has_cpp_library:
|
||||
if isinstance(att, SparseCS):
|
||||
values = att.values.clone()
|
||||
values = dropout(values)
|
||||
att = SparseCS.wrap(
|
||||
att.shape,
|
||||
values,
|
||||
att.row_indices,
|
||||
att.row_offsets,
|
||||
att.column_indices,
|
||||
att._transp_info,
|
||||
)
|
||||
elif att.is_sparse:
|
||||
att = att.coalesce()
|
||||
values = att.values().clone() # protect against in-place dropout
|
||||
values = dropout(values)
|
||||
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
|
||||
else:
|
||||
# Simple dense case
|
||||
att = dropout(att)
|
||||
|
||||
return att
|
||||
|
||||
# Non optimized vanilla dropout
|
||||
att = dropout(att)
|
||||
return att
|
||||
|
||||
|
||||
def scaled_query_key_softmax(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
|
||||
# this is needed due to limitations in sparse_bmm for now
|
||||
|
||||
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
|
||||
q = q / math.sqrt(k.size(-1))
|
||||
|
||||
# Matmul with mask
|
||||
if att_mask is not None and isinstance(att_mask, AttentionMask):
|
||||
# Additive mask
|
||||
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
|
||||
else:
|
||||
mask = att_mask
|
||||
|
||||
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
|
||||
|
||||
# Softmax to get the attention probabilities
|
||||
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
|
||||
att = _softmax(att, causal=is_causal)
|
||||
return att
|
||||
|
||||
|
||||
if _is_blocksparse_available:
|
||||
# 128 is default maxsize
|
||||
@lru_cache(maxsize=128)
|
||||
def _retrieve_blocksparse(
|
||||
num_heads: int, seq_len: int, block_size: int
|
||||
) -> BlockSparseAttention:
|
||||
# Checks if blocksparse object exists in cache
|
||||
|
||||
blocks = seq_len // block_size
|
||||
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
|
||||
return BlockSparseAttention(
|
||||
layout=layout_fill, block_size=block_size, causal=True
|
||||
)
|
||||
|
||||
|
||||
def blocksparse_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
dropout: Optional[torch.nn.Module] = None,
|
||||
block_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
|
||||
orig_dim = q.dim()
|
||||
seq_len = q.shape[-2]
|
||||
# Layout head dimension: 1 or batch size (q.shape[0])
|
||||
layout_heads = 1
|
||||
|
||||
# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
|
||||
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
|
||||
|
||||
if orig_dim == 3:
|
||||
# Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh
|
||||
# Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs)
|
||||
if layout_heads == 1:
|
||||
q = q.unsqueeze(1)
|
||||
k = k.unsqueeze(1)
|
||||
v = v.unsqueeze(1)
|
||||
else:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
|
||||
# Dropout is a no-op in evaluation mode
|
||||
if isinstance(dropout, torch.nn.Dropout):
|
||||
blocksparse_attention.attn_drop = dropout
|
||||
else:
|
||||
blocksparse_attention.attn_drop = torch.nn.Dropout(0.0)
|
||||
att = blocksparse_attention(q, k, v)
|
||||
|
||||
# Reshape attention (B, nh, S, hs) back to (N, S, hs)
|
||||
if orig_dim == 3:
|
||||
return att.flatten(0, 1)
|
||||
return att
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
||||
dropout: Optional[torch.nn.Module] = None,
|
||||
block_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
autocast_disabled = (
|
||||
_has_cpp_library
|
||||
and isinstance(att_mask, SparseCS)
|
||||
or (att_mask is not None and att_mask.is_sparse)
|
||||
)
|
||||
seq_len = q.shape[-2]
|
||||
|
||||
# switch if:
|
||||
# causal is required but mask is not sparse
|
||||
# fp16 or under amp context
|
||||
# sequence length is divisible by block size
|
||||
# same seq len for K and Q
|
||||
switch_to_blocksparse = (
|
||||
_is_blocksparse_available
|
||||
and (att_mask is not None and not att_mask.is_sparse)
|
||||
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
|
||||
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
|
||||
and not seq_len % block_size
|
||||
and q.shape[-2] == k.shape[-2]
|
||||
)
|
||||
|
||||
if switch_to_blocksparse:
|
||||
logger.info("Switching causal attention to Triton blocksparse...")
|
||||
return blocksparse_attention(q, k, v, dropout, block_size)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore
|
||||
if autocast_disabled:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
||||
|
||||
# Optional dropout, could be part of the masking in the future
|
||||
att = _apply_dropout(att, dropout)
|
||||
|
||||
# Get to the predicted values, for all heads
|
||||
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
|
||||
y = bmm(att, v)
|
||||
return y
|
||||
173
pkgs/xformers/components/attention/favor.py
Normal file
173
pkgs/xformers/components/attention/favor.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
from xformers.components.attention.feature_maps import (
|
||||
FeatureMap,
|
||||
FeatureMapType,
|
||||
SMHyperbolic,
|
||||
SMOrf,
|
||||
SMReg,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FavorAttentionConfig(AttentionConfig):
|
||||
causal: Optional[bool]
|
||||
dim_features: Optional[int] = None # The dimensions of the random features
|
||||
dim_head: Optional[
|
||||
int
|
||||
] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
|
||||
iter_before_redraw: Optional[
|
||||
int
|
||||
] = None # The number of iterations before the random features are re-drawn from scratch
|
||||
feature_map: Optional[FeatureMapType] = None
|
||||
|
||||
|
||||
@register_attention("favor", FavorAttentionConfig)
|
||||
class FavorAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
causal: bool = False,
|
||||
dropout: float = 0.0,
|
||||
dim_features: Optional[int] = None,
|
||||
dim_head: Optional[int] = None,
|
||||
iter_before_redraw: Optional[int] = None,
|
||||
feature_map_type: FeatureMapType = FeatureMapType.SMReg,
|
||||
normalize_inputs: bool = False,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
r"""
|
||||
Kernelized attention, as proposed in Performers_
|
||||
("Rethinking attention with performers." K. Choromanski et al. (2020).).
|
||||
|
||||
FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
|
||||
|
||||
Args:
|
||||
dropout (float): the probability of an output to be randomly dropped at training time
|
||||
dim_features (int): the dimension of the random features space
|
||||
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
|
||||
feature_map_type (FeatureMapType): the type of feature map being used,
|
||||
for instance orthogonal random features.
|
||||
|
||||
.. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.causal = causal
|
||||
self.iter_before_redraw = (
|
||||
(2 * iter_before_redraw)
|
||||
if iter_before_redraw is not None
|
||||
else iter_before_redraw
|
||||
) # This will be used for both key and query
|
||||
self.normalize_inputs = normalize_inputs
|
||||
self.feature_map_type = feature_map_type
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=True)
|
||||
|
||||
# Setup dimension-dependent variables
|
||||
# Reasonable dimension default
|
||||
if dim_features is None:
|
||||
assert dim_head is not None, "dim_features or dim_head needs to be passed"
|
||||
self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
|
||||
self.dim_features = 2 * (
|
||||
self.dim_features // 2
|
||||
) # needs to be even for some variants
|
||||
logger.info(
|
||||
f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
|
||||
)
|
||||
else:
|
||||
self.dim_features = dim_features
|
||||
|
||||
feature_map_constructor = {
|
||||
FeatureMapType.SMHyp: SMHyperbolic,
|
||||
FeatureMapType.SMReg: SMReg,
|
||||
FeatureMapType.SMOrf: SMOrf,
|
||||
}[self.feature_map_type]
|
||||
|
||||
feature_settings = {
|
||||
"dim_features": self.dim_features,
|
||||
"iter_before_redraw": self.iter_before_redraw,
|
||||
"normalize_inputs": self.normalize_inputs,
|
||||
}
|
||||
|
||||
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = False
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
@staticmethod
|
||||
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
|
||||
# Only promote fp16 buffers, bfloat16 would be fine for instance
|
||||
return x.float() if x.dtype == torch.float16 else x
|
||||
|
||||
@staticmethod
|
||||
def _causal_attention(
|
||||
k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Algorithm 1 in the paper
|
||||
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
|
||||
Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
|
||||
Grenorm = k_prime.unsqueeze(3) * ref_v
|
||||
|
||||
# Consolidate against the feature dimension
|
||||
att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
|
||||
att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
|
||||
|
||||
# Cumulative sum over the sequence
|
||||
att_raw = att_raw.cumsum(2)
|
||||
att_norm = att_norm.cumsum(2)
|
||||
|
||||
return att_raw, att_norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
|
||||
# Project key and queries onto the feature map space
|
||||
k_prime = self.feature_map(k)
|
||||
q_prime = self.feature_map(q)
|
||||
|
||||
with autocast(enabled=False):
|
||||
# The softmax kernel approximation for Favor will easily overflow
|
||||
# Force the computations here to stay in fp32 for numerical stability
|
||||
# Note that the dimensions are vastly reduced when compared to scaled_dot_product
|
||||
k_prime = self._maybe_promote(k_prime)
|
||||
q_prime = self._maybe_promote(q_prime)
|
||||
v = self._maybe_promote(v)
|
||||
|
||||
if not self.causal:
|
||||
att_normalization = q_prime @ (
|
||||
k_prime.transpose(-2, -1) @ torch.ones_like(v)
|
||||
)
|
||||
att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
|
||||
else:
|
||||
# Actually compute attention
|
||||
att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
|
||||
|
||||
# Normalize
|
||||
att = att_raw / att_normalization
|
||||
|
||||
if self.attn_drop is not None:
|
||||
att = self.attn_drop(att)
|
||||
|
||||
return att
|
||||
26
pkgs/xformers/components/attention/feature_maps/__init__.py
Normal file
26
pkgs/xformers/components/attention/feature_maps/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from .base import FeatureMap, FeatureMapConfig
|
||||
from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg
|
||||
|
||||
|
||||
class FeatureMapType(str, Enum):
|
||||
SMOrf = "sm_orf"
|
||||
SMHyp = "sm_hyp"
|
||||
SMReg = "sm_reg" # regularized softmax kernel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SMOrf",
|
||||
"SMReg",
|
||||
"SMHyperbolic",
|
||||
"NormDistribution",
|
||||
"FeatureMapConfig",
|
||||
"FeatureMap",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
61
pkgs/xformers/components/attention/feature_maps/base.py
Normal file
61
pkgs/xformers/components/attention/feature_maps/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
"""
|
||||
Feature maps allow for a given query or key to be encoded in a different space.
|
||||
"""
|
||||
|
||||
Self = TypeVar("Self", bound="FeatureMap")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureMapConfig:
|
||||
name: str
|
||||
dim_features: int
|
||||
iter_before_redraw: Optional[int]
|
||||
normalize_inputs: Optional[bool]
|
||||
epsilon: Optional[float]
|
||||
|
||||
|
||||
class FeatureMap(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int] = None,
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.dim_features = dim_features
|
||||
self.dim_feature_map = dim_features
|
||||
|
||||
self.iter_before_redraw = iter_before_redraw
|
||||
self.features: Optional[torch.Tensor] = None
|
||||
self.epsilon = epsilon
|
||||
self.normalize_inputs = normalize_inputs
|
||||
|
||||
self._iter_counter = 0
|
||||
|
||||
@abstractmethod
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self:
|
||||
# Generate the class inputs from the config
|
||||
fields = asdict(config)
|
||||
|
||||
# Skip all Nones so that default values are used
|
||||
fields = {k: v for k, v in fields.items() if v is not None}
|
||||
|
||||
return cls(**fields)
|
||||
288
pkgs/xformers/components/attention/feature_maps/softmax.py
Normal file
288
pkgs/xformers/components/attention/feature_maps/softmax.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
from .base import FeatureMap
|
||||
|
||||
"""
|
||||
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
|
||||
class NormDistribution(Enum):
|
||||
Xi = auto()
|
||||
Uniform = auto()
|
||||
|
||||
|
||||
class SoftMaxPositiveEstimators(FeatureMap):
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
|
||||
self.softmax_temp = softmax_temp
|
||||
|
||||
# Handle the scaling from all kernels by √m.
|
||||
# This normalizes for all the feature maps involved
|
||||
self.h_scale = math.log(math.sqrt(self.dim_features))
|
||||
|
||||
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with record_function("feature_map::pre_scale"):
|
||||
# Re-draw counting logic
|
||||
if (
|
||||
(
|
||||
self.iter_before_redraw is not None
|
||||
and self._iter_counter > self.iter_before_redraw
|
||||
)
|
||||
or self.features is None
|
||||
or self.features.device != x.device
|
||||
):
|
||||
# The feature map is actually using half the dimension, we'll concatenate + and - features
|
||||
self._iter_counter = 1
|
||||
self.features = self._get_feature_map(
|
||||
x.shape[-1], self.dim_feature_map, x.device
|
||||
)
|
||||
|
||||
features = self.features
|
||||
assert features is not None
|
||||
|
||||
if features.dtype != x.dtype:
|
||||
self.features = features.to(x.dtype)
|
||||
|
||||
self._iter_counter += 1
|
||||
|
||||
# Normalization / softmax
|
||||
if self.softmax_temp < 0:
|
||||
# A = exp(QK.t/√d), so each input will be scaled by √√d
|
||||
self.softmax_temp = x.shape[-1] ** -0.25
|
||||
|
||||
x_scaled = x * self.softmax_temp
|
||||
|
||||
# Compute the scaling factors in logspace, applied from within the exponential
|
||||
# - dimnish possible exponential overflow
|
||||
# - remove a multiply across the batch, replace by an addition
|
||||
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
|
||||
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
|
||||
|
||||
if self.normalize_inputs:
|
||||
# L0 normalize the exponential term, can be useful for numerical stability
|
||||
# This ensures that features +- offset is below 1
|
||||
self.offset -= norm_x_2.max(1, keepdim=True)[0]
|
||||
|
||||
# Return the scaled inputs, the rest depends on the kernel being used
|
||||
return x_scaled
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_random_ortho_matrix(
|
||||
blocks: int,
|
||||
dim: int,
|
||||
device: torch.device,
|
||||
norm_distribution: NormDistribution = NormDistribution.Uniform,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Generate a random matrix whose rows are exactly orthonormal
|
||||
|
||||
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
|
||||
https://arxiv.org/pdf/math-ph/0609050v2.pdf
|
||||
|
||||
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
|
||||
unique and the qr decomposition routines are biased towards numerical stability. See the above
|
||||
paper for more information.
|
||||
|
||||
.. note: this does not follow the original implementation from the Performers authors.
|
||||
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
|
||||
"""
|
||||
|
||||
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
|
||||
|
||||
# Randomly scale the norms of the features, Xi distributed
|
||||
if norm_distribution == NormDistribution.Xi:
|
||||
# NOTE: This averages to sqrt(d)
|
||||
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
|
||||
|
||||
Q, R = torch.linalg.qr(H)
|
||||
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
|
||||
|
||||
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
|
||||
if norm_distribution == NormDistribution.Xi:
|
||||
return torch.diag_embed(norms) @ Q
|
||||
|
||||
return Q
|
||||
|
||||
|
||||
class SMOrf(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Positive random orthogonal features" softmax estimator,
|
||||
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Xi,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return features.flatten(0, 1)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
assert self.features is not None
|
||||
|
||||
# Project onto the random feature map.
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.exp(x_scaled + self.offset)
|
||||
|
||||
|
||||
class SMHyperbolic(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Positive random features hyperbolic" estimator, SMHyp+,
|
||||
as proposed in the Performers_ paper, Lemma 1.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(
|
||||
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
||||
)
|
||||
|
||||
assert (
|
||||
dim_features % 2 == 0
|
||||
), "The feature dimension needs to be even with this kernel"
|
||||
self.dim_feature_map = self.dim_features // 2
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Xi,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return features.flatten(0, 1)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
|
||||
# Project onto the random feature map, concatenate both + and - results
|
||||
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
||||
# softmax kernel (cosh representation)
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.cat(
|
||||
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class SMReg(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(
|
||||
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
||||
)
|
||||
|
||||
assert (
|
||||
dim_features % 2 == 0
|
||||
), "The feature dimension needs to be even with this kernel"
|
||||
self.dim_feature_map = self.dim_features // 2
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Uniform,
|
||||
device=device,
|
||||
).flatten(0, 1)
|
||||
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
|
||||
return (torch.diag(norms) @ features)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
|
||||
# Project onto the random feature map, concatenate both + and - results
|
||||
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
||||
# softmax kernel (cosh representation + sample regularization)
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.cat(
|
||||
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
||||
dim=-1,
|
||||
)
|
||||
35
pkgs/xformers/components/attention/fourier_mix.py
Normal file
35
pkgs/xformers/components/attention/fourier_mix.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
|
||||
|
||||
@register_attention("fourier_mix", AttentionConfig)
|
||||
class FourierMix(Attention):
|
||||
def __init__(self, dropout: float, *_, **__):
|
||||
"""
|
||||
FFT-based pseudo-attention mechanism, from
|
||||
"
|
||||
"FNet: Mixing Tokens with Fourier Transforms"
|
||||
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
|
||||
"""
|
||||
super().__init__()
|
||||
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = False
|
||||
self.requires_input_projection = False
|
||||
|
||||
def forward(self, q: torch.Tensor, *_, **__):
|
||||
# Guard against autocast / fp16, not supported by torch.fft.fft2
|
||||
with autocast(enabled=False):
|
||||
att = torch.fft.fft2(q).real
|
||||
|
||||
att = self.attn_drop(att)
|
||||
|
||||
return att
|
||||
122
pkgs/xformers/components/attention/global_tokens.py
Normal file
122
pkgs/xformers/components/attention/global_tokens.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
maybe_sparsify,
|
||||
register_attention,
|
||||
sparsify,
|
||||
)
|
||||
from xformers.components.attention.attention_patterns import (
|
||||
causal_1d_pattern,
|
||||
global_token_pattern,
|
||||
)
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlobalAttentionConfig(AttentionConfig):
|
||||
attention_query_mask: torch.Tensor # Mark the queries which have global attention
|
||||
causal: Optional[bool]
|
||||
force_sparsity: Optional[bool]
|
||||
|
||||
|
||||
@register_attention("global", GlobalAttentionConfig)
|
||||
class GlobalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float,
|
||||
attention_query_mask: torch.Tensor,
|
||||
causal: bool = False,
|
||||
force_sparsity: bool = False,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
r"""
|
||||
Global attention, as proposed for instance in BigBird_ or Longformer_.
|
||||
|
||||
Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend
|
||||
to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to
|
||||
any other query.
|
||||
|
||||
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
|
||||
|
||||
Args:
|
||||
dropout (float): probability of an element to be zeroed
|
||||
attention_query_mask (torch.Tensor): if true, this query can attend to all the others
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected"
|
||||
assert (
|
||||
attention_query_mask.shape[1] == 1
|
||||
and attention_query_mask.shape[0] > attention_query_mask.shape[1]
|
||||
), "A N x 1 query mask is expected"
|
||||
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
||||
self.attention_mask = global_token_pattern(attention_query_mask[:, 0])
|
||||
self.force_sparsity = force_sparsity
|
||||
|
||||
if causal:
|
||||
self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1])
|
||||
|
||||
self.attention_mask = (
|
||||
sparsify(self.attention_mask)
|
||||
if self.force_sparsity
|
||||
else maybe_sparsify(self.attention_mask)
|
||||
)
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.requires_same_k_q_dimensions = True
|
||||
self.supports_attention_mask = False
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
# Make sure that the mask is on the right device
|
||||
if self.attention_mask.device != q.device:
|
||||
self.attention_mask = self.attention_mask.to(q.device)
|
||||
|
||||
# Mask-aware attention
|
||||
if att_mask is not None:
|
||||
if att_mask.dtype == torch.bool and isinstance(
|
||||
self.attention_mask, AttentionMask
|
||||
):
|
||||
if not isinstance(att_mask, AttentionMask):
|
||||
att_mask = AttentionMask.from_bool(att_mask)
|
||||
mask = self.attention_mask + att_mask
|
||||
else:
|
||||
mask = self.attention_mask & att_mask
|
||||
else:
|
||||
mask = self.attention_mask
|
||||
|
||||
# Handle q/k/v which would not fit the mask
|
||||
seq_len = q.shape[-2]
|
||||
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
|
||||
|
||||
# Normal attention with the global tokens mask
|
||||
att = scaled_dot_product_attention(
|
||||
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
|
||||
)
|
||||
|
||||
# Take into account an hypothetical padding
|
||||
return att[:, :seq_len, :]
|
||||
78
pkgs/xformers/components/attention/lambda_layer.py
Normal file
78
pkgs/xformers/components/attention/lambda_layer.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
|
||||
|
||||
def calc_rel_pos(n: int):
|
||||
# Adapted from LucidRains
|
||||
# https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py
|
||||
rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n]
|
||||
rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2]
|
||||
return rel_pos
|
||||
|
||||
|
||||
@dataclass
|
||||
class LambdaLayerConfig(AttentionConfig):
|
||||
seq_len: int # dimension of the input sequence
|
||||
dim_head: int
|
||||
|
||||
|
||||
@register_attention("lambda", LambdaLayerConfig)
|
||||
class LambdaLayer(Attention):
|
||||
def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__):
|
||||
"""
|
||||
Attention approximation using Lambda layers, from
|
||||
"Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Possible extensions:
|
||||
# - support different dimensions for key and queries
|
||||
# - support varying dimensions in between inputs and outputs
|
||||
# - support u hyperparam
|
||||
|
||||
self.rel_pos_emb = torch.nn.Parameter(
|
||||
torch.randn(2 * seq_len - 1, int(dim_head))
|
||||
)
|
||||
self.rel_pos = calc_rel_pos(seq_len)
|
||||
self.attn_drop = torch.nn.Dropout(dropout, inplace=True)
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.requires_same_k_q_dimensions = True
|
||||
self.supports_attention_mask = False
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
||||
):
|
||||
"""..NOTE: We're reusing the einsum notation suggested by the paper, changed in that
|
||||
heads are folded in the batch dimension"""
|
||||
|
||||
content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v)
|
||||
content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda)
|
||||
|
||||
rel_pos_emb = self.rel_pos_emb[self.rel_pos]
|
||||
|
||||
# Handle real sequence length being possibly smaller
|
||||
seq_len = q.shape[1]
|
||||
rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :]
|
||||
|
||||
# Compute the position lambda for every possible combination in one go, then compute the
|
||||
# position related contribution
|
||||
position_lambdas = torch.einsum(
|
||||
"mnk,bnv->bnkv", rel_pos_emb, v
|
||||
) # one lambda per position
|
||||
position_output = (q.unsqueeze(2) @ position_lambdas).squeeze()
|
||||
att = content_output + position_output
|
||||
|
||||
att = self.attn_drop(att)
|
||||
|
||||
return att
|
||||
74
pkgs/xformers/components/attention/linformer.py
Normal file
74
pkgs/xformers/components/attention/linformer.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinformerSelfAttentionConfig(AttentionConfig):
|
||||
seq_len: int # dimension of the input sequence
|
||||
k: Optional[int] # dimension of the internal space
|
||||
|
||||
|
||||
@register_attention("linformer", LinformerSelfAttentionConfig)
|
||||
class LinformerAttention(Attention):
|
||||
def __init__(
|
||||
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Linformer attention mechanism,
|
||||
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
|
||||
The original notation is kept as is.
|
||||
|
||||
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if k is None:
|
||||
k = seq_len // 4
|
||||
|
||||
self.k = k
|
||||
self.E = nn.Linear(seq_len, k, bias=False)
|
||||
self.F = nn.Linear(seq_len, k, bias=False)
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
||||
self.seq_len = seq_len
|
||||
|
||||
# MHA related flags:
|
||||
# kq need to have the same dimension
|
||||
self.requires_same_k_q_dimensions = True
|
||||
|
||||
# This attention does not support attention masks
|
||||
self.supports_attention_mask = False
|
||||
|
||||
def forward(
|
||||
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
|
||||
):
|
||||
# Handle a smaller dimension than expected
|
||||
padding = 0
|
||||
if q.shape[1] < self.seq_len:
|
||||
padding = self.seq_len - q.shape[1]
|
||||
pad_dims = (0, 0, 0, padding)
|
||||
q = torch.nn.functional.pad(q, pad_dims)
|
||||
k = torch.nn.functional.pad(k, pad_dims)
|
||||
v = torch.nn.functional.pad(v, pad_dims)
|
||||
|
||||
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
|
||||
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
|
||||
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
|
||||
)
|
||||
|
||||
y = self.attn_drop(y)
|
||||
|
||||
return y[:, :-padding, :] if padding > 0 else y
|
||||
120
pkgs/xformers/components/attention/local.py
Normal file
120
pkgs/xformers/components/attention/local.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
maybe_sparsify,
|
||||
register_attention,
|
||||
sparsify,
|
||||
)
|
||||
from xformers.components.attention.attention_patterns import (
|
||||
causal_1d_pattern,
|
||||
local_1d_pattern,
|
||||
)
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalAttentionConfig(AttentionConfig):
|
||||
causal: Optional[bool] = None
|
||||
window_size: Optional[int] = None
|
||||
force_sparsity: Optional[bool] = None
|
||||
|
||||
|
||||
@register_attention("local", LocalAttentionConfig)
|
||||
class LocalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float = 0.0,
|
||||
causal: bool = False,
|
||||
window_size: int = 5,
|
||||
force_sparsity: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
r"""
|
||||
An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_
|
||||
|
||||
|
||||
Args:
|
||||
dropout (float): the probability of an output to be randomly dropped at training time
|
||||
causal (bool): apply a causal mask, in that the attention cannot be applied to the future
|
||||
window_size (int): the overall window size for local attention.
|
||||
Odd number is expected if the mask is not causal, as the window size will be evenly
|
||||
distributed on both sides of each query
|
||||
|
||||
|
||||
.. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf
|
||||
|
||||
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
|
||||
|
||||
.. _Longformer: https://arxiv.org/pdf/2004.05150.pdf
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
||||
self.causal = causal
|
||||
self.force_sparsity = force_sparsity
|
||||
|
||||
if not self.causal:
|
||||
assert (
|
||||
window_size % 2 == 1
|
||||
), "The window size is assumed to be odd (counts self-attention + 2 wings)"
|
||||
|
||||
self.window_size = window_size
|
||||
self.attention_mask: Optional[torch.Tensor] = None
|
||||
self.requires_same_k_q_dimensions = True
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def _get_local_mask(self, shape: torch.Size) -> torch.Tensor:
|
||||
window_size = self.window_size * 2 + 1 if self.causal else self.window_size
|
||||
mask = local_1d_pattern(shape[1], window_size)
|
||||
|
||||
if self.causal:
|
||||
mask &= causal_1d_pattern(shape[1])
|
||||
|
||||
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# Local window attention masking
|
||||
if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]:
|
||||
self.attention_mask = self._get_local_mask(q.shape).to(q.device)
|
||||
|
||||
# Take into account the optional user mask
|
||||
if att_mask is None:
|
||||
mask = self.attention_mask
|
||||
else:
|
||||
if isinstance(att_mask, AttentionMask):
|
||||
# Needed because & op not defined for SparseCS with AttentionMask
|
||||
att_mask = att_mask.to_bool()
|
||||
mask = self.attention_mask & att_mask
|
||||
|
||||
return scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop
|
||||
)
|
||||
295
pkgs/xformers/components/attention/nystrom.py
Normal file
295
pkgs/xformers/components/attention/nystrom.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
from xformers.components.attention.core import (
|
||||
scaled_dot_product_attention,
|
||||
scaled_query_key_softmax,
|
||||
)
|
||||
from xformers.components.attention.utils import (
|
||||
bool_mask_to_additive,
|
||||
iterative_pinv,
|
||||
reshape_key_padding_mask,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class NystromSelfAttentionConfig(AttentionConfig):
|
||||
"""
|
||||
num_heads Number of heads.
|
||||
num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good
|
||||
approximation according to https://arxiv.org/pdf/2102.03902.pdf.
|
||||
causal Apply a causal mask, in that the attention cannot be applied to the future.
|
||||
use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose
|
||||
inverse, otherwise use standard torch inverse.
|
||||
pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using
|
||||
method from (Razavi et al. 2014).
|
||||
False if using exact coefficient computation (leads to faster convergence).
|
||||
inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse.
|
||||
v_skip_connection A module that will take V as input and will be added as a skip connection to the
|
||||
softmax approximation. A skip connection is added in the paper to help with training.
|
||||
conv_kernel_size Kernel size for convolution optionally added to help in training.
|
||||
If v_skip_connection is not specified, this will be used to define the default
|
||||
depth wise convolution used as a skip connection.
|
||||
If both conv_kernel_size and v_skip_connection are None, no skip connection will
|
||||
be added.
|
||||
landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d.
|
||||
"""
|
||||
|
||||
num_heads: int
|
||||
num_landmarks: Optional[int]
|
||||
landmark_pooling: Optional[nn.Module]
|
||||
causal: Optional[bool]
|
||||
pinverse_original_init: Optional[bool]
|
||||
inv_iterations: Optional[int]
|
||||
v_skip_connection: Optional[nn.Module]
|
||||
conv_kernel_size: Optional[int]
|
||||
use_razavi_pinverse: Optional[bool]
|
||||
|
||||
|
||||
class AvgPool(nn.Module):
|
||||
def __init__(self, n: int):
|
||||
super().__init__()
|
||||
self.n = n
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# Average independently for every segment in the sequence dimension
|
||||
seq_len = x.shape[1]
|
||||
head_dim = x.shape[2]
|
||||
segments = seq_len // self.n
|
||||
assert segments > 0, "num_landmarks should be smaller than the sequence length"
|
||||
|
||||
# Dimensions are a match
|
||||
if seq_len % self.n == 0:
|
||||
return x.reshape(
|
||||
-1,
|
||||
self.n,
|
||||
segments,
|
||||
head_dim,
|
||||
).mean(dim=-2)
|
||||
|
||||
# Handle the last segment boundary being off
|
||||
n_round = self.n - seq_len % self.n
|
||||
|
||||
x_avg_round = (
|
||||
x[:, : n_round * segments, :]
|
||||
.reshape(-1, n_round, segments, head_dim)
|
||||
.mean(dim=-2)
|
||||
)
|
||||
x_avg_off = (
|
||||
x[:, n_round * segments :, :]
|
||||
.reshape(-1, self.n - n_round, segments + 1, head_dim)
|
||||
.mean(dim=-2)
|
||||
)
|
||||
return torch.cat((x_avg_round, x_avg_off), dim=-2)
|
||||
|
||||
|
||||
@register_attention("nystrom", NystromSelfAttentionConfig)
|
||||
class NystromAttention(Attention):
|
||||
# TODO: update defaults for use_razavi_pinverse and inv_iterations
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float,
|
||||
num_heads: int,
|
||||
num_landmarks: int = 64,
|
||||
landmark_pooling: Optional[nn.Module] = None,
|
||||
causal: bool = False,
|
||||
use_razavi_pinverse: bool = True,
|
||||
pinverse_original_init: bool = False,
|
||||
inv_iterations: int = 6, # recommended default in paper was 6.
|
||||
v_skip_connection: Optional[nn.Module] = None,
|
||||
conv_kernel_size: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Nystrom attention mechanism, from Nystromformer_.
|
||||
::
|
||||
|
||||
"A Nystrom-based Algorithm for Approximating Self-Attention."
|
||||
Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021)
|
||||
|
||||
Reference codebase: https://github.com/mlpen/Nystromformer
|
||||
|
||||
.. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# merged key padding mask and attention mask is not accepted
|
||||
self.requires_separate_masks = True
|
||||
self.num_landmarks = num_landmarks
|
||||
# TODO: should be able to not have to pass in num_heads
|
||||
self.num_heads = num_heads
|
||||
self.use_razavi_pinverse = use_razavi_pinverse
|
||||
self.pinverse_original_init = pinverse_original_init
|
||||
self.inv_iterations = inv_iterations
|
||||
self.attn_drop = nn.Dropout(dropout)
|
||||
self.skip_connection = v_skip_connection
|
||||
self.causal = causal
|
||||
|
||||
if self.skip_connection is None and conv_kernel_size is not None:
|
||||
self.skip_connection = nn.Conv2d(
|
||||
in_channels=self.num_heads,
|
||||
out_channels=self.num_heads,
|
||||
kernel_size=(conv_kernel_size, 1),
|
||||
padding=(conv_kernel_size // 2, 0),
|
||||
bias=False,
|
||||
groups=self.num_heads,
|
||||
)
|
||||
|
||||
if landmark_pooling is not None:
|
||||
self.landmark_pooling = landmark_pooling
|
||||
else:
|
||||
self.landmark_pooling = AvgPool(n=self.num_landmarks)
|
||||
|
||||
# Optional lower triangular masks for causal attention
|
||||
self.causal_mask_1: Optional[torch.Tensor] = None
|
||||
self.causal_mask_2: Optional[torch.Tensor] = None
|
||||
self.causal_mask_3: Optional[torch.Tensor] = None
|
||||
|
||||
# This attention does not support attention masks
|
||||
self.supports_attention_mask = False
|
||||
self.supports_key_padding_mask = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or
|
||||
(batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will
|
||||
be ignored. An additive mask is expected, meaning float values using "-inf" to mask values
|
||||
"""
|
||||
|
||||
batched_dim = k.size(0)
|
||||
seq_len = k.size(-2)
|
||||
tt = {"dtype": q.dtype, "device": q.device}
|
||||
|
||||
if key_padding_mask is not None:
|
||||
if key_padding_mask.dtype == torch.bool:
|
||||
logger.warning(
|
||||
"Bool mask found, but an additive mask is expected. Converting but this is slow"
|
||||
)
|
||||
|
||||
key_padding_mask = bool_mask_to_additive(key_padding_mask)
|
||||
|
||||
if key_padding_mask.ndim == 2:
|
||||
key_padding_mask = reshape_key_padding_mask(
|
||||
key_padding_mask, batched_dim
|
||||
)
|
||||
|
||||
zeros = torch.zeros_like(key_padding_mask)
|
||||
ones = torch.ones_like(key_padding_mask)
|
||||
is_masked = torch.isinf(-key_padding_mask)
|
||||
|
||||
# _mask takes 1 if the token is not padded, otherwise 0.
|
||||
_mask = torch.where(is_masked, zeros, ones)
|
||||
_mask = _mask.transpose(2, 1)
|
||||
assert _mask.shape == (batched_dim, q.shape[1], 1)
|
||||
|
||||
# Mask q and k before pooling
|
||||
# https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31
|
||||
q = q * _mask
|
||||
k = k * _mask
|
||||
|
||||
assert key_padding_mask.size() == (batched_dim, 1, seq_len), (
|
||||
f"key_padding_mask has invalid dimensions {key_padding_mask.size()}."
|
||||
f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})."
|
||||
)
|
||||
|
||||
if self.num_landmarks >= seq_len:
|
||||
mask: Optional[torch.Tensor] = None
|
||||
|
||||
if self.causal:
|
||||
mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
mask = key_padding_mask if mask is None else mask + key_padding_mask
|
||||
|
||||
x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask)
|
||||
|
||||
else:
|
||||
q_landmarks = self.landmark_pooling(q)
|
||||
k_landmarks = self.landmark_pooling(k)
|
||||
|
||||
if self.causal and (
|
||||
self.causal_mask_1 is None
|
||||
or (batched_dim, seq_len, self.num_landmarks)
|
||||
!= self.causal_mask_1.size()
|
||||
):
|
||||
self.causal_mask_1 = self._triu_mask(
|
||||
batched_dim, seq_len, self.num_landmarks, **tt
|
||||
)
|
||||
self.causal_mask_2 = self._triu_mask(
|
||||
batched_dim, self.num_landmarks, self.num_landmarks, **tt
|
||||
)
|
||||
self.causal_mask_3 = self._triu_mask(
|
||||
batched_dim, self.num_landmarks, seq_len, **tt
|
||||
)
|
||||
|
||||
mask_3: Optional[torch.Tensor] = self.causal_mask_3
|
||||
if key_padding_mask is not None:
|
||||
mask_3 = (
|
||||
key_padding_mask if mask_3 is None else mask_3 + key_padding_mask
|
||||
)
|
||||
|
||||
kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None)
|
||||
kernel_2 = scaled_query_key_softmax(
|
||||
q=q_landmarks, k=k_landmarks, att_mask=None
|
||||
)
|
||||
kernel_3 = scaled_dot_product_attention(
|
||||
q=q_landmarks, k=k, v=v, att_mask=mask_3
|
||||
)
|
||||
|
||||
kernel_2_inv = (
|
||||
iterative_pinv(
|
||||
kernel_2, self.inv_iterations, self.pinverse_original_init
|
||||
)
|
||||
if self.use_razavi_pinverse
|
||||
else torch.linalg.pinv(kernel_2)
|
||||
)
|
||||
|
||||
x = torch.matmul(
|
||||
torch.matmul(
|
||||
kernel_1,
|
||||
kernel_2_inv,
|
||||
),
|
||||
kernel_3,
|
||||
)
|
||||
|
||||
if self.skip_connection:
|
||||
# Assumption here is that v is 3D.
|
||||
v_conv = self.skip_connection(
|
||||
v.reshape(-1, self.num_heads, v.size(-2), v.size(-1))
|
||||
)
|
||||
x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1))
|
||||
x = self.attn_drop(x)
|
||||
return x
|
||||
|
||||
def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor:
|
||||
device = kwargs["device"]
|
||||
dtype = kwargs["dtype"]
|
||||
|
||||
return torch.triu(
|
||||
torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"),
|
||||
diagonal=1,
|
||||
).expand(
|
||||
dim_1, -1, -1
|
||||
) # micro optim, save memory on the batch dimension
|
||||
324
pkgs/xformers/components/attention/ortho.py
Normal file
324
pkgs/xformers/components/attention/ortho.py
Normal file
@@ -0,0 +1,324 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.autograd.profiler as profiler
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as Fn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
register_attention,
|
||||
)
|
||||
from xformers.components.attention.core import (
|
||||
scaled_dot_product_attention,
|
||||
scaled_query_key_softmax,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
class LandmarkSelection(str, Enum):
|
||||
Orthogonal = "orthogonal"
|
||||
KMeans = "kmeans"
|
||||
KMeans_Spherical = "kmeans_spherical"
|
||||
Random = "random"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrthoformerAttentionConfig(AttentionConfig):
|
||||
"""
|
||||
num_landmarks Number of landmarks to use for softmax approximation.
|
||||
subsample_fraction Percentage of q_samples matrix to sample per iteration
|
||||
landmark_selection Landmark selection strategy
|
||||
"""
|
||||
|
||||
num_landmarks: Optional[int]
|
||||
subsample_fraction: Optional[float]
|
||||
landmark_selection: Optional[LandmarkSelection]
|
||||
|
||||
|
||||
@register_attention("orthoformer", OrthoformerAttentionConfig)
|
||||
class OrthoFormerAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float,
|
||||
num_landmarks: int = 32,
|
||||
subsample_fraction: float = 1.0,
|
||||
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Orthoformer_ attention mechanism.
|
||||
::
|
||||
|
||||
"Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
|
||||
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
|
||||
C., Vedaldi, A., Henriques, J. (2021)
|
||||
|
||||
Reference codebase: https://github.com/facebookresearch/Motionformer
|
||||
|
||||
.. _Orthoformer: https://arxiv.org/abs/2106.05392
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_landmarks = num_landmarks
|
||||
self.attn_drop = nn.Dropout(dropout)
|
||||
self.subsample_fraction = subsample_fraction
|
||||
self.landmark_selection = landmark_selection
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
N = k.shape[1]
|
||||
|
||||
if self.num_landmarks == N:
|
||||
# Default attention
|
||||
x = scaled_dot_product_attention(q, k, v, att_mask)
|
||||
else:
|
||||
with torch.no_grad(), profiler.record_function("select landmarks"):
|
||||
if self.landmark_selection == LandmarkSelection.Orthogonal:
|
||||
landmarks = self._compute_orthogonal_landmarks(q)
|
||||
elif self.landmark_selection == LandmarkSelection.Random:
|
||||
half_L = self.num_landmarks // 2
|
||||
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
|
||||
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
|
||||
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
|
||||
elif self.landmark_selection == LandmarkSelection.KMeans:
|
||||
landmarks = self._cluster_landmarks(q)
|
||||
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
|
||||
landmarks = self._cluster_landmarks(q, spherical=True)
|
||||
|
||||
if att_mask is not None:
|
||||
logger.warning(
|
||||
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
|
||||
The two are typically not compatible"
|
||||
)
|
||||
# FIXME: Should we still accept a mask in that case ?
|
||||
att_mask = None
|
||||
|
||||
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
||||
# like it could be uninitialized.
|
||||
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
|
||||
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
||||
# like it could be uninitialized.
|
||||
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
|
||||
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
|
||||
x = self.attn_drop(x)
|
||||
return x
|
||||
|
||||
def _cluster_landmarks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
spherical: bool = False,
|
||||
num_iters: int = 6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct set of landmarks by recursively selecting new landmarks
|
||||
that are maximally orthogonal to the existing set.
|
||||
Returns near orthogonal landmarks with shape (B, M, D).
|
||||
"""
|
||||
|
||||
num_landmarks = min(self.num_landmarks, q.shape[1])
|
||||
|
||||
if self.subsample_fraction < 1.0:
|
||||
num_samples = max(
|
||||
int(self.subsample_fraction * q.size(-2)), num_landmarks
|
||||
) # Need at least M/2 samples of queries and keys
|
||||
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
|
||||
else:
|
||||
q_samples = q # (B, N, D)
|
||||
|
||||
if spherical:
|
||||
q_samples_normalized = Fn.normalize(
|
||||
q_samples, p=2, dim=-1
|
||||
) # may need to change default eps to eps=1e-8 for mixed precision compatibility
|
||||
landmarks = self._kmeans_spherical(
|
||||
q_samples_normalized, num_landmarks, num_iters
|
||||
)
|
||||
else:
|
||||
landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
|
||||
return landmarks # (B, M, D)
|
||||
|
||||
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
|
||||
"""
|
||||
Arguments:
|
||||
x: (B, N, D)
|
||||
K: number of clusters
|
||||
num_iters: the number of kmeans updates
|
||||
"""
|
||||
|
||||
B, N, D = x.size()
|
||||
assert K <= N, f"{K} > {N}"
|
||||
|
||||
c = x[
|
||||
:, torch.randperm(N, device=x.device)[:K], :
|
||||
].clone() # initialisation for the centroids
|
||||
|
||||
with profiler.record_function("kmeans"):
|
||||
x_i = x.view(B, N, 1, D)
|
||||
c_j = c.view(B, 1, K, D)
|
||||
counts = c.new_zeros(B, K)
|
||||
ones = x.new_ones((B, N))
|
||||
|
||||
for _ in range(num_iters):
|
||||
# E step: assign points to the nearest cluster
|
||||
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
|
||||
cl = D_ij.argmin(
|
||||
dim=-1, keepdim=True
|
||||
).long() # (B, N, 1) index of point to nearest cluster
|
||||
|
||||
# M step: update the centroids
|
||||
c.zero_()
|
||||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
||||
counts.fill_(1e-6) # avoid div0
|
||||
counts.scatter_add_(
|
||||
-1, cl.squeeze(-1), ones
|
||||
) # number of points per cluster
|
||||
c.divide_(counts.unsqueeze(-1)) # compute the average
|
||||
|
||||
return c
|
||||
|
||||
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
|
||||
"""
|
||||
Arguments:
|
||||
x: (B, N, D)
|
||||
"""
|
||||
B, N, D = x.size()
|
||||
assert K <= N, f"{K} > {N}"
|
||||
|
||||
# initialisation for the centroids
|
||||
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
|
||||
|
||||
with profiler.record_function("kmeans_spherical"):
|
||||
counts = c.new_zeros(B, K)
|
||||
ones = x.new_ones((B, N))
|
||||
|
||||
for _ in range(num_iters):
|
||||
# E step: assign points to the nearest cluster
|
||||
D_ij = torch.matmul(
|
||||
x, c.transpose(-2, -1)
|
||||
) # (B, N, K) cosine similarity
|
||||
cl = D_ij.argmax(
|
||||
dim=-1, keepdim=True
|
||||
).long() # (B, N, 1) index of point to nearest cluster
|
||||
|
||||
# M step: update the centroids
|
||||
c.zero_()
|
||||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
||||
counts.fill_(1e-6) # avoid div0
|
||||
counts.scatter_add_(
|
||||
-1, cl.squeeze(-1), ones
|
||||
) # number of points per cluster
|
||||
c.divide_(counts.unsqueeze(-1)) # compute the average
|
||||
c = Fn.normalize(c, p=2, dim=-1) # renormalise
|
||||
return c
|
||||
|
||||
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Construct set of landmarks by recursively selecting new landmarks
|
||||
that are maximally orthogonal to the existing set.
|
||||
Returns near orthogonal landmarks with shape (B, M, D).
|
||||
"""
|
||||
|
||||
if self.subsample_fraction < 1.0:
|
||||
# Need at least M samples of queries
|
||||
num_samples = max(
|
||||
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
|
||||
)
|
||||
q_samples = q[
|
||||
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
|
||||
]
|
||||
else:
|
||||
# (B, N, D)
|
||||
q_samples = q
|
||||
|
||||
# may need to change default eps to eps=1e-8 for mixed precision compatibility
|
||||
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
|
||||
B, N, D = q_samples_normalized.shape
|
||||
|
||||
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
|
||||
landmark_mask = torch.ones(
|
||||
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
|
||||
)
|
||||
|
||||
# Get initial random landmark
|
||||
random_idx = torch.randint(
|
||||
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
|
||||
)
|
||||
selected_mask.scatter_(-2, random_idx, landmark_mask)
|
||||
|
||||
# Selected landmarks
|
||||
selected_landmarks = torch.empty(
|
||||
(B, self.num_landmarks, D),
|
||||
device=q_samples_normalized.device,
|
||||
dtype=q_samples_normalized.dtype,
|
||||
)
|
||||
selected_landmarks[:, 0, :] = q_samples_normalized[
|
||||
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
|
||||
].view(B, D)
|
||||
|
||||
# Store computed cosine similarities
|
||||
cos_sims = torch.empty(
|
||||
(B, N, self.num_landmarks),
|
||||
device=q_samples_normalized.device,
|
||||
dtype=q_samples_normalized.dtype,
|
||||
)
|
||||
|
||||
for M in range(1, self.num_landmarks):
|
||||
with profiler.record_function("find new landmark"):
|
||||
# Calculate absolute cosine similarity between selected and unselected landmarks
|
||||
# (B, N, D) * (B, D) -> (B, N)
|
||||
cos_sims[:, :, M - 1] = torch.einsum(
|
||||
"b n d, b d -> b n",
|
||||
q_samples_normalized,
|
||||
selected_landmarks[:, M - 1, :],
|
||||
).abs()
|
||||
|
||||
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
|
||||
cos_sim_set = cos_sims[:, :, :M]
|
||||
|
||||
# Get orthogonal landmark: landmark with smallest absolute cosine similarity:
|
||||
# set cosine similarity for already selected landmarks to > 1
|
||||
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
|
||||
|
||||
# (B,) - want max for non
|
||||
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
|
||||
|
||||
# Add most orthogonal landmark to selected landmarks:
|
||||
selected_landmarks[:, M, :] = q_samples_normalized[
|
||||
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
|
||||
].view(B, D)
|
||||
|
||||
# Removed selected indices from non-selected mask:
|
||||
selected_mask.scatter_(
|
||||
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
|
||||
)
|
||||
|
||||
# (B, M, D)
|
||||
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
|
||||
B, -1, D
|
||||
)
|
||||
return landmarks # (B, M, D)
|
||||
82
pkgs/xformers/components/attention/pooling.py
Normal file
82
pkgs/xformers/components/attention/pooling.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingAttentionConfig(AttentionConfig):
|
||||
pool_size: int # dimension of the input sequence
|
||||
stride: Optional[int] # dimension of the internal space
|
||||
padding: Optional[int]
|
||||
|
||||
|
||||
@register_attention("pooling", PoolingAttentionConfig)
|
||||
class Pooling(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
pool_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: Optional[int] = None,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
"""
|
||||
Pooling token mixing mechanism, as proposed in
|
||||
`Metaformer is actually what you need for vision`_, Yu et al (2021).
|
||||
|
||||
The original notation is kept as is.
|
||||
|
||||
.. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
padding = padding if padding is not None else pool_size // 2
|
||||
self.pool = nn.AvgPool2d(
|
||||
pool_size,
|
||||
stride=stride,
|
||||
padding=pool_size // 2,
|
||||
count_include_pad=False,
|
||||
)
|
||||
|
||||
# MHA related flags:
|
||||
# kq need to have the same dimension
|
||||
self.requires_same_k_q_dimensions = False
|
||||
|
||||
# This attention does not support attention masks
|
||||
self.supports_attention_mask = False
|
||||
|
||||
# This "attention" (token mixing) skips the multihead attention altogether
|
||||
self.requires_skip_multi_head = True
|
||||
self.requires_input_projection = False
|
||||
|
||||
# This operator does not really handle q,k,v
|
||||
self.requires_same_k_q_dimensions = True
|
||||
|
||||
# This attention requires the 2d structure out of the context,
|
||||
# implictly assumed to be a squared length
|
||||
self.requires_squared_context = True
|
||||
|
||||
def forward(self, q: torch.Tensor, *_, **__):
|
||||
# Expose the 2D token structure
|
||||
B, HW, C = q.shape
|
||||
H = int(math.sqrt(HW))
|
||||
assert H * H == HW
|
||||
|
||||
q = q.transpose(-2, -1).reshape(B, C, H, H)
|
||||
|
||||
# 2D pool
|
||||
x_pool = self.pool(q) - q # compensate for the residual path
|
||||
|
||||
# Get back to B HW C
|
||||
return x_pool.flatten(2, 3).transpose(-2, -1)
|
||||
126
pkgs/xformers/components/attention/random.py
Normal file
126
pkgs/xformers/components/attention/random.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
maybe_sparsify,
|
||||
register_attention,
|
||||
sparsify,
|
||||
)
|
||||
from xformers.components.attention.attention_patterns import (
|
||||
causal_1d_pattern,
|
||||
random_pattern,
|
||||
)
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class RandomAttentionConfig(AttentionConfig):
|
||||
r: Optional[
|
||||
float
|
||||
] # the ratio of keys that the query can attend to. 1.0 means dense attention
|
||||
constant_masking: Optional[
|
||||
bool
|
||||
] # whether the randomness is per query or defined at construction time
|
||||
force_sparsity: Optional[bool] # use sparsity in any case (potentially slower)
|
||||
|
||||
|
||||
@register_attention("random", RandomAttentionConfig)
|
||||
class RandomAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float,
|
||||
causal: bool = False,
|
||||
r: float = 0.01,
|
||||
constant_masking: bool = True,
|
||||
force_sparsity: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
"Random" attention, as proposed for instance in BigBird_.
|
||||
Random means in that case that each query can attend to a random set of keys.
|
||||
This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory.
|
||||
|
||||
Args:
|
||||
r (float): the ratio in [0,1] of keys that the query can attend to
|
||||
constant_masking (bool): if true, keep the same random set for all queries.
|
||||
|
||||
.. _BigBird: https://arxiv.org/pdf/2007.14062.pdf
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
||||
self.causal = causal
|
||||
self.r = r
|
||||
self.rand_attention_mask: Optional[torch.Tensor] = None
|
||||
self.constant_masking = constant_masking
|
||||
self.force_sparsity = force_sparsity
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
self.requires_same_k_q_dimensions = True
|
||||
|
||||
def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor:
|
||||
sparsity = 1 - self.r
|
||||
mask = random_pattern(shape[1], sparsity=sparsity)
|
||||
|
||||
if self.causal:
|
||||
mask &= causal_1d_pattern(shape[1])
|
||||
|
||||
mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
# Rand masking
|
||||
if not self.constant_masking or self.rand_attention_mask is None:
|
||||
self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device)
|
||||
|
||||
# Mask-aware attention
|
||||
if att_mask is not None:
|
||||
if att_mask.dtype == torch.bool and isinstance(
|
||||
self.rand_attention_mask, AttentionMask
|
||||
):
|
||||
mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask)
|
||||
else:
|
||||
if isinstance(att_mask, AttentionMask):
|
||||
# Needed because & op not defined for SparseCS with AttentionMask
|
||||
att_mask = att_mask.to_bool()
|
||||
mask = self.rand_attention_mask & att_mask
|
||||
else:
|
||||
mask = self.rand_attention_mask
|
||||
|
||||
# Handle q/k/v which would not fit the mask
|
||||
seq_len = q.shape[-2]
|
||||
q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v))
|
||||
|
||||
# Normal attention with the random mask
|
||||
att = scaled_dot_product_attention(
|
||||
q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop
|
||||
)
|
||||
|
||||
# Take into account an hypothetical padding
|
||||
return att[:, :seq_len, :]
|
||||
134
pkgs/xformers/components/attention/scaled_dot_product.py
Normal file
134
pkgs/xformers/components/attention/scaled_dot_product.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from xformers.components.attention import (
|
||||
Attention,
|
||||
AttentionConfig,
|
||||
AttentionMask,
|
||||
register_attention,
|
||||
)
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledDotProductConfig(AttentionConfig):
|
||||
causal: Optional[bool]
|
||||
seq_len: Optional[int]
|
||||
to_seq_len: Optional[int]
|
||||
|
||||
|
||||
@register_attention("scaled_dot_product", ScaledDotProductConfig)
|
||||
class ScaledDotProduct(Attention):
|
||||
r"""
|
||||
Implementing the Scaled Dot-Product attention proposed in
|
||||
`Attention is all you need`_, Vaswani et al.
|
||||
|
||||
.. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
|
||||
"""
|
||||
|
||||
mask: Optional[AttentionMask]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout: float = 0.0,
|
||||
causal: bool = False,
|
||||
seq_len: Optional[int] = None,
|
||||
to_seq_len: Optional[int] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_drop = nn.Dropout(dropout, inplace=False)
|
||||
self.causal = causal
|
||||
self.seq_len = seq_len
|
||||
|
||||
if causal and seq_len is not None:
|
||||
self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
# Properties specific to this attention mechanism
|
||||
self.supports_attention_mask = True
|
||||
self.supports_key_padding_mask = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
att_mask A 2D or 3D mask which ignores attention at certain positions.
|
||||
|
||||
- If the mask is boolean, a value of True will keep the value,
|
||||
while a value of False will mask the value.
|
||||
|
||||
Key padding masks (dimension: batch x sequence length) and attention masks
|
||||
(dimension: sequence length x sequence length OR batch x sequence length x sequence length)
|
||||
can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
|
||||
used for that merging.
|
||||
|
||||
- If the mask has the float type, then an additive mask is expected (masked values are -inf)
|
||||
|
||||
"""
|
||||
|
||||
# Convenience, create an attention mask if a tensor was passed
|
||||
if att_mask is not None and isinstance(att_mask, torch.Tensor):
|
||||
# By default we don't know of the causality, and a check would be expensive
|
||||
att_mask = (
|
||||
AttentionMask.from_bool(att_mask)
|
||||
if att_mask.dtype == torch.bool
|
||||
else AttentionMask(att_mask, is_causal=False)
|
||||
)
|
||||
|
||||
# Handle a possibly deferred causal mask handling
|
||||
mask = self.mask
|
||||
if self.causal and self.mask is None:
|
||||
mask = AttentionMask.make_causal(
|
||||
seq_len=q.shape[-2],
|
||||
to_seq_len=q.shape[-2],
|
||||
device=q.device,
|
||||
dtype=q.dtype,
|
||||
)
|
||||
|
||||
# Merge the optional causal mask and the user-provided mask
|
||||
if mask is not None:
|
||||
mask = mask.to(dtype=q.dtype, device=q.device)
|
||||
|
||||
att_mask = att_mask + mask if att_mask is not None else mask
|
||||
|
||||
# Try to handle a case where the sequence is smaller than the mask
|
||||
if (
|
||||
att_mask is not None
|
||||
and q.shape[-2] == k.shape[-2]
|
||||
and q.shape[-2] < att_mask.shape[1]
|
||||
):
|
||||
if isinstance(att_mask, AttentionMask):
|
||||
att_mask = att_mask.make_crop(seq_len=q.shape[-2])
|
||||
else:
|
||||
logger.error(
|
||||
"Mismatching sparse attention mask and sequence length."
|
||||
+ " Please pad the inputs or adjust the attention mask"
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
# Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
|
||||
y = scaled_dot_product_attention(
|
||||
q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop
|
||||
)
|
||||
return y
|
||||
812
pkgs/xformers/components/attention/sparsity_config.py
Normal file
812
pkgs/xformers/components/attention/sparsity_config.py
Normal file
@@ -0,0 +1,812 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
The code has been adopted from DeepSpeed
|
||||
(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py)
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SparsityConfig:
|
||||
"""Abstract Configuration class to store `sparsity configuration of a self attention layer`.
|
||||
It contains shared property of different block-sparse sparsity patterns. However, each class
|
||||
needs to extend it based on required property and functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
|
||||
"""Initialize the Sparsity Pattern Config.
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of
|
||||
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
||||
defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be
|
||||
assigned a different sparsity layout; default is false and this will be satisfied
|
||||
based on availability.
|
||||
"""
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.block_size = block_size
|
||||
self.different_layout_per_head = different_layout_per_head
|
||||
self.num_layout_heads = num_heads if different_layout_per_head else 1
|
||||
|
||||
def setup_layout(self, seq_len):
|
||||
"""Create layout tensor for the given sequence length
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout
|
||||
of all head; initialized with zero
|
||||
"""
|
||||
|
||||
if seq_len % self.block_size != 0:
|
||||
raise ValueError(
|
||||
f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!"
|
||||
)
|
||||
num_blocks = seq_len // self.block_size
|
||||
# TODO Currently we allocate layout per head; needs to be updated if heads share a single layout.
|
||||
layout = torch.zeros(
|
||||
(self.num_heads, num_blocks, num_blocks), dtype=torch.int64
|
||||
)
|
||||
return layout
|
||||
|
||||
def check_and_propagate_first_head_layout(self, layout):
|
||||
"""If all heads require same sparsity layout, it propagate first head layout to all heads
|
||||
Arguments:
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head
|
||||
"""
|
||||
|
||||
if not self.different_layout_per_head:
|
||||
layout[1 : self.num_heads, :, :] = layout[0, :, :]
|
||||
return layout
|
||||
|
||||
|
||||
class DenseSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Dense` configuration.
|
||||
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and
|
||||
comprehension.
|
||||
"""
|
||||
|
||||
def __init__(self, num_heads, block_size=16, different_layout_per_head=False):
|
||||
"""Initialize the Dense Sparsity Pattern Config.
|
||||
In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison
|
||||
and comprehension.
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of
|
||||
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
||||
defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: this is just for the sake of consistency with
|
||||
other sparsity formats; can ignore it for DenseSparsityConfig
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block_size, different_layout_per_head)
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Set 1 to all blocks of the layout meanins the pattern is dense; not sparse.
|
||||
Arguments:
|
||||
seq_len: required: an integer determining the underling sequence length;
|
||||
must be <= max sequence length
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head; for dense everything is 1
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
layout[:, :, :] = 1
|
||||
return layout
|
||||
|
||||
|
||||
class FixedSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Fixed` sparsity configuration.
|
||||
For more details about this sparsity config, please see `Generative Modeling with
|
||||
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
block_size=16,
|
||||
different_layout_per_head=False,
|
||||
num_local_blocks=4,
|
||||
num_global_blocks=1,
|
||||
attention="bidirectional",
|
||||
horizontal_global_attention=False,
|
||||
num_different_global_patterns=1,
|
||||
):
|
||||
"""Initialize `Fixed` Sparsity Pattern Config.
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of
|
||||
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
||||
defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be
|
||||
assigned a different sparsity layout; default is false and this will be satisfied
|
||||
based on availability.
|
||||
num_local_blocks: optional: an integer determining the number of blocks in local attention
|
||||
window.
|
||||
num_global_blocks: optional: an integer determining how many consecutive blocks in a local
|
||||
window is used as the representative of the window for global attention.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
||||
such as autoregressive models, in which tokens attend only to tokens appear before them
|
||||
in the context. Considering that, the upper triangular of attention matrix is empty as
|
||||
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
||||
any other tokens before or after them. Then, the upper triangular part of the attention
|
||||
matrix is mirror of the lower triangular in the above figure.
|
||||
horizontal_global_attention: optional: a boolean determining if blocks that are global
|
||||
representative of a local window, also attend to all other blocks. This is valid only if
|
||||
attention type is `bidirectional`. Looking at the attention matrix, that means global
|
||||
attention not only includes the vertical blocks, but also horizontal blocks.
|
||||
num_different_global_patterns: optional: an integer determining number of different global
|
||||
attentions layouts. While global attention can be fixed by which block/s are representative
|
||||
of any local window, since there are multi-heads, each head can use a different global representative.
|
||||
For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different
|
||||
versions in which the first, Second, third, or forth block of each local window can be global
|
||||
representative of that window. This parameter determines how many of such patterns we want.
|
||||
Of course, there is a limitation based on num_local_blocks and num_global_blocks.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block_size, different_layout_per_head)
|
||||
|
||||
self.num_local_blocks = num_local_blocks
|
||||
|
||||
if num_local_blocks % num_global_blocks != 0:
|
||||
raise ValueError(
|
||||
f"""Number of blocks in a local window, {num_local_blocks},
|
||||
must be dividable by number of global blocks, {num_global_blocks}!"""
|
||||
)
|
||||
self.num_global_blocks = num_global_blocks
|
||||
|
||||
if attention != "unidirectional" and attention != "bidirectional":
|
||||
raise NotImplementedError(
|
||||
'only "uni/bi-directional" attentions are supported for now!'
|
||||
)
|
||||
self.attention = attention
|
||||
|
||||
if attention != "bidirectional" and horizontal_global_attention:
|
||||
raise ValueError(
|
||||
'only "bi-directional" attentions can support horizontal global attention!'
|
||||
)
|
||||
self.horizontal_global_attention = horizontal_global_attention
|
||||
|
||||
if num_different_global_patterns > 1 and not different_layout_per_head:
|
||||
raise ValueError(
|
||||
"""Number of different layouts cannot be more than one when you have set a single layout
|
||||
for all heads! Set different_layout_per_head to True."""
|
||||
)
|
||||
if num_different_global_patterns > (num_local_blocks // num_global_blocks):
|
||||
raise ValueError(
|
||||
f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns},
|
||||
cannot be larger than number of local window blocks divided by number of global blocks,
|
||||
{num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!"""
|
||||
)
|
||||
self.num_different_global_patterns = num_different_global_patterns
|
||||
|
||||
def set_local_layout(self, h, layout):
|
||||
"""Sets local attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which local layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
for i in range(0, num_blocks, self.num_local_blocks):
|
||||
end = min(i + self.num_local_blocks, num_blocks)
|
||||
for row in range(i, end):
|
||||
for col in range(
|
||||
i, (row + 1 if self.attention == "unidirectional" else end)
|
||||
):
|
||||
layout[h, row, col] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attention layout used by the given head in the sparse attention.
|
||||
Currently we set global blocks starting from the last block of a local window to the first one.
|
||||
That means if a local window consists of 4 blocks and global attention size is one block, we use
|
||||
block #4 in each local window as global. If we have different layout per head, then other heads
|
||||
will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global
|
||||
attentions, multiple head may have same global attentions.
|
||||
Note) if horizontal_global_attention is set, global blocks will be set both horizontally and
|
||||
vertically.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
first_global_block_idx = (
|
||||
self.num_local_blocks
|
||||
- (1 + h % self.num_different_global_patterns) * self.num_global_blocks
|
||||
)
|
||||
|
||||
# set all global blocks except the last one if (in last local window)
|
||||
end = num_blocks - (num_blocks % self.num_local_blocks)
|
||||
for i in range(first_global_block_idx, end, self.num_local_blocks):
|
||||
|
||||
# vertical global attention
|
||||
first_row = 0 if self.attention == "bidirectional" else i
|
||||
# (((i // self.num_local_blocks) + 1) * self.num_local_blocks)
|
||||
# if (first_row < num_blocks):
|
||||
layout[h, first_row:, i : i + self.num_global_blocks] = 1
|
||||
|
||||
# horizontal global attention; only in bidirectional attention
|
||||
if self.horizontal_global_attention:
|
||||
layout[h, i : i + self.num_global_blocks, :] = 1
|
||||
|
||||
# set last global blocks; handle possible short last local window
|
||||
if end < num_blocks:
|
||||
start = min(
|
||||
end + first_global_block_idx, num_blocks - self.num_global_blocks
|
||||
)
|
||||
end = start + self.num_global_blocks
|
||||
|
||||
# vertical global attention
|
||||
first_row = 0 if self.attention == "bidirectional" else start
|
||||
# (((start // self.num_local_blocks) + 1) * self.num_local_blocks)
|
||||
# if (first_row < num_blocks):
|
||||
layout[h, first_row:, start:end] = 1
|
||||
|
||||
# horizontal global attention
|
||||
if self.horizontal_global_attention:
|
||||
layout[h, start:end, :] = 1
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `Fixed` sparsity layout used by each head in the sparse attention.
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed`
|
||||
sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_local_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class VariableSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `Variable` sparsity configuration.
|
||||
This layout is an extension of FixedSparsityConfig in which:
|
||||
- user can set random layout; default value is zero means no random block
|
||||
- user can provide a list of local block sizes
|
||||
- user can provide a list of global block indices.
|
||||
For more details about `Fixed` sparsity config, please see `Generative Modeling with
|
||||
Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized.
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
block_size=16,
|
||||
different_layout_per_head=False,
|
||||
num_random_blocks=0,
|
||||
local_window_blocks=[4],
|
||||
global_block_indices=[0],
|
||||
global_block_end_indices=None,
|
||||
attention="bidirectional",
|
||||
horizontal_global_attention=False,
|
||||
):
|
||||
"""Initialize `Variable` Sparsity Pattern Config.
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of sparse
|
||||
self-attention is based on blocked sparse matrices. In which this parameter defines
|
||||
size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a
|
||||
different sparsity layout; default is false and this will be satisfied based on
|
||||
availability. Currently this sparsity config can only assign single layout to all heads;
|
||||
needs to be extended for different layout per head.
|
||||
num_random_blocks: optional: an integer determining the number of random blocks in each block row.
|
||||
local_window_blocks: optional: a list of integers determining the number of blocks in each
|
||||
local attention window. It assumes first number determines # of blocks in the first local
|
||||
window, second the second window, ..., and the last number determines the number of blocks
|
||||
in the remaining local windows.
|
||||
global_block_indices: optional: a list of integers determining which blocks are considered
|
||||
as global attention. Given indices, determine the blocks that all other token blocks
|
||||
attend to and they attend to all other token blocks. Default value is only index 0.
|
||||
Notice that if global_block_end_indices parameter is set, this parameter is used as
|
||||
starting index of each global window.
|
||||
global_block_end_indices: optional: a list of integers determining end indices of global
|
||||
window blocks. By default this is not used. But if it is set, it must have the same size
|
||||
of global_block_indices parameter, and combining this two parameters, for each index i,
|
||||
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
|
||||
considered as global attention.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
||||
such as autoregressive models, in which tokens attend only to tokens appear before them
|
||||
in the context. Considering that, the upper triangular of attention matrix is empty as
|
||||
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
||||
any other tokens before or after them. Then, the upper triangular part of the attention
|
||||
matrix is mirror of the lower triangular in the above figure.
|
||||
horizontal_global_attention: optional: a boolean determining if blocks that are global
|
||||
representative of a local window, also attend to all other blocks. This is valid only if
|
||||
attention type is `bidirectional`. Looking at the attention matrix, that means global
|
||||
attention not only includes the vertical blocks, but also horizontal blocks.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block_size, different_layout_per_head)
|
||||
|
||||
self.num_random_blocks = num_random_blocks
|
||||
self.local_window_blocks = local_window_blocks
|
||||
self.global_block_indices = global_block_indices
|
||||
|
||||
if global_block_end_indices is not None:
|
||||
if len(global_block_indices) != len(global_block_end_indices):
|
||||
raise ValueError(
|
||||
f"""Global block start indices length, {len(global_block_indices)}, must be same as
|
||||
global block end indices length, {len(global_block_end_indices)}!"""
|
||||
)
|
||||
for _, (start_idx, end_idx) in enumerate(
|
||||
zip(global_block_indices, global_block_end_indices)
|
||||
):
|
||||
if start_idx >= end_idx:
|
||||
raise ValueError(
|
||||
f"""Global block start index, {start_idx}, must be smaller than global block end
|
||||
index, {end_idx}!"""
|
||||
)
|
||||
self.global_block_end_indices = global_block_end_indices
|
||||
|
||||
if attention != "unidirectional" and attention != "bidirectional":
|
||||
raise NotImplementedError(
|
||||
'only "uni/bi-directional" attentions are supported for now!'
|
||||
)
|
||||
self.attention = attention
|
||||
|
||||
if attention != "bidirectional" and horizontal_global_attention:
|
||||
raise ValueError(
|
||||
'only "bi-directional" attentions can support horizontal global attention!'
|
||||
)
|
||||
self.horizontal_global_attention = horizontal_global_attention
|
||||
|
||||
def set_random_layout(self, h, layout):
|
||||
"""Sets random attention layout used by the given head in the sparse attention.
|
||||
Note) By default, it assumes there will be a unique random block layout for all heads; unless
|
||||
`different_layout_per_head` parameter is set in which each head can have a different random
|
||||
layout.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which random layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if num_blocks < self.num_random_blocks:
|
||||
raise ValueError(
|
||||
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
|
||||
of blocks in a row, {num_blocks}!"""
|
||||
)
|
||||
for row in range(0, num_blocks):
|
||||
rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks)
|
||||
layout[h, row, rnd_cols] = 1
|
||||
return layout
|
||||
|
||||
def set_local_layout(self, h, layout):
|
||||
"""Sets local attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which local layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
start_block_idx = 0
|
||||
end_block_idx = 0
|
||||
for block_size in self.local_window_blocks:
|
||||
end_block_idx += block_size
|
||||
end_block_idx = min(end_block_idx, num_blocks)
|
||||
for row in range(start_block_idx, end_block_idx):
|
||||
for col in range(
|
||||
start_block_idx,
|
||||
(row + 1 if self.attention == "unidirectional" else end_block_idx),
|
||||
):
|
||||
layout[h, row, col] = 1
|
||||
start_block_idx += block_size
|
||||
|
||||
# if there is any remaining not attended part, use the lats local window block size as local
|
||||
# window for the remaining applicable local windows
|
||||
for i in range(start_block_idx, num_blocks, block_size):
|
||||
end_block_idx = min(i + block_size, num_blocks)
|
||||
for row in range(i, end_block_idx):
|
||||
for col in range(
|
||||
i,
|
||||
(row + 1 if self.attention == "unidirectional" else end_block_idx),
|
||||
):
|
||||
layout[h, row, col] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if self.global_block_end_indices is None:
|
||||
for idx in self.global_block_indices:
|
||||
# if global block idx is in the range of the sequence blocks
|
||||
if idx < num_blocks:
|
||||
# global rows
|
||||
if self.horizontal_global_attention:
|
||||
layout[h, idx, :] = 1
|
||||
|
||||
# global columns
|
||||
first_row = 0 if self.attention == "bidirectional" else idx
|
||||
layout[h, first_row:, idx] = 1
|
||||
else:
|
||||
for _, (start_idx, end_idx) in enumerate(
|
||||
zip(self.global_block_indices, self.global_block_end_indices)
|
||||
):
|
||||
# if global block idx is in the range of the sequence blocks
|
||||
if start_idx < num_blocks:
|
||||
end_idx = min(end_idx, num_blocks)
|
||||
# global rows
|
||||
if self.horizontal_global_attention:
|
||||
layout[h, start_idx:end_idx, :] = 1
|
||||
|
||||
# global columns
|
||||
first_row = 0 if self.attention == "bidirectional" else start_idx
|
||||
layout[h, first_row:, start_idx:end_idx] = 1
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `Variable` sparsity layout used by each head in the sparse attention.
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable`
|
||||
sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_random_layout(h, layout)
|
||||
layout = self.set_local_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class BigBirdSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store `BigBird` sparsity configuration.
|
||||
For more details about this sparsity config, please see `Big Bird: Transformers for
|
||||
Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
block_size=16,
|
||||
different_layout_per_head=False,
|
||||
num_random_blocks=1,
|
||||
num_sliding_window_blocks=3,
|
||||
num_global_blocks=1,
|
||||
attention="bidirectional",
|
||||
):
|
||||
"""Initialize the BigBird Sparsity Pattern Config.
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of
|
||||
sparse self-attention is based on blocked sparse matrices. In which this parameter
|
||||
defines size of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned
|
||||
a different sparsity layout; default is false and this will be satisfied based on
|
||||
availability.
|
||||
num_random_blocks: optional: an integer determining the number of random blocks in each
|
||||
block row.
|
||||
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
|
||||
local attention window.
|
||||
num_global_blocks: optional: an integer determining how many consecutive blocks, starting
|
||||
from index 0, are considered as global attention. Global block tokens will be attended
|
||||
by all other block tokens and will attend to all other block tokens as well.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
||||
such as autoregressive models, in which tokens attend only to tokens appear before them
|
||||
in the context. Considering that, the upper triangular of attention matrix is empty as
|
||||
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
||||
any other tokens before or after them. Then, the upper triangular part of the attention
|
||||
matrix is mirror of the lower triangular in the above figure.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block_size, different_layout_per_head)
|
||||
|
||||
self.num_random_blocks = num_random_blocks
|
||||
self.num_sliding_window_blocks = num_sliding_window_blocks
|
||||
self.num_global_blocks = num_global_blocks
|
||||
|
||||
if attention != "unidirectional" and attention != "bidirectional":
|
||||
raise NotImplementedError(
|
||||
'only "uni/bi-directional" attentions are supported for now!'
|
||||
)
|
||||
self.attention = attention
|
||||
|
||||
def set_random_layout(self, h, layout):
|
||||
"""Sets random attention layout used by the given head in the sparse attention.
|
||||
Note) By default, it assumes there will be a unique random block layout for all heads; unless
|
||||
`different_layout_per_head` parameter is set in which each head can have a different random layout.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which random layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if num_blocks < self.num_random_blocks:
|
||||
raise ValueError(
|
||||
f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number
|
||||
of blocks in a row, {num_blocks}!"""
|
||||
)
|
||||
|
||||
for row in range(0, num_blocks):
|
||||
sample_range = (
|
||||
range(0, num_blocks)
|
||||
if self.attention == "bidirectional"
|
||||
else range(0, row + 1)
|
||||
)
|
||||
rnd_cols = random.sample(sample_range, self.num_random_blocks)
|
||||
layout[h, row, rnd_cols] = 1
|
||||
return layout
|
||||
|
||||
def set_sliding_window_layout(self, h, layout):
|
||||
"""Sets sliding local attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which local sliding window layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if num_blocks < self.num_sliding_window_blocks:
|
||||
raise ValueError(
|
||||
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than
|
||||
overall number of blocks in a row, {num_blocks}!"""
|
||||
)
|
||||
|
||||
w = self.num_sliding_window_blocks // 2
|
||||
for row in range(0, num_blocks):
|
||||
start = max(0, row - w)
|
||||
end = min(row + w + 1, num_blocks)
|
||||
layout[h, row, start:end] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout_itc(self, h, layout):
|
||||
"""Sets global attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
|
||||
of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if num_blocks < self.num_global_blocks:
|
||||
raise ValueError(
|
||||
f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number
|
||||
of blocks in a row, {num_blocks}!"""
|
||||
)
|
||||
|
||||
# global rows
|
||||
layout[h, 0 : self.num_global_blocks, :] = 1
|
||||
|
||||
# global columns
|
||||
layout[h, :, 0 : self.num_global_blocks] = 1
|
||||
|
||||
if self.attention == "unidirectional":
|
||||
# zero out anything attending to the future
|
||||
layout = torch.tril(layout)
|
||||
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates `BigBird` sparsity layout used by each head in the sparse attention.
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird`
|
||||
sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_random_layout(h, layout)
|
||||
layout = self.set_sliding_window_layout(h, layout)
|
||||
layout = self.set_global_layout_itc(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
|
||||
|
||||
class BSLongformerSparsityConfig(SparsityConfig):
|
||||
"""Configuration class to store edited `Longformer` sparsity configuration.
|
||||
Note) this is a block-sparse version of the Longformer which is slightly different than original
|
||||
Longformer; which is element-wise sparsity.
|
||||
For more details about this sparsity config, please see `Longformer:
|
||||
The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf
|
||||
This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
block_size=16,
|
||||
different_layout_per_head=False,
|
||||
num_sliding_window_blocks=3,
|
||||
global_block_indices=[0],
|
||||
global_block_end_indices=None,
|
||||
attention="bidirectional",
|
||||
):
|
||||
"""Initialize the edited `Longformer` Sparsity Pattern Config.
|
||||
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial
|
||||
Arguments:
|
||||
num_heads: required: an integer determining number of attention heads of the layer.
|
||||
block_size: optional: an integer determining the block size. Current implementation of sparse
|
||||
self-attention is based on blocked sparse matrices. In which this parameter defines size
|
||||
of such blocks, `Block X Block`.
|
||||
different_layout_per_head: optional: a boolean determining if each head should be assigned a
|
||||
different sparsity layout; default is false and this will be satisfied based on
|
||||
availability.
|
||||
num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding
|
||||
local attention window.
|
||||
global_block_indices: optional: a list of integers determining which blocks are considered
|
||||
as global attention. Given indices, determine the blocks that all other token blocks
|
||||
attend to and they attend to all other token blocks. Default value is only index 0.
|
||||
Notice that if global_block_end_indices parameter is set, this parameter is used as
|
||||
starting index of each global window.
|
||||
global_block_end_indices: optional: a list of integers determining end indices of global
|
||||
window blocks. By default this is not used. But if it is set, it must have the same size
|
||||
of global_block_indices parameter, and combining this two parameters, for each index i,
|
||||
blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are
|
||||
considered as global attention.
|
||||
attention: optional: a string determining attention type. Attention can be `unidirectional`,
|
||||
such as autoregressive models, in which tokens attend only to tokens appear before them
|
||||
in the context. Considering that, the upper triangular of attention matrix is empty as
|
||||
above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to
|
||||
any other tokens before or after them. Then, the upper triangular part of the attention
|
||||
matrix is mirror of the lower triangular in the above figure.
|
||||
"""
|
||||
|
||||
super().__init__(num_heads, block_size, different_layout_per_head)
|
||||
|
||||
self.num_sliding_window_blocks = num_sliding_window_blocks
|
||||
self.global_block_indices = global_block_indices
|
||||
self.attention = attention
|
||||
|
||||
if global_block_end_indices is not None:
|
||||
if len(global_block_indices) != len(global_block_end_indices):
|
||||
raise ValueError(
|
||||
f"""Global block start indices length, {len(global_block_indices)}, must be same as
|
||||
global block end indices length, {len(global_block_end_indices)}!"""
|
||||
)
|
||||
for _, (start_idx, end_idx) in enumerate(
|
||||
zip(global_block_indices, global_block_end_indices)
|
||||
):
|
||||
if start_idx >= end_idx:
|
||||
raise ValueError(
|
||||
f"""Global block start index, {start_idx}, must be smaller than global block end
|
||||
index, {end_idx}!"""
|
||||
)
|
||||
self.global_block_end_indices = global_block_end_indices
|
||||
|
||||
def set_sliding_window_layout(self, h, layout):
|
||||
"""Sets sliding local attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout
|
||||
of all head in which local sliding window layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if num_blocks < self.num_sliding_window_blocks:
|
||||
raise ValueError(
|
||||
f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller
|
||||
than overall number of blocks in a row, {num_blocks}!"""
|
||||
)
|
||||
|
||||
w = self.num_sliding_window_blocks // 2
|
||||
for row in range(0, num_blocks):
|
||||
start = max(0, row - w)
|
||||
end = min(row + w + 1, num_blocks)
|
||||
layout[h, row, start:end] = 1
|
||||
return layout
|
||||
|
||||
def set_global_layout(self, h, layout):
|
||||
"""Sets global attention layout used by the given head in the sparse attention.
|
||||
Arguments:
|
||||
h: required: an integer determining head index
|
||||
layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing
|
||||
sparsity layout of all head; may not be completely set at this step
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity
|
||||
layout of all head in which global layout is set
|
||||
"""
|
||||
|
||||
num_blocks = layout.shape[1]
|
||||
if self.global_block_end_indices is None:
|
||||
for idx in self.global_block_indices:
|
||||
# if global block idx is in the range of the sequence blocks
|
||||
if idx < num_blocks:
|
||||
# global rows
|
||||
layout[h, idx, :] = 1
|
||||
|
||||
# global columns
|
||||
layout[h, :, idx] = 1
|
||||
else:
|
||||
for _, (start_idx, end_idx) in enumerate(
|
||||
zip(self.global_block_indices, self.global_block_end_indices)
|
||||
):
|
||||
# if global block idx is in the range of the sequence blocks
|
||||
if start_idx < num_blocks:
|
||||
end_idx = min(end_idx, num_blocks)
|
||||
# global rows
|
||||
layout[h, start_idx:end_idx, :] = 1
|
||||
|
||||
# global columns
|
||||
layout[h, :, start_idx:end_idx] = 1
|
||||
if self.attention == "unidirectional":
|
||||
layout = torch.tril(layout)
|
||||
return layout
|
||||
|
||||
def make_layout(self, seq_len):
|
||||
"""Generates edited `Longformer` sparsity layout used by each head in the sparse attention.
|
||||
Arguments:
|
||||
seq_len: required: an integer determining number of attention heads of the layer.
|
||||
Return:
|
||||
layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer`
|
||||
sparsity layout of all head
|
||||
"""
|
||||
|
||||
layout = self.setup_layout(seq_len)
|
||||
for h in range(0, self.num_layout_heads):
|
||||
layout = self.set_sliding_window_layout(h, layout)
|
||||
layout = self.set_global_layout(h, layout)
|
||||
|
||||
layout = self.check_and_propagate_first_head_layout(layout)
|
||||
return layout
|
||||
108
pkgs/xformers/components/attention/utils.py
Normal file
108
pkgs/xformers/components/attention/utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
|
||||
def reshape_key_padding_mask(
|
||||
key_padding_mask: torch.Tensor, batched_dim: int
|
||||
) -> torch.Tensor:
|
||||
assert key_padding_mask.ndim == 2
|
||||
batch_size, src_len = key_padding_mask.size()
|
||||
num_heads = batched_dim // batch_size
|
||||
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
|
||||
|
||||
|
||||
def _reshape_key_padding_mask(
|
||||
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
|
||||
) -> torch.Tensor:
|
||||
assert key_padding_mask.shape == (batch_size, src_len)
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(batch_size, 1, 1, src_len)
|
||||
.expand(-1, num_heads, -1, -1)
|
||||
.reshape(batch_size * num_heads, 1, src_len)
|
||||
)
|
||||
return key_padding_mask
|
||||
|
||||
|
||||
# Combine the attention mask and key padding mask into a single mask
|
||||
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
|
||||
# Additive masking not yet supported
|
||||
def maybe_merge_masks(
|
||||
att_mask: Optional[torch.Tensor],
|
||||
key_padding_mask: Optional[torch.Tensor],
|
||||
batch_size: int,
|
||||
src_len: int,
|
||||
num_heads: int,
|
||||
tgt_len: Optional[int] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if tgt_len is None:
|
||||
tgt_len = src_len
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, src_len)
|
||||
key_padding_mask = _reshape_key_padding_mask(
|
||||
key_padding_mask, batch_size, src_len, num_heads
|
||||
)
|
||||
if att_mask is None:
|
||||
# make sure dimensions of key padding mask are the same as those expected for att_mask
|
||||
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
|
||||
# Assumption is that False means to mask.
|
||||
elif att_mask.dtype == torch.bool:
|
||||
att_mask = att_mask.logical_and(key_padding_mask)
|
||||
else:
|
||||
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
|
||||
|
||||
return att_mask
|
||||
|
||||
|
||||
# Assumes that matrix passed in has had softmax applied to it.
|
||||
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
|
||||
"""
|
||||
Computing the Moore-Penrose inverse.
|
||||
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
|
||||
matrix-matrix multiplications.
|
||||
"""
|
||||
|
||||
i = torch.eye(
|
||||
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
|
||||
)
|
||||
k = softmax_mat
|
||||
|
||||
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
|
||||
if pinverse_original_init:
|
||||
# This original implementation is more conservative to compute coefficient of Z_0.
|
||||
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
|
||||
else:
|
||||
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
|
||||
# convergence.
|
||||
v = (
|
||||
1
|
||||
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
|
||||
* k.transpose(-1, -2)
|
||||
)
|
||||
|
||||
for _ in range(n_iter):
|
||||
kv = torch.matmul(k, v)
|
||||
v = torch.matmul(
|
||||
0.25 * v,
|
||||
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def bool_mask_to_additive(
|
||||
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
mask.dtype == torch.bool
|
||||
), "This util is meant to convert in between bool masks and additive ones"
|
||||
|
||||
mask_ = torch.zeros_like(mask, dtype=dtype)
|
||||
mask_[~mask] = float("-inf")
|
||||
return mask_
|
||||
96
pkgs/xformers/components/attention/visual.py
Normal file
96
pkgs/xformers/components/attention/visual.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualAttentionConfig(AttentionConfig):
|
||||
dim_model: int # dimension of the input sequence
|
||||
|
||||
|
||||
class LKA(nn.Module):
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
|
||||
self.conv_spatial = nn.Conv2d(
|
||||
dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3
|
||||
)
|
||||
self.conv1 = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
u = x.clone()
|
||||
attn = self.conv0(x)
|
||||
attn = self.conv_spatial(attn)
|
||||
attn = self.conv1(attn)
|
||||
|
||||
return u * attn
|
||||
|
||||
|
||||
@register_attention("visual", VisualAttentionConfig)
|
||||
class Visual(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
"""
|
||||
Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022).
|
||||
The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network
|
||||
for the reference implementation
|
||||
|
||||
.. Note: compared to the paper, this block contains the LKA (Large Kernel Attention)
|
||||
and the prior and posterior transformations (Conv2d and activation)
|
||||
|
||||
.. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(dim_model, dim_model, 1),
|
||||
nn.GELU(),
|
||||
LKA(dim_model),
|
||||
nn.Conv2d(dim_model, dim_model, 1),
|
||||
)
|
||||
|
||||
# MHA related flags:
|
||||
self.requires_same_k_q_dimensions = (
|
||||
True # This mechanism only really supports self attention
|
||||
)
|
||||
self.supports_attention_mask = False
|
||||
self.requires_skip_multi_head = (
|
||||
True # This mechanism skips the multihead attention altogether
|
||||
)
|
||||
self.requires_squared_context = (
|
||||
True # Recovering the 2D structure from context assumes squared content
|
||||
)
|
||||
|
||||
self.requires_input_projection = (
|
||||
False # This mechanism does not require that the MHA projects inputs
|
||||
)
|
||||
|
||||
def forward(self, q: torch.Tensor, *_, **__):
|
||||
# Expose the 2D token structure
|
||||
B, HW, C = q.shape
|
||||
H = int(math.sqrt(HW))
|
||||
assert H * H == HW
|
||||
|
||||
x = q.transpose(-2, -1).reshape(B, C, H, H)
|
||||
|
||||
# Large kernel attention
|
||||
residual = x.clone()
|
||||
x = self.block(x)
|
||||
x = x + residual
|
||||
|
||||
# Get back to B HW C
|
||||
return x.flatten(2, 3).transpose(-2, -1)
|
||||
Reference in New Issue
Block a user