191 lines
6.6 KiB
Python
191 lines
6.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import logging
|
|
import math
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
|
|
from xformers import _is_triton_available
|
|
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
|
|
|
logger = logging.getLogger("xformers")
|
|
|
|
|
|
_is_blocksparse_available = _is_triton_available()
|
|
|
|
|
|
if _is_blocksparse_available:
|
|
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
|
|
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore
|
|
|
|
from xformers.triton.utils import gpu_capabilities_older_than_70
|
|
|
|
# Blocksparse requires Tensor cores
|
|
if gpu_capabilities_older_than_70():
|
|
logger.warning(
|
|
"Blocksparse is not available: the current GPU does not expose Tensor cores"
|
|
)
|
|
_is_blocksparse_available = False
|
|
|
|
|
|
if _is_blocksparse_available:
|
|
|
|
@dataclass
|
|
class BlockSparseAttentionConfig(AttentionConfig):
|
|
layout: torch.Tensor # The dimensions of the random features
|
|
block_size: int
|
|
dropout: float
|
|
num_heads: int
|
|
|
|
@register_attention("blocksparse", BlockSparseAttentionConfig)
|
|
class BlockSparseAttention(Attention):
|
|
r"""
|
|
Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout.
|
|
|
|
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
|
|
If some dimensions are missing, we assume that the same layout is to be used across heads.
|
|
|
|
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
|
|
be relaxed in the future.
|
|
|
|
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
|
|
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layout: torch.Tensor,
|
|
block_size: int = 16,
|
|
dropout: float = 0.0,
|
|
num_heads: int = 1, # optional, used to adapt the layout if in need
|
|
causal: bool = False,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
if layout.dim() == 2:
|
|
logger.warning(
|
|
"The layout passed is lacking a head dimension and a batch dimension"
|
|
)
|
|
logger.warning(
|
|
"Now assuming that the same layout is to be used across all heads"
|
|
)
|
|
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
|
|
logger.warning(f"New layout dimensions: {layout.shape}")
|
|
|
|
assert block_size in (
|
|
16,
|
|
32,
|
|
64,
|
|
128,
|
|
), "Only block sizes in [16, 32, 64, 128] are supported"
|
|
|
|
super().__init__()
|
|
|
|
self.causal = causal
|
|
|
|
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
|
|
|
|
# Pure blocksparse data
|
|
self.layout = layout
|
|
self.block_size = block_size
|
|
|
|
# make sure that the head dimension is not folded down with the batch
|
|
self.requires_head_dimension = True
|
|
|
|
# key padding mask and attention mask must be passed in separately
|
|
self.requires_same_k_q_dimensions = True
|
|
|
|
# The underlying triton op does not support per element attention mask
|
|
self.supports_attention_mask = False
|
|
self.supports_key_padding_mask = False
|
|
|
|
def create_triton_kernels(self, device):
|
|
# blocksparse operators
|
|
self.sparse_dot_sdd = blocksparse_matmul(
|
|
self.layout,
|
|
self.block_size,
|
|
"sdd",
|
|
trans_a=False,
|
|
trans_b=True,
|
|
device=device,
|
|
)
|
|
|
|
self.sparse_dot_dsd = blocksparse_matmul(
|
|
self.layout,
|
|
self.block_size,
|
|
"dsd",
|
|
trans_a=False,
|
|
trans_b=False,
|
|
device=device,
|
|
)
|
|
|
|
self.sparse_softmax = blocksparse_softmax(
|
|
self.layout,
|
|
self.block_size,
|
|
device=device,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
scale: float = 1.0,
|
|
*args,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert (
|
|
"att_mask" not in kwargs.keys() and "att_mask" not in args
|
|
), "This attention does not support an attention mask, but you can specify causality."
|
|
|
|
r"""
|
|
A thin wrap around the Triton blockparse attention operation
|
|
|
|
.. note: Per element attention mask is not supported, but you can specify causality
|
|
"""
|
|
|
|
# Delayed triton init, to make sure that we get the right device
|
|
# Infer device from query
|
|
if not hasattr(self, "sparse_dot_sdd"):
|
|
self.create_triton_kernels(q.device)
|
|
|
|
assert (
|
|
q.shape[-2] == k.shape[-2]
|
|
), "Blocksparse requires the same dimensions for K and Q for now"
|
|
|
|
assert (
|
|
q.shape[-2] == self.layout.shape[-2] * self.block_size
|
|
), "Actual sequence size and layout are inconsistent"
|
|
assert (
|
|
k.shape[-2] == self.layout.shape[-2] * self.block_size
|
|
), "Actual sequence size and layout are inconsistent"
|
|
|
|
assert (
|
|
q.shape[-2] % self.block_size
|
|
) == 0, "Sequence length {} must be a multiple of block size {}".format(
|
|
q.shape[-2], self.block_size
|
|
)
|
|
|
|
# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
|
|
# When the computations are block sparse, the matrix types change along the way:
|
|
# - (sparse) attention matrix = (dense) Kt * (dense) Q
|
|
q = q / math.sqrt(q.size(-1))
|
|
sparse_att_mat = self.sparse_dot_sdd(q, k)
|
|
|
|
# - softmax on the sparse attention matrix
|
|
sparse_att_mat = self.sparse_softmax(
|
|
sparse_att_mat, scale=scale, is_causal=self.causal
|
|
)
|
|
|
|
sparse_att_mat = self.attn_drop(sparse_att_mat)
|
|
|
|
# - then (dense) attention is (sparse) attention matrix * dense (value)
|
|
a = self.sparse_dot_dsd(sparse_att_mat, v)
|
|
return a
|