First commit
This commit is contained in:
342
pkgs/xformers/components/attention/core.py
Normal file
342
pkgs/xformers/components/attention/core.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers import _has_cpp_library, _is_triton_available
|
||||
from xformers.components.attention.attention_mask import AttentionMask
|
||||
|
||||
if _has_cpp_library:
|
||||
from ._sputnik_sparse import SparseCS
|
||||
|
||||
if _is_triton_available():
|
||||
from xformers.triton.softmax import softmax as triton_softmax
|
||||
from xformers.triton.utils import gpu_capabilities_older_than_70
|
||||
|
||||
_is_blocksparse_available = (
|
||||
_is_triton_available() and not gpu_capabilities_older_than_70()
|
||||
)
|
||||
|
||||
if _is_blocksparse_available:
|
||||
from xformers.components.attention.blocksparse import BlockSparseAttention
|
||||
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
|
||||
assert matrix.ndim == 3
|
||||
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
|
||||
nonzero = torch.nonzero(keep)
|
||||
nnz = nonzero.shape[0]
|
||||
# NOTE: need to make it a multiple of 4 for sputnik
|
||||
nonzero = nonzero[: (nnz - nnz % divisible_by)]
|
||||
i, j = nonzero.unbind(1)
|
||||
output = torch.zeros_like(matrix)
|
||||
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
|
||||
output[bdim, i, j] = matrix[bdim, i, j]
|
||||
return output
|
||||
|
||||
|
||||
def _broadcast_batch(mask, batch_size):
|
||||
if mask.ndim == 3:
|
||||
return mask
|
||||
assert mask.ndim == 2
|
||||
|
||||
mask = mask.coalesce()
|
||||
values = mask.values()
|
||||
indices = mask.indices()
|
||||
nnz = len(values)
|
||||
# strategy: repeat the indices and append the extra batch dimension to the indices
|
||||
indices = indices.repeat(1, batch_size)
|
||||
# now create the batch indices
|
||||
batch_indices = torch.arange(batch_size, device=indices.device)
|
||||
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
|
||||
|
||||
# put them together
|
||||
indices = torch.cat([batch_indices[None, :], indices], dim=0)
|
||||
|
||||
# now repeat the values
|
||||
values = values.repeat(batch_size)
|
||||
|
||||
size = (batch_size,) + mask.shape
|
||||
|
||||
return torch.sparse_coo_tensor(indices, values, size)
|
||||
|
||||
|
||||
def _matmul_with_mask(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
mask: Optional[Union[torch.Tensor, "SparseCS"]],
|
||||
) -> torch.Tensor:
|
||||
if mask is None:
|
||||
return a @ b
|
||||
|
||||
if _has_cpp_library and mask.dtype == torch.bool:
|
||||
if isinstance(mask, SparseCS):
|
||||
return mask.matmul_with_mask(a, b)
|
||||
if mask.is_sparse:
|
||||
# perform broadcasting if needed
|
||||
mask = _broadcast_batch(mask, a.shape[0])
|
||||
|
||||
# coalesced is not implemented for bool tensors, so need to cast
|
||||
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
|
||||
|
||||
return torch.ops.xformers.matmul_with_mask(a, b, mask)
|
||||
|
||||
# Non optimized codepath
|
||||
if _has_cpp_library:
|
||||
assert not isinstance(mask, SparseCS)
|
||||
|
||||
att = a @ b
|
||||
if mask.dtype == torch.bool:
|
||||
assert not isinstance(mask, SparseCS)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
||||
# mask is presumed false == ignore
|
||||
att[~mask] = float("-inf")
|
||||
else:
|
||||
# mask is presumed additive
|
||||
# repeat if batch sizes don't match
|
||||
if (
|
||||
not isinstance(mask, SparseCS)
|
||||
and mask.ndim == 3
|
||||
and mask.shape[0] != att.shape[0]
|
||||
and (att.shape[0] % mask.shape[0]) == 0
|
||||
):
|
||||
repeat_factor = att.shape[0] // mask.shape[0]
|
||||
mask = mask.repeat([repeat_factor, 1, 1])
|
||||
logger.info("Mismatched batch dimensions for mask, repeating mask.")
|
||||
att += mask
|
||||
return att
|
||||
|
||||
|
||||
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
|
||||
if _has_cpp_library and isinstance(a, SparseCS):
|
||||
return a.softmax()
|
||||
|
||||
if a.is_sparse:
|
||||
return torch.sparse.softmax(a, dim=a.ndim - 1)
|
||||
|
||||
if _is_triton_available():
|
||||
return triton_softmax(a, mask=None, causal=causal)
|
||||
else:
|
||||
return torch.softmax(a, dim=a.ndim - 1)
|
||||
|
||||
|
||||
if _has_cpp_library:
|
||||
|
||||
class SparseBMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, a, b):
|
||||
a = a.coalesce()
|
||||
r = torch.bmm(a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad):
|
||||
a, b = ctx.saved_tensors
|
||||
|
||||
# gradients w.r.t. a
|
||||
ga = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
|
||||
|
||||
# gradients w.r.t. b
|
||||
gb = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
gb = a.transpose(1, 2).bmm(grad)
|
||||
|
||||
return ga, gb
|
||||
|
||||
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Batch matrix multiply between a sparse matrix and a dense matrix
|
||||
"""
|
||||
assert a.ndim == b.ndim == 3
|
||||
assert a.shape[0] == b.shape[0]
|
||||
assert a.shape[2] == b.shape[1]
|
||||
return SparseBMM.apply(a, b)
|
||||
|
||||
|
||||
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
if _has_cpp_library:
|
||||
if isinstance(a, SparseCS):
|
||||
return a.spmm(b)
|
||||
if a.is_sparse:
|
||||
return _sparse_bmm(a, b)
|
||||
return a @ b
|
||||
|
||||
|
||||
def _apply_dropout(att, dropout):
|
||||
if dropout is None:
|
||||
return att
|
||||
|
||||
# Dropout chokes on sparse tensors
|
||||
if _has_cpp_library:
|
||||
if isinstance(att, SparseCS):
|
||||
values = att.values.clone()
|
||||
values = dropout(values)
|
||||
att = SparseCS.wrap(
|
||||
att.shape,
|
||||
values,
|
||||
att.row_indices,
|
||||
att.row_offsets,
|
||||
att.column_indices,
|
||||
att._transp_info,
|
||||
)
|
||||
elif att.is_sparse:
|
||||
att = att.coalesce()
|
||||
values = att.values().clone() # protect against in-place dropout
|
||||
values = dropout(values)
|
||||
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
|
||||
else:
|
||||
# Simple dense case
|
||||
att = dropout(att)
|
||||
|
||||
return att
|
||||
|
||||
# Non optimized vanilla dropout
|
||||
att = dropout(att)
|
||||
return att
|
||||
|
||||
|
||||
def scaled_query_key_softmax(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
|
||||
# this is needed due to limitations in sparse_bmm for now
|
||||
|
||||
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
|
||||
q = q / math.sqrt(k.size(-1))
|
||||
|
||||
# Matmul with mask
|
||||
if att_mask is not None and isinstance(att_mask, AttentionMask):
|
||||
# Additive mask
|
||||
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
|
||||
else:
|
||||
mask = att_mask
|
||||
|
||||
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
|
||||
|
||||
# Softmax to get the attention probabilities
|
||||
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
|
||||
att = _softmax(att, causal=is_causal)
|
||||
return att
|
||||
|
||||
|
||||
if _is_blocksparse_available:
|
||||
# 128 is default maxsize
|
||||
@lru_cache(maxsize=128)
|
||||
def _retrieve_blocksparse(
|
||||
num_heads: int, seq_len: int, block_size: int
|
||||
) -> BlockSparseAttention:
|
||||
# Checks if blocksparse object exists in cache
|
||||
|
||||
blocks = seq_len // block_size
|
||||
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
|
||||
return BlockSparseAttention(
|
||||
layout=layout_fill, block_size=block_size, causal=True
|
||||
)
|
||||
|
||||
|
||||
def blocksparse_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
dropout: Optional[torch.nn.Module] = None,
|
||||
block_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
|
||||
orig_dim = q.dim()
|
||||
seq_len = q.shape[-2]
|
||||
# Layout head dimension: 1 or batch size (q.shape[0])
|
||||
layout_heads = 1
|
||||
|
||||
# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
|
||||
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
|
||||
|
||||
if orig_dim == 3:
|
||||
# Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh
|
||||
# Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs)
|
||||
if layout_heads == 1:
|
||||
q = q.unsqueeze(1)
|
||||
k = k.unsqueeze(1)
|
||||
v = v.unsqueeze(1)
|
||||
else:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
|
||||
blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
|
||||
# Dropout is a no-op in evaluation mode
|
||||
if isinstance(dropout, torch.nn.Dropout):
|
||||
blocksparse_attention.attn_drop = dropout
|
||||
else:
|
||||
blocksparse_attention.attn_drop = torch.nn.Dropout(0.0)
|
||||
att = blocksparse_attention(q, k, v)
|
||||
|
||||
# Reshape attention (B, nh, S, hs) back to (N, S, hs)
|
||||
if orig_dim == 3:
|
||||
return att.flatten(0, 1)
|
||||
return att
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
|
||||
dropout: Optional[torch.nn.Module] = None,
|
||||
block_size: int = 128,
|
||||
) -> torch.Tensor:
|
||||
autocast_disabled = (
|
||||
_has_cpp_library
|
||||
and isinstance(att_mask, SparseCS)
|
||||
or (att_mask is not None and att_mask.is_sparse)
|
||||
)
|
||||
seq_len = q.shape[-2]
|
||||
|
||||
# switch if:
|
||||
# causal is required but mask is not sparse
|
||||
# fp16 or under amp context
|
||||
# sequence length is divisible by block size
|
||||
# same seq len for K and Q
|
||||
switch_to_blocksparse = (
|
||||
_is_blocksparse_available
|
||||
and (att_mask is not None and not att_mask.is_sparse)
|
||||
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
|
||||
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
|
||||
and not seq_len % block_size
|
||||
and q.shape[-2] == k.shape[-2]
|
||||
)
|
||||
|
||||
if switch_to_blocksparse:
|
||||
logger.info("Switching causal attention to Triton blocksparse...")
|
||||
return blocksparse_attention(q, k, v, dropout, block_size)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore
|
||||
if autocast_disabled:
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
|
||||
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
|
||||
|
||||
# Optional dropout, could be part of the masking in the future
|
||||
att = _apply_dropout(att, dropout)
|
||||
|
||||
# Get to the predicted values, for all heads
|
||||
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
|
||||
y = bmm(att, v)
|
||||
return y
|
||||
Reference in New Issue
Block a user