First commit
This commit is contained in:
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
@@ -0,0 +1,779 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionBias:
|
||||
"""Base class for a custom bias that can be applied \
|
||||
as the attn_bias argument in
|
||||
:attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
That function has the ability to add a tensor, the
|
||||
attention bias, to the QK^T matrix before it is used
|
||||
in the softmax part of the attention calculation.
|
||||
The attention bias tensor with shape
|
||||
(B or 1, n_queries, number of keys)
|
||||
can be given as the attn_bias input.
|
||||
The most common use case is for an attention bias is
|
||||
to contain only zeros and negative infinities, which forms
|
||||
a mask so that some queries only attend to some keys.
|
||||
|
||||
Children of this class define alternative things which can
|
||||
be used as the attn_bias input to define an attention bias which
|
||||
forms such a mask, for some common cases.
|
||||
|
||||
When using an :attr:`xformers.ops.AttentionBias`
|
||||
instead of a :attr:`torch.Tensor`, the mask matrix does
|
||||
not need to be materialized, and can be
|
||||
hardcoded into some kernels for better performance.
|
||||
|
||||
See:
|
||||
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
|
||||
|
||||
"""
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Materializes the bias as a `torch.Tensor`. This is very slow
|
||||
and we don't attempt to make it fast. Only use for debugging/testing.
|
||||
|
||||
Shape should be like `[*, q_seqlen, k_seqlen]`
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LowerTriangularMask(AttentionBias):
|
||||
"""
|
||||
A lower-triangular (aka causal) mask
|
||||
|
||||
A query Q cannot attend to a key which is farther from the
|
||||
initial key than Q is from the initial query.
|
||||
"""
|
||||
|
||||
def __init__(self, *tensor_args, **tensor_kwargs) -> None:
|
||||
# NOTE: Unused arguments, we keep them for backward compatibility
|
||||
super().__init__()
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
|
||||
|
||||
def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
|
||||
return LowerTriangularMaskWithTensorBias(bias)
|
||||
|
||||
|
||||
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
|
||||
"""A lower-triangular (aka causal) mask with an additive bias"""
|
||||
|
||||
def __init__(self, bias: torch.Tensor) -> None:
|
||||
self._bias = bias
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return super().materialize(shape, dtype=dtype, device=device) + self._bias
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
|
||||
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
|
||||
f" seqstart: {self.seqstart_py}"
|
||||
)
|
||||
if batch_sizes is None:
|
||||
batch_sizes = [1] * (len(self.seqstart_py) - 1)
|
||||
split_chunks = []
|
||||
it = 0
|
||||
for batch_size in batch_sizes:
|
||||
split_chunks.append(
|
||||
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
|
||||
)
|
||||
it += batch_size
|
||||
return [
|
||||
tensor.reshape([bs, -1, *tensor.shape[2:]])
|
||||
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PaddedSeqLenInfo(_SeqLenInfo):
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks which are
|
||||
padded out to the same total length.
|
||||
|
||||
For example, to represent a dimension of length 12 with space for
|
||||
three blocks of length 4, but where the occupied lengths are
|
||||
2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
|
||||
|
||||
The layout along the dimension is
|
||||
|
||||
0 ─► block 0
|
||||
block 0
|
||||
<space>
|
||||
<space>
|
||||
4 ─► block 1
|
||||
block 1
|
||||
block 1
|
||||
<space>
|
||||
8 ─► block 2
|
||||
block 2
|
||||
<space>
|
||||
<space>
|
||||
12 ─►
|
||||
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 4, 8, 12]
|
||||
seqstart: torch.IntTensor([0, 4, 8, 12])
|
||||
seqlen_py: [2, 3, 2]
|
||||
seqlen: torch.IntTensor([2, 3, 2])
|
||||
padding: 4
|
||||
"""
|
||||
|
||||
seqlen: torch.Tensor
|
||||
seqlen_py: Sequence[int]
|
||||
padding: int
|
||||
# From parent: seqstart[i] contains the start position
|
||||
# of the i-th sequence
|
||||
# seqstart: torch.Tensor
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert len(self.seqstart_py) == len(self.seqlen_py) + 1
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqlen = self.seqlen.to(device, non_blocking=True)
|
||||
super().to(device)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
for (start, _), length in zip(super().intervals(), self.seqlen_py):
|
||||
yield start, start + length
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
raise RuntimeError(
|
||||
"Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens_padded(
|
||||
cls, seqlens: Sequence[int], padding: int
|
||||
) -> "_PaddedSeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
seqstart = padding * torch.arange(batch_size)
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
assert all(seqlen <= padding for seqlen in seqlens)
|
||||
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
|
||||
return cls(
|
||||
seqlen=torch.tensor(seqlens, dtype=torch.int32),
|
||||
seqlen_py=seqlens,
|
||||
max_seqlen=max(seqlens),
|
||||
min_seqlen=min(seqlens),
|
||||
seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
|
||||
seqstart_py=seqstart_py,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
raise NotImplementedError("_PaddedSeqLenInfo.split")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalMask(AttentionBias):
|
||||
"""
|
||||
A block-diagonal mask that can be passed as ``attn_bias``
|
||||
argument to :attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
Queries in block i only attend to keys in block i.
|
||||
|
||||
.. figure:: /_static/block_diag_bias.png
|
||||
|
||||
This bias can be used to handle a batch of sequences of
|
||||
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
|
||||
|
||||
:Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from xformers.ops import fmha
|
||||
|
||||
K = 16
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
list_x = [
|
||||
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
|
||||
]
|
||||
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
|
||||
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
|
||||
|
||||
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
|
||||
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
list_out = attn_bias.split(out)
|
||||
print(list_out[0].shape) # [1, 3, 1, K]
|
||||
assert tuple(list_out[0].shape) == (1, 3, 1, K)
|
||||
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _SeqLenInfo
|
||||
_batch_sizes: Optional[Sequence[int]] = None
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
|
||||
shape[-1],
|
||||
self.k_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
|
||||
shape[-2],
|
||||
self.q_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_seqlen: Optional[Sequence[int]] = None,
|
||||
) -> "BlockDiagonalMask":
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
|
||||
kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
|
||||
(Defaults to ``q_seqlen``.)
|
||||
Returns:
|
||||
BlockDiagonalMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
if kv_seqlen is None or q_seqlen == kv_seqlen:
|
||||
k_seqinfo = q_seqinfo
|
||||
else:
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
@classmethod
|
||||
def from_tensor_list(
|
||||
cls,
|
||||
tensors: Sequence[torch.Tensor],
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
|
||||
concatenated on the sequence length dimension
|
||||
|
||||
.. figure:: /_static/block_diag_cat_split.png
|
||||
|
||||
See also :attr:`BlockDiagonalMask.split` to split the returned
|
||||
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
|
||||
|
||||
Args:
|
||||
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
|
||||
All tensors should have the same dimension and the same batch size ``B``, but
|
||||
they can have different sequence length ``M``.
|
||||
|
||||
Returns:
|
||||
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
|
||||
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
|
||||
"""
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors]
|
||||
seqlens = []
|
||||
for x in tensors:
|
||||
for _ in range(x.shape[0]):
|
||||
seqlens.append(x.shape[1])
|
||||
block_diag = cls.from_seqlens(seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
|
||||
concat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
return block_diag, concat_tensors
|
||||
|
||||
@classmethod
|
||||
def from_tensor_lists_qkv(
|
||||
cls,
|
||||
tensors_q: Sequence[torch.Tensor],
|
||||
tensors_k: Sequence[torch.Tensor],
|
||||
tensors_v: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert len(tensors_q) == len(tensors_k)
|
||||
assert tensors_v is None or len(tensors_v) == len(tensors_q)
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
|
||||
q_seqlens, kv_seqlens = [], []
|
||||
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
|
||||
assert q.shape[0] == k.shape[0]
|
||||
q_seqlens += [q.shape[1]] * q.shape[0]
|
||||
kv_seqlens += [k.shape[1]] * k.shape[0]
|
||||
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
|
||||
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
return (
|
||||
block_diag,
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
|
||||
if tensors_v is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.k_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
|
||||
|
||||
Returns:
|
||||
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
|
||||
"""
|
||||
assert self.q_seqinfo is self.k_seqinfo
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def make_causal(self) -> "BlockDiagonalCausalMask":
|
||||
"""Makes each block causal"""
|
||||
return BlockDiagonalCausalMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
|
||||
"""Makes each block causal with a possible non-causal prefix"""
|
||||
return BlockDiagonalCausalFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_local_attention(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionMask":
|
||||
"""Experimental: Makes each block causal with local attention"""
|
||||
return BlockDiagonalCausalLocalAttentionMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
def make_local_attention_from_bottomright(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
|
||||
"""Experimental: Makes each block causal with local attention, start from bottom right"""
|
||||
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is farther from the initial key in block i than Q
|
||||
is from the initial query in block i.
|
||||
"""
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return LowerTriangularMask().materialize(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
This mask allows for a non-causal prefix
|
||||
NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
|
||||
defined (softmax of vector of `-inf` in the attention)
|
||||
|
||||
Queries and keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which nearer the final key in block i than Q is to the
|
||||
final query in block i.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
num_queries = q_end - q_start
|
||||
num_keys = k_end - k_start
|
||||
if num_keys < num_queries:
|
||||
raise ValueError(
|
||||
f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
|
||||
" Expected `num_keys >= num_queries`"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
|
||||
except an offset on causality is allowed for each block and we support padding for k/v
|
||||
|
||||
The keys and values are divided into blocks which are padded out to
|
||||
the same total length.
|
||||
For example, if there is space for 12 keys, for three blocks of
|
||||
max length 4, but we only want to use the first 2, 3 and 2
|
||||
of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
|
||||
The queries are divided into blocks, without padding, of lengths given by
|
||||
q_seqlen.
|
||||
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is not in use (i.e. in the padded area),
|
||||
nor one which is nearer to the final key in block i
|
||||
than Q is to the final query in block i.
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _PaddedSeqLenInfo
|
||||
causal_diagonal: Any = None # unused. Exists for BC only.
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("k shapes wrong")
|
||||
if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("q shapes wrong")
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_padding: int,
|
||||
kv_seqlen: Sequence[int],
|
||||
causal_diagonal: Any = None,
|
||||
) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
|
||||
"""Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
|
||||
lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
|
||||
kv_padding (int): Padding for k/v - also an upperbound on each individual key length
|
||||
kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
|
||||
causal_diagonal: unused, for BC only
|
||||
Returns:
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
|
||||
if self._window_size is not None and self._window_size > 0:
|
||||
mask = torch.triu(mask, diagonal=-self._window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
BlockDiagonalCausalFromBottomRightMask
|
||||
):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q + (q - k) - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
|
||||
if torch.max(materialized, dim=1).values.min() == -float("inf"):
|
||||
raise RuntimeError("FUCKING FUCK FUCK")
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
|
||||
if self._window_size is not None:
|
||||
mask = torch.triu(
|
||||
mask, diagonal=num_keys - num_queries - self._window_size + 1
|
||||
)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
Reference in New Issue
Block a user