First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,474 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional, Sequence, Tuple, Type, Union
import os
import torch
from . import cutlass, decoder, flash, small_k, triton, triton_splitk
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp)
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
TritonFlashAttentionOp = (triton.FwOp, triton.BwOp)
class _fMHA(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
inp = Inputs(*args)
op_fw = op[0] if op is not None else None
op_bw = op[1] if op is not None else None
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
inp=inp, op=op_fw
)
# Saving attn_bias is a bit complicated, as the
# torch part should go in `save_for_backward`
if isinstance(inp.attn_bias, torch.Tensor):
attn_bias_tensor = inp.attn_bias
attn_bias_ctx = None
else:
attn_bias_tensor = None
attn_bias_ctx = inp.attn_bias
ctx.save_for_backward(
inp.query,
inp.key,
inp.value,
op_ctx.q_padded,
op_ctx.k_padded,
op_ctx.v_padded,
op_ctx.o_padded,
op_ctx.out,
op_ctx.lse,
)
ctx.rng_state = op_ctx.rng_state
ctx.attn_bias_tensor = attn_bias_tensor
if op_ctx.op_bw is not None:
if op_bw is not None and op_bw is not op_ctx.op_bw:
raise ValueError(
f"Specified op_bw={op_bw.NAME}, but forward op "
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
)
op_bw = op_ctx.op_bw
ctx.op_fw = op_fw
ctx.op_bw = op_bw
ctx.p = inp.p
ctx.use_alibi = inp.use_alibi
ctx.alibi_mode = inp.alibi_mode
ctx.imp_mode = inp.imp_mode
ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
ctx.n_args = len(args)
return out
@staticmethod
def deserialize_bias(
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
) -> Any:
if attn_bias_tensor is None:
return attn_bias_ctx
return attn_bias_tensor
@classmethod
@torch.autograd.function.once_differentiable
def backward(cls, ctx, grad):
# Re-create context
query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors
attn_bias_tensor = ctx.attn_bias_tensor
rng_state = ctx.rng_state
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
p=ctx.p,
scale=ctx.scale,
use_alibi=ctx.use_alibi,
alibi_mode=ctx.alibi_mode,
imp_mode = ctx.imp_mode,
)
op_ctx = Context(
lse=lse,
out=out,
q_padded=q_padded,
k_padded=k_padded,
v_padded=v_padded,
o_padded=o_padded,
rng_state=rng_state,
)
grads = _memory_efficient_attention_backward(
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
)
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
ctx.n_args - 2
)
def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[AttentionOp] = None,
) -> torch.Tensor:
"""Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
:Inputs shape:
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
the sequence length, H the number of heads, and K the embeding size per head
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
- Inputs can also be of dimension 5 with GQA - see note below
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
:Equivalent pytorch code:
.. code-block:: python
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value
:Examples:
.. code-block:: python
import xformers.ops as xops
# Compute regular attention
y = xops.memory_efficient_attention(q, k, v)
# With a dropout of 0.2
y = xops.memory_efficient_attention(q, k, v, p=0.2)
# Causal attention
y = xops.memory_efficient_attention(
q, k, v,
attn_bias=xops.LowerTriangularMask()
)
:Supported hardware:
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
MQA/GQA is an experimental feature supported only for the forward pass.
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
``H`` is the number of heads per group (8 in the example).
Please note that xFormers will not automatically broadcast the inputs, so you will need
to broadcast it manually before calling `memory_efficient_attention`.
:GQA/MQA example:
.. code-block:: python
import torch
import xformers.ops as xops
B, M, K = 3, 32, 128
kwargs = dict(device="cuda", dtype=torch.float16)
q = torch.randn([B, M, 8, K], **kwargs)
k = torch.randn([B, M, 2, K], **kwargs)
v = torch.randn([B, M, 2, K], **kwargs)
out_gqa = xops.memory_efficient_attention(
q.reshape([B, M, 2, 4, K]),
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
)
Raises:
NotImplementedError: if there is no operator available to compute the MHA
ValueError: if inputs are invalid
:parameter query: Tensor of shape ``[B, Mq, H, K]``
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
:parameter p: Dropout probability. Disabled if set to ``0.0``
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
scale (q.shape[-1]**-0.5) will be used.
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
If set to ``None`` (recommended), xFormers \
will dispatch to the best available operator, depending on the inputs \
and options.
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
"""
return _memory_efficient_attention(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
) -> torch.Tensor:
"""
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
"""
return _memory_efficient_attention_forward(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
def memory_efficient_attention_forward_requires_grad(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
return out, ctx.lse
def memory_efficient_attention_backward(
grad: torch.Tensor,
output: torch.Tensor,
lse: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionBwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the gradient of the attention.
Returns a tuple (dq, dk, dv)
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
`lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
gradients = _memory_efficient_attention_backward(
Context(out=output, lse=lse),
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
grad,
op=op,
)
return (gradients.dq, gradients.dk, gradients.dv)
def _memory_efficient_attention(
inp: Inputs, op: Optional[AttentionOp] = None
) -> torch.Tensor:
# fast-path that doesn't require computing the logsumexp for backward computation
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
return _memory_efficient_attention_forward(
inp, op=op[0] if op is not None else None
)
output_shape = inp.normalize_bmhk()
return _fMHA.apply(
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode
).reshape(output_shape)
def _memory_efficient_attention_forward(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> torch.Tensor:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, False)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)
def _memory_efficient_attention_forward_requires_grad(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> Tuple[torch.Tensor, Context]:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, True)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
def _memory_efficient_attention_backward(
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
) -> Gradients:
"""Warning: grad/ctx.out is potentially in BMK format"""
inp.validate_inputs()
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
raise ValueError(
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
f"grad.shape : {grad.shape} \n"
f"out.shape : {ctx.out.shape} \n"
f"query.shape: {inp.query.shape}"
)
shape_dq, shape_dk, shape_dv = tuple(
x.shape for x in (inp.query, inp.key, inp.value)
)
inp.normalize_bmhk()
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
if (
ctx.lse.ndim != 3
# Dim 0
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.query.shape[0]
)
or (
isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
)
# Dim 1
or ctx.lse.shape[1] != inp.query.shape[2]
# Dim 2
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[2] < inp.query.shape[1]
)
):
raise ValueError(
"Input tensors have incompatible shapes."
f"lse.shape : {ctx.lse.shape} \n"
f"query.shape : {inp.query.shape}"
)
grad = bmk2bmhk(grad, 1)
ctx.out = bmk2bmhk(ctx.out, 1)
if op is None:
op = _dispatch_bw(inp)
else:
_ensure_op_supports_or_raise(
ValueError, "memory_efficient_attention_backward", op, inp
)
grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
grads.dv = grads.dv.reshape(shape_dv)
return grads
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
cutlass.FwOp,
flash.FwOp,
triton.FwOp,
small_k.FwOp,
triton_splitk.FwOp,
]
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
cutlass.BwOp,
flash.BwOp,
triton.BwOp,
small_k.BwOp,
]
__all__ = [
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"TritonFlashAttentionOp",
"memory_efficient_attention",
"ALL_FW_OPS",
"ALL_BW_OPS",
]

View File

@@ -0,0 +1,779 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
import torch
class AttentionBias:
"""Base class for a custom bias that can be applied \
as the attn_bias argument in
:attr:`xformers.ops.memory_efficient_attention`.
That function has the ability to add a tensor, the
attention bias, to the QK^T matrix before it is used
in the softmax part of the attention calculation.
The attention bias tensor with shape
(B or 1, n_queries, number of keys)
can be given as the attn_bias input.
The most common use case is for an attention bias is
to contain only zeros and negative infinities, which forms
a mask so that some queries only attend to some keys.
Children of this class define alternative things which can
be used as the attn_bias input to define an attention bias which
forms such a mask, for some common cases.
When using an :attr:`xformers.ops.AttentionBias`
instead of a :attr:`torch.Tensor`, the mask matrix does
not need to be materialized, and can be
hardcoded into some kernels for better performance.
See:
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
"""
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""
Materializes the bias as a `torch.Tensor`. This is very slow
and we don't attempt to make it fast. Only use for debugging/testing.
Shape should be like `[*, q_seqlen, k_seqlen]`
"""
raise NotImplementedError()
class LowerTriangularMask(AttentionBias):
"""
A lower-triangular (aka causal) mask
A query Q cannot attend to a key which is farther from the
initial key than Q is from the initial query.
"""
def __init__(self, *tensor_args, **tensor_kwargs) -> None:
# NOTE: Unused arguments, we keep them for backward compatibility
super().__init__()
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
return LowerTriangularMaskWithTensorBias(bias)
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
"""A lower-triangular (aka causal) mask with an additive bias"""
def __init__(self, bias: torch.Tensor) -> None:
self._bias = bias
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return super().materialize(shape, dtype=dtype, device=device) + self._bias
@dataclass
class _SeqLenInfo:
"""
(Internal) Represents the division of a dimension into blocks.
For example, to represents a dimension of length 7 divided into
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 2, 5, 7]
seqstart: torch.IntTensor([0, 2, 5, 7])
"""
seqstart: torch.Tensor
max_seqlen: int
min_seqlen: int
seqstart_py: List[int]
def to(self, device: torch.device) -> None:
self.seqstart = self.seqstart.to(device, non_blocking=True)
def intervals(self) -> Iterable[Tuple[int, int]]:
yield from zip(self.seqstart_py, self.seqstart_py[1:])
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
"""
assert not isinstance(seqlens, torch.Tensor)
seqstart_py = [0]
max_seqlen = -1
min_seqlen = -1
for seqlen in seqlens:
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
max_seqlen = max(max_seqlen, seqlen)
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
return cls(
max_seqlen=max_seqlen,
min_seqlen=min_seqlen,
seqstart=seqstart,
seqstart_py=seqstart_py,
)
def split(
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
) -> List[torch.Tensor]:
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
raise ValueError(
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
f" seqstart: {self.seqstart_py}"
)
if batch_sizes is None:
batch_sizes = [1] * (len(self.seqstart_py) - 1)
split_chunks = []
it = 0
for batch_size in batch_sizes:
split_chunks.append(
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
)
it += batch_size
return [
tensor.reshape([bs, -1, *tensor.shape[2:]])
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
]
@dataclass
class _PaddedSeqLenInfo(_SeqLenInfo):
"""
(Internal) Represents the division of a dimension into blocks which are
padded out to the same total length.
For example, to represent a dimension of length 12 with space for
three blocks of length 4, but where the occupied lengths are
2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
The layout along the dimension is
0 ─► block 0
block 0
<space>
<space>
4 ─► block 1
block 1
block 1
<space>
8 ─► block 2
block 2
<space>
<space>
12 ─►
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 4, 8, 12]
seqstart: torch.IntTensor([0, 4, 8, 12])
seqlen_py: [2, 3, 2]
seqlen: torch.IntTensor([2, 3, 2])
padding: 4
"""
seqlen: torch.Tensor
seqlen_py: Sequence[int]
padding: int
# From parent: seqstart[i] contains the start position
# of the i-th sequence
# seqstart: torch.Tensor
def __post_init__(self) -> None:
assert len(self.seqstart_py) == len(self.seqlen_py) + 1
def to(self, device: torch.device) -> None:
self.seqlen = self.seqlen.to(device, non_blocking=True)
super().to(device)
def intervals(self) -> Iterable[Tuple[int, int]]:
for (start, _), length in zip(super().intervals(), self.seqlen_py):
yield start, start + length
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
raise RuntimeError(
"Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
)
@classmethod
def from_seqlens_padded(
cls, seqlens: Sequence[int], padding: int
) -> "_PaddedSeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
seqstart = padding * torch.arange(batch_size)
"""
assert not isinstance(seqlens, torch.Tensor)
assert all(seqlen <= padding for seqlen in seqlens)
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
return cls(
seqlen=torch.tensor(seqlens, dtype=torch.int32),
seqlen_py=seqlens,
max_seqlen=max(seqlens),
min_seqlen=min(seqlens),
seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
seqstart_py=seqstart_py,
padding=padding,
)
def split(
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
) -> List[torch.Tensor]:
raise NotImplementedError("_PaddedSeqLenInfo.split")
@dataclass
class BlockDiagonalMask(AttentionBias):
"""
A block-diagonal mask that can be passed as ``attn_bias``
argument to :attr:`xformers.ops.memory_efficient_attention`.
Queries and Keys are each divided into the same number of blocks.
Queries in block i only attend to keys in block i.
.. figure:: /_static/block_diag_bias.png
This bias can be used to handle a batch of sequences of
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
:Example:
.. code-block:: python
import torch
from xformers.ops import fmha
K = 16
dtype = torch.float16
device = "cuda"
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
print(list_out[0].shape) # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)
"""
q_seqinfo: _SeqLenInfo
k_seqinfo: _SeqLenInfo
_batch_sizes: Optional[Sequence[int]] = None
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return torch.zeros(
shape,
dtype=dtype,
device=device,
)
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""Materialize the attention bias - for debugging & testing"""
assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
shape[-1],
self.k_seqinfo.seqstart_py[-1],
)
assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
shape[-2],
self.q_seqinfo.seqstart_py[-1],
)
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
mask.fill_(-math.inf)
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
dtype=dtype,
device=device,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(
cls,
q_seqlen: Sequence[int],
kv_seqlen: Optional[Sequence[int]] = None,
) -> "BlockDiagonalMask":
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
Args:
q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
(Defaults to ``q_seqlen``.)
Returns:
BlockDiagonalMask
"""
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
if kv_seqlen is None or q_seqlen == kv_seqlen:
k_seqinfo = q_seqinfo
else:
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
@classmethod
def from_tensor_list(
cls,
tensors: Sequence[torch.Tensor],
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
concatenated on the sequence length dimension
.. figure:: /_static/block_diag_cat_split.png
See also :attr:`BlockDiagonalMask.split` to split the returned
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
Args:
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
All tensors should have the same dimension and the same batch size ``B``, but
they can have different sequence length ``M``.
Returns:
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
"""
batch_sizes = [tensor.shape[0] for tensor in tensors]
seqlens = []
for x in tensors:
for _ in range(x.shape[0]):
seqlens.append(x.shape[1])
block_diag = cls.from_seqlens(seqlens)
block_diag._batch_sizes = batch_sizes
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
concat_tensors = torch.cat(tensors_bs1, dim=1)
return block_diag, concat_tensors
@classmethod
def from_tensor_lists_qkv(
cls,
tensors_q: Sequence[torch.Tensor],
tensors_k: Sequence[torch.Tensor],
tensors_v: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert len(tensors_q) == len(tensors_k)
assert tensors_v is None or len(tensors_v) == len(tensors_q)
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
q_seqlens, kv_seqlens = [], []
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
assert q.shape[0] == k.shape[0]
q_seqlens += [q.shape[1]] * q.shape[0]
kv_seqlens += [k.shape[1]] * k.shape[0]
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
block_diag._batch_sizes = batch_sizes
return (
block_diag,
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
if tensors_v is not None
else None,
)
def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.q_seqinfo.split(tensor, self._batch_sizes)
def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.k_seqinfo.split(tensor, self._batch_sizes)
def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
Args:
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
Returns:
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
"""
assert self.q_seqinfo is self.k_seqinfo
return self.q_seqinfo.split(tensor, self._batch_sizes)
def make_causal(self) -> "BlockDiagonalCausalMask":
"""Makes each block causal"""
return BlockDiagonalCausalMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)
def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
"""Makes each block causal with a possible non-causal prefix"""
return BlockDiagonalCausalFromBottomRightMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)
def make_local_attention(
self, window_size: int
) -> "BlockDiagonalCausalLocalAttentionMask":
"""Experimental: Makes each block causal with local attention"""
return BlockDiagonalCausalLocalAttentionMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
_window_size=window_size,
)
def make_local_attention_from_bottomright(
self, window_size: int
) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
"""Experimental: Makes each block causal with local attention, start from bottom right"""
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
_window_size=window_size,
)
@dataclass
class BlockDiagonalCausalMask(BlockDiagonalMask):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
Queries and Keys are each divided into the same number of blocks.
A query Q in block i cannot attend to a key which is not in block i,
nor one which is farther from the initial key in block i than Q
is from the initial query in block i.
"""
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return LowerTriangularMask().materialize(
shape,
dtype=dtype,
device=device,
)
@dataclass
class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
This mask allows for a non-causal prefix
NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
defined (softmax of vector of `-inf` in the attention)
Queries and keys are each divided into the same number of blocks.
A query Q in block i cannot attend to a key which is not in block i,
nor one which nearer the final key in block i than Q is to the
final query in block i.
"""
def __post_init__(self) -> None:
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
num_queries = q_end - q_start
num_keys = k_end - k_start
if num_keys < num_queries:
raise ValueError(
f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
" Expected `num_keys >= num_queries`"
)
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
num_queries, num_keys = shape[-2:]
return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
@dataclass
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
except an offset on causality is allowed for each block and we support padding for k/v
The keys and values are divided into blocks which are padded out to
the same total length.
For example, if there is space for 12 keys, for three blocks of
max length 4, but we only want to use the first 2, 3 and 2
of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
The queries are divided into blocks, without padding, of lengths given by
q_seqlen.
A query Q in block i cannot attend to a key which is not in block i,
nor one which is not in use (i.e. in the padded area),
nor one which is nearer to the final key in block i
than Q is to the final query in block i.
"""
q_seqinfo: _SeqLenInfo
k_seqinfo: _PaddedSeqLenInfo
causal_diagonal: Any = None # unused. Exists for BC only.
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
num_queries, num_keys = shape[-2:]
return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""Materialize the attention bias - for debugging & testing"""
if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
raise ValueError("k shapes wrong")
if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
raise ValueError("q shapes wrong")
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
mask.fill_(-math.inf)
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
dtype=dtype,
device=device,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(
cls,
q_seqlen: Sequence[int],
kv_padding: int,
kv_seqlen: Sequence[int],
causal_diagonal: Any = None,
) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
"""Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
lengths for query and key/value.
Args:
q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
kv_padding (int): Padding for k/v - also an upperbound on each individual key length
kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
causal_diagonal: unused, for BC only
Returns:
BlockDiagonalCausalWithOffsetPaddedKeysMask
"""
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
q_seqlen,
kv_seqlen,
)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
@dataclass
class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
"""
(Experimental feature)
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
This makes the mask "local" and the attention pattern banded.
Query i only attends to keys in its block and cannot attend keys further than "window_size"
from it.
"""
_window_size: int = 0 # forced due to inheritance and default arguments
def __post_init__(self):
if self._window_size <= 0:
raise ValueError(
f"Expected `window_size > 0`, but window_size={self._window_size}"
)
q_seqlen = [
y - x
for x, y in zip(
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
)
]
kv_seqlen = [
y - x
for x, y in zip(
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
)
]
for q, k in zip(q_seqlen, kv_seqlen):
if q - self._window_size >= k:
raise RuntimeError(
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
)
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=1,
device=device,
)
num_queries, num_keys = shape[-2:]
mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
if self._window_size is not None and self._window_size > 0:
mask = torch.triu(mask, diagonal=-self._window_size + 1)
mask = torch.log(mask)
return mask.to(dtype)
@dataclass
class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
BlockDiagonalCausalFromBottomRightMask
):
"""
(Experimental feature)
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
This makes the mask "local" and the attention pattern banded.
Query i only attends to keys in its block and cannot attend keys further than "window_size"
from it.
"""
_window_size: int = 0 # forced due to inheritance and default arguments
def __post_init__(self):
super().__post_init__()
if self._window_size <= 0:
raise ValueError(
f"Expected `window_size > 0`, but window_size={self._window_size}"
)
q_seqlen = [
y - x
for x, y in zip(
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
)
]
kv_seqlen = [
y - x
for x, y in zip(
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
)
]
for q, k in zip(q_seqlen, kv_seqlen):
if q + (q - k) - self._window_size >= k:
raise RuntimeError(
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
)
materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
if torch.max(materialized, dim=1).values.min() == -float("inf"):
raise RuntimeError("FUCKING FUCK FUCK")
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=1,
device=device,
)
num_queries, num_keys = shape[-2:]
mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
if self._window_size is not None:
mask = torch.triu(
mask, diagonal=num_keys - num_queries - self._window_size + 1
)
mask = torch.log(mask)
return mask.to(dtype)

View File

@@ -0,0 +1,542 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
import torch
from ..._cpp_lib import _built_with_cuda
from ..common import BaseOperator
from .attn_bias import (
AttentionBias,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
# NoneType
if isinstance(None, attn_bias_type):
return True
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
return True
return False
@dataclass
class Inputs:
"""
Stores inputs to the `memory_efficient_attention` operators
"""
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
p: float = 0.0
scale: Optional[float] = None
use_alibi: bool = False
alibi_mode: int = 1
imp_mode: int = 0
@property
def device(self) -> torch.device:
return self.query.device
@property
def scale_float(self) -> float:
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.query.ndim == 5:
return self.query, self.key, self.value
if self.query.ndim == 4:
return (
self.query.unsqueeze(2),
self.key.unsqueeze(2),
self.value.unsqueeze(2),
)
if self.value.ndim == 3:
return (
self.query[:, :, None, None],
self.key[:, :, None, None],
self.value[:, :, None, None],
)
assert False
def normalize_bmhk(self) -> Tuple[int, ...]:
if self.query.ndim not in [3, 4, 5]:
raise ValueError(
f"Invalid shape for query: {self.query.shape}. "
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
)
if self.value.dtype == torch.int32:
# Quantized K/V case, in which the last dims of Q and K are different.
# NB we currently don't have any implementations for quantized KV with
# SUPPORTS_DIFFERENT_VALUE_EMBED.
output_shape = tuple(self.query.shape)
else:
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
# Convert from legacy format
if self.query.ndim == 3:
self.query = self.query.unsqueeze(2)
self.key = self.key.unsqueeze(2)
self.value = self.value.unsqueeze(2)
if isinstance(self.attn_bias, torch.Tensor):
if self.attn_bias.ndim != 3:
raise ValueError(
f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
)
self.attn_bias = self.attn_bias.unsqueeze(1)
return output_shape
def validate_inputs(self) -> None:
qkv = (self.query, self.key, self.value)
if self.query.ndim not in (3, 4, 5) or any(
x.ndim != self.query.ndim for x in qkv
):
raise ValueError(
f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if any(x.device != self.query.device for x in qkv):
raise ValueError("Query/Key/Value should all be on the same device")
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
if not (quantized_dtypes or non_quantized_dtypes):
raise ValueError(
"Query/Key/Value should either all have the same dtype, or "
"(in the quantized case) Key/Value should have dtype torch.int32\n"
f" query.dtype: {self.query.dtype}\n"
f" key.dtype : {self.key.dtype}\n"
f" value.dtype: {self.value.dtype}"
)
# Biases with tensors attached are meant to be in BMHK format
# This would require to permute biases/gradients which can be expensive,
# so let's just forbid it - BMK is a legacy format anyway
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
type(self.attn_bias)
):
raise ValueError(
f"Please provide inputs in BMHK format rather "
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
)
attn_bias_t: Optional[torch.Tensor] = None
if isinstance(self.attn_bias, torch.Tensor):
attn_bias_t = self.attn_bias
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
attn_bias_t = self.attn_bias._bias
if self.query.ndim == 4 and attn_bias_t is not None:
expected_shape = (
self.query.shape[0],
self.query.shape[2],
self.query.shape[1],
self.key.shape[1],
)
if attn_bias_t.shape != expected_shape:
raise ValueError(
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if isinstance(self.attn_bias, BlockDiagonalMask):
if any(x.shape[0] != 1 for x in qkv):
raise ValueError(
f"Expected batch_size=1 when using block-diagonal bias\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if self.p < 0.0 or self.p > 1.0:
raise ValueError(f"Invalid dropout probability: p={self.p}")
# Check that shapes match between inputs
B, Mq = self.query.shape[:2]
K = self.query.shape[-1]
B, Mkv = self.key.shape[:2]
Kv = self.value.shape[-1]
valid_shapes = True
if self.query.ndim == 3: # BMK
valid_shapes = (
self.query.shape == (B, Mq, K)
and self.key.shape == (B, Mkv, K)
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
if self.query.ndim == 4: # BMHK
quantized_kv_cache = self.value.dtype == torch.int32
key_embed_dim = Kv if quantized_kv_cache else K
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
valid_shapes = (
self.query.shape == (B, Mq, G, H, K)
and self.key.shape == (B, Mkv, G, H, K)
and self.value.shape == (B, Mkv, G, H, Kv)
)
if not valid_shapes:
raise ValueError(
f"Incompatible shapes for attention inputs:\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}\n"
"HINT: We don't support broadcasting, please use `expand` "
"yourself before calling `memory_efficient_attention` if you need to"
)
@dataclass
class Context:
lse: torch.Tensor
out: torch.Tensor
q_padded: Optional[torch.Tensor] = None
k_padded: Optional[torch.Tensor] = None
v_padded: Optional[torch.Tensor] = None
o_padded: Optional[torch.Tensor] = None
op_bw: Optional[Type["AttentionBwOpBase"]] = None
rng_state: Optional[torch.Tensor] = None
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
lse = self.lse
if pad_amount > 0:
if force_pad_inf:
lse = lse[:, :, : self.out.shape[1]]
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
lse[:, :, self.out.shape[1] :].fill_(math.inf)
return lse
@dataclass
class Gradients:
dq: torch.Tensor
dk: torch.Tensor
dv: torch.Tensor
# bias gradient. None if there is no tensor bias or if it doesn't require grad
db: Optional[torch.Tensor] = None
class AttentionOpBase(BaseOperator):
"""Base class for any attention operator in xFormers
See:
- :attr:`xformers.ops.fmha.cutlass.FwOp`
- :attr:`xformers.ops.fmha.cutlass.BwOp`
- :attr:`xformers.ops.fmha.flash.FwOp`
- :attr:`xformers.ops.fmha.flash.BwOp`
- :attr:`xformers.ops.fmha.triton.FwOp`
- :attr:`xformers.ops.fmha.triton.BwOp`
- :attr:`xformers.ops.fmha.small_k.FwOp`
- :attr:`xformers.ops.fmha.small_k.BwOp`
"""
OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"
_TEST_BATCH_SIZES: List[int] = [1, 300]
_TEST_K: List[int] = [32, 128]
@classmethod
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = []
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
reasons.append("query.shape[-1] != value.shape[-1]")
if max(K, Kv) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
reasons = cls.shape_not_supported_reasons(
Mq=d.query.shape[1],
Mkv=d.key.shape[1],
K=d.query.shape[-1],
Kv=d.query.shape[-1],
)
device_type = d.query.device.type
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
if device_type == "cuda" and not _built_with_cuda:
reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too old)"
)
if dtype not in cls.SUPPORTED_DTYPES:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
reasons.append("has custom scale")
# bfloat16 is only supported on A100+
# ... although the kernels can still run and give the
# correct result
if dtype is torch.bfloat16 and (
not device_type.startswith("cuda")
):
reasons.append("bf16 is only supported on A100+ GPUs")
if not cls.is_available():
reasons.append(
"operator wasn't built - see `python -m xformers.info` for more info"
)
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
reasons.append(
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
)
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
reasons.append("operator does not support BMGHK format")
return reasons
class AttentionFwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 3e-4,
torch.half: 4e-3,
torch.bfloat16: 2e-2,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 2e-5,
torch.half: 4e-4,
torch.bfloat16: 5e-3,
}
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# Q @ K.transpose
total_flop += num_q * num_kv * query.shape[-1] * 2
# (ignore softmax)
# attn @ V
total_flop += num_q * key.shape[-1] * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
class AttentionBwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 5e-4,
torch.half: 9e-2,
torch.bfloat16: 0.7,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 0.1,
}
SUPPORTS_ATTN_BIAS_GRAD = False
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
if (
isinstance(d.attn_bias, torch.Tensor)
and d.attn_bias.requires_grad
and not cls.SUPPORTS_ATTN_BIAS_GRAD
):
reasons.append(
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
Kqk = query.shape[-1]
Kv = value.shape[-1]
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# att = Q @ K.transpose
total_flop += num_q * num_kv * Kqk * 2
# att @ dO
total_flop += num_kv * num_q * Kv * 2
# dov = dO @ V
total_flop += num_q * Kv * num_kv * 2
# dov @ K
total_flop += num_q * Kqk * num_kv * 2
# dov @ Q
total_flop += num_q * Kqk * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
AttentionOp = Tuple[
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
]
@dataclass
class AttentionOpDispatch:
"""Dispatcher to automatically select
the best operator to run memory-efficient attention.
:Deprecated:
This class is deprecated and will be removed in a later version
"""
op: AttentionOp
@classmethod
def from_arguments(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> "AttentionOpDispatch":
"""Here for backward compatibility"""
from .dispatch import _dispatch_bw, _dispatch_fw
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
)
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)
def check_lastdim_alignment_stride1(
reasons: List[str], name: str, x: torch.Tensor, alignment: int
) -> None:
if x.shape[-1] % alignment != 0:
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
elif x.stride(-2) % alignment != 0:
reasons.append(
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
)
# We can have stride=0 sometimes if dimension=1
if x.stride(-1) > 1:
reasons.append(
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
)

View File

@@ -0,0 +1,479 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from enum import Enum
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
import torch
from ..common import get_xformers_operator, register_operator
from . import attn_bias
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
def _uses_tensorcores(sm: int, is_half: bool) -> bool:
if sm >= 80:
return True
if sm >= 70:
return is_half
return False
def _minimum_gemm_alignment(inp: Inputs) -> int:
if inp.device.type != "cuda":
return 1
cap = torch.cuda.get_device_capability(inp.device)
sm = cap[0] * 10 + cap[1]
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
inp.query.dtype
]
uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
matmul_alignment_mn = 1
if sm >= 80:
matmul_alignment_mn = 4
if uses_tensorcores:
matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
return matmul_alignment_mn
def _get_seqlen_info(
inp: Inputs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
attn_bias = inp.attn_bias
if isinstance(
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
):
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
seqstart_k = attn_bias.k_seqinfo.seqstart
seqstart_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
else:
seqstart_k = None
seqstart_q = None
max_seqlen_q = -1
max_seqlen_k = -1
return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
def _get_tensor_bias(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Optional[torch.Tensor]:
if isinstance(attn_bias, torch.Tensor):
return attn_bias
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
return attn_bias._bias
return None
def _check_bias_alignment(
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> None:
attn_bias_tensor = _get_tensor_bias(attn_bias)
if attn_bias_tensor is not None:
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
show_padding_hint = False
for d in range(attn_bias_tensor.ndim - 1):
if attn_bias_tensor.stride(d) % alignment != 0:
reasons.append(
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
)
show_padding_hint = True
if show_padding_hint:
reasons.append(
"""\
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
you need to ensure memory is aligned by slicing a bigger tensor. \
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
)
# We can have stride=0 sometimes if dimension=1
if attn_bias_tensor.stride(-1) > 1:
reasons.append(
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
"you should call `.contiguous()` on the bias"
)
class _CustomMaskType(int, Enum):
"""
(Matches CustomMaskType in C++.)
"""
NoCustomMask = 0
CausalFromTopLeft = 1
CausalFromBottomRight = 2
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
if isinstance(
bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
),
):
return int(_CustomMaskType.CausalFromTopLeft)
if isinstance(
bias,
(
attn_bias.BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
):
return int(_CustomMaskType.CausalFromBottomRight)
return int(_CustomMaskType.NoCustomMask)
@register_operator
class FwOp(AttentionFwOpBase):
"""xFormers' MHA kernel based on CUTLASS.
Supports a large number of settings (including without TensorCores, f32 ...)
and GPUs as old as P100 (Sm60)
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 65536
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
torch.Tensor,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
attn_bias.BlockDiagonalCausalFromBottomRightMask,
attn_bias.BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = True
SUPPORTS_BMGHK = True
NAME = "cutlassF"
_TEST_K: List[int] = [
32, # 64x64 kernel
128, # 64x128 kernel
256, # 64x128 with accumulation in gmem
]
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
if inp.query.ndim in [3, 4]:
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
# Workaround until this is properly implemented in C++
# run each head group in a different stream
n_groups = inp.key.shape[2]
main_stream = torch.cuda.current_stream()
streams = [main_stream] + [
torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
]
outs = []
for group, stream in enumerate(streams):
stream.wait_stream(main_stream)
with torch.cuda.stream(stream):
query = inp.query[:, :, group]
key = inp.key[:, :, group]
value = inp.value[:, :, group]
bias = inp.attn_bias
if isinstance(bias, torch.Tensor):
bias = bias[:, group]
if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias):
bias = attn_bias.LowerTriangularMaskWithTensorBias(
bias._bias[:, group]
)
outs.append(
cls.apply_bmhk(
replace(inp, query=query, key=key, value=value, attn_bias=bias),
needs_gradient=needs_gradient,
)
)
for s in streams[1:]:
main_stream.wait_stream(s)
out = torch.stack([o[0] for o in outs], dim=2)
ctx: Optional[Context] = None
if needs_gradient:
ctx = Context(
out=out,
lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
op_bw=outs[0][1].op_bw, # type: ignore
)
return out, ctx
@classmethod
def apply_bmhk(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
out, lse, rng_seed, rng_offset = cls.OPERATOR(
query=inp.query,
key=inp.key,
value=inp.value,
attn_bias=_get_tensor_bias(inp.attn_bias),
seqstart_q=seqstart_q,
seqstart_k=seqstart_k,
max_seqlen_q=max_seqlen_q,
dropout_p=inp.p,
compute_logsumexp=needs_gradient,
custom_mask_type=_custom_mask_type(inp.attn_bias),
scale=inp.scale,
seqlen_k=inp.attn_bias.k_seqinfo.seqlen
if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
else None,
window_size=inp.attn_bias._window_size
if isinstance(
inp.attn_bias,
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
else None,
)
ctx: Optional[Context] = None
if needs_gradient:
ctx = Context(
out=out,
lse=lse,
# cutlass forward is only compatible with cutlass backward if
# dropout is used (because of the way RNG states are passed and the
# way random numbers are generated during backward)
op_bw=BwOp if inp.p != 0 else None,
)
if inp.p != 0:
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
return reasons
@classmethod
# type: ignore
def operator_flop(
cls,
q,
k,
v,
b,
seqstart_q,
seqstart_k,
max_seqlen_q_,
compute_lse,
custom_mask_type,
*a,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
causal=custom_mask_type > 0,
seqstart_k=seqstart_k,
seqstart_q=seqstart_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
torch.Tensor,
LowerTriangularMask,
# TODO: Fix handling of gradient through the fMHA autograd function
# LowerTriangularMaskWithTensorBias,
BlockDiagonalMask,
BlockDiagonalCausalMask,
attn_bias.BlockDiagonalCausalFromBottomRightMask,
attn_bias.BlockDiagonalCausalLocalAttentionMask,
}
SUPPORTS_ATTN_BIAS_GRAD = True
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
NAME = "cutlassB"
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 5e-4,
# increased from 9e-2, more opportunities for numerical errors when bias is
# used, noticed in gK on SM80
torch.half: 1e-1,
torch.bfloat16: 7e-1,
}
_TEST_K: List[int] = [
32, # 64x64 kernel
128, # 64x128/128x128 kernel
256, # 64x128 with accumulation in gmem
]
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
# Backprop of gradient through broadcasted bias is not supported
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
# Don't forget that inputs are either in BMK or BMHK!
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
else:
# bias is B H Mq Mk
expected_bias_shape = (
d.query.shape[0],
d.query.shape[2] if d.query.ndim == 4 else 1,
d.query.shape[1],
d.key.shape[1],
)
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
reasons.append(
"Broadcasting the `attn_bias` tensor is not supported "
f"(shape: {tuple(attn_bias_tensor.shape)}"
f"/ expected: {expected_bias_shape})"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
dtype = inp.query.dtype
rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()
force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
grad.to(dtype),
inp.query,
inp.key,
inp.value,
_get_tensor_bias(inp.attn_bias),
cu_seqlens_q=seqstart_q,
cu_seqlens_k=seqstart_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
output=ctx.out.to(dtype),
dropout_p=inp.p,
# if not using dropout, seed and offset are irrelevant but still expected
# in function signature so just pass 0
# seed and offset could be None if a different FW op other than cutlass
# was used.
rng_seed=rng_seed,
rng_offset=rng_offset,
custom_mask_type=_custom_mask_type(inp.attn_bias),
scale=inp.scale,
num_splits_key=-1, # Let C++ determine it
window_size=inp.attn_bias._window_size
if isinstance(
inp.attn_bias,
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
else None,
)
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
# require grad
if not (
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
):
grad_bias = None
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
@classmethod
# type: ignore
def operator_flop(
cls,
dO,
q,
k,
v,
b,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
logsumexp,
output,
dropout_p,
rng_seed,
rng_offset,
custom_mask_type,
scale,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
seqstart_q=cu_seqlens_q,
seqstart_k=cu_seqlens_k,
causal=custom_mask_type > 0,
)

View File

@@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional, Set, Tuple
import numpy as np
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs
@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
SUPPORTED_MAX_K: float = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
NAME = "decoderF"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
attn_bias = d.attn_bias
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
# If we don't get here, we've an error elsewhere
if d.query.ndim != 4 or d.key.ndim != 4:
reasons.append("Inputs must be BMHK. BMK not supported")
if d.query.shape[0] != 1:
reasons.append("One formal batch element expected")
if d.query.shape[-1] != 128:
reasons.append("Only head_dim==128 for now.")
if d.key.stride(-1) != 1:
reasons.append("expect keys to have last dim contiguous")
if d.value.stride(-1) != 1:
reasons.append("expect values to have last dim contiguous")
q_starts = attn_bias.q_seqinfo.seqstart_py
if attn_bias.q_seqinfo.max_seqlen != 1:
reasons.append("decoding expects one query")
elif d.query.shape[1] != len(q_starts) - 1:
reasons.append("empty lanes not supported yet")
if attn_bias.k_seqinfo.padding > 8192:
reasons.append("key padding exceeds 8192")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if needs_gradient:
raise NotImplementedError("gradient")
attn_bias = inp.attn_bias
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
padding = attn_bias.k_seqinfo.padding
multiquery = inp.key.stride(2) == 0
if multiquery:
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
else:
key = inp.key[0].unflatten(0, (-1, padding))
value = inp.value[0].unflatten(0, (-1, padding))
seq_positions = attn_bias.k_seqinfo.seqlen
query = inp.query[0, :, None]
if inp.scale is not None:
qk_scale = inp.scale
else:
qk_scale = 1.0 / np.sqrt(key.shape[-1])
out = cls.OPERATOR(
query=query,
key=key,
value=value,
seq_positions=seq_positions,
scale=qk_scale,
)
return out, None

View File

@@ -0,0 +1,147 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import textwrap
from collections import deque
from typing import List, Sequence, Type, TypeVar
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
return False
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
# TODO: fill out
return False
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
def _format_inputs_description(inp: Inputs) -> str:
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
attn_bias : {type(inp.attn_bias)}
p : {inp.p}"""
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
reasons = op.not_supported_reasons(inp)
if not reasons:
return
raise exc_type(
f"""Operator `{name}` does not support inputs:
{textwrap.indent(_format_inputs_description(inp), ' ')}
{_format_not_supported_reasons(op, reasons)}"""
)
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
not_supported_reasons: List[List[str]] = []
for op in priority_list:
not_supported = op.not_supported_reasons(inp)
if not not_supported:
return op
not_supported_reasons.append(not_supported)
# Let's write a nice message explaining what we tried and why it's not supported
msg = f"""No operator found for `{name}` with inputs:
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
for op, not_supported in zip(priority_list, not_supported_reasons):
msg += "\n" + _format_not_supported_reasons(op, not_supported)
raise NotImplementedError(msg)
def _dispatch_fw_priority_list(
inp: Inputs, needs_gradient: bool
) -> Sequence[Type[AttentionFwOpBase]]:
priority_list_ops = deque(
[
flash.FwOp,
triton.FwOp,
cutlass.FwOp,
small_k.FwOp,
]
)
if _is_cutlass_fwd_faster_than_flash(inp):
priority_list_ops.remove(cutlass.FwOp)
priority_list_ops.appendleft(cutlass.FwOp)
if _is_triton_fwd_fastest(inp):
priority_list_ops.remove(triton.FwOp)
priority_list_ops.appendleft(triton.FwOp)
if not needs_gradient:
mqa_or_gqa = (
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
)
if not mqa_or_gqa:
# With multiquery, cutlass is sometimes faster than decoder
# but it's not currently clear when.
priority_list_ops.appendleft(decoder.FwOp)
# Split-KV is useful with MQA
# for short Q-seqlen / long K-seqlen
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
parallelism_BH = 0 # BMK
if inp.query.ndim == 3:
parallelism_BH = inp.query.shape[0]
elif inp.query.ndim == 4: # BMHK
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
elif inp.query.ndim == 5: # BMGHK
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
if parallelism_BH > 0 and parallelism_BH < 64:
priority_list_ops.appendleft(triton_splitk.FwOp)
# Without variable seqlen flash is fastest
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
priority_list_ops.remove(flash.FwOp)
priority_list_ops.appendleft(flash.FwOp)
return priority_list_ops
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
"""Computes the best operator for forward
Raises:
NotImplementedError: if not operator was found
Returns:
AttentionOp: The best operator for the configuration
"""
# return _run_priority_list(
# "memory_efficient_attention_forward",
# _dispatch_fw_priority_list(inp, needs_gradient),
# inp,
# )
return flash.FwOp
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
return False
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
priority_list_ops: List[Type[AttentionBwOpBase]] = [
flash.BwOp,
cutlass.BwOp,
# CUDA illegal memory issues, race conditions etc..
# triton.BwOp,
# Deprecated
small_k.BwOp,
]
# if _is_cutlassB_faster_than_flash(inp):
# priority_list_ops.remove(cutlass.BwOp)
# priority_list_ops.insert(0, cutlass.BwOp)
# return _run_priority_list(
# "memory_efficient_attention_backward", priority_list_ops, inp
# )
return flash.BwOp

View File

@@ -0,0 +1,666 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from itertools import zip_longest
from typing import Any, List, Optional, Set, Tuple, Union
import os
import torch
from ..common import _get_storage_base, get_operator, register_operator
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
global enable_ixdnn
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
FLASH_VERSION = "0.0.0"
try:
try:
from ... import _C_flashattention # type: ignore[attr-defined]
from ..._cpp_lib import _build_metadata
if _build_metadata is not None:
FLASH_VERSION = _build_metadata.flash_version
except ImportError:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
FLASH_VERSION = flash_attn.__version__
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
if flash_ver_parsed < (2, 3):
raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
_flash_lib = torch.library.Library("xformers_flash", "DEF")
_flash_lib.define(
"flash_fwd(Tensor query, Tensor key, Tensor value, "
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)
_flash_lib.define(
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
)
def _flash_fwd(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
return_softmax,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd_ixdnn(
query,
key,
value,
None, # out
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None, # rng
)
else:
if cu_seq_lens_q is None:
assert cu_seq_lens_k is None
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd(
query,
key,
value,
None, # out
p,
softmax_scale,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None, # rng
)
else:
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.varlen_fwd(
query,
key,
value,
out,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None,
)
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
def _flash_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
_C_flashattention.bwd_ixdnn(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None,
)
else:
if cu_seq_lens_k is None:
assert cu_seq_lens_q is None
_C_flashattention.bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
p,
softmax_scale,
is_causal,
use_alibi,
alibi_mode,
None,
)
else:
_C_flashattention.varlen_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False, # zero_tensors
is_causal,
use_alibi,
alibi_mode,
None,
)
return dq, dk, dv
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
def _convert_input_format(
inp: Inputs,
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
assert inp.query.ndim in [4, 5]
query, key, value = inp.query, inp.key, inp.value
batch = query.shape[0]
seqlen_q = query.shape[1]
seqlen_kv = key.shape[1]
head_dim_q = query.shape[-1]
head_dim_v = value.shape[-1]
attn_bias = inp.attn_bias
if enable_ixdnn:
if query.shape[0] == 1:
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
else:
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
else:
if isinstance(attn_bias, BlockDiagonalMask):
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
else:
cu_seqlen_k = None
cu_seqlen_q = None
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
if query.ndim == 5: # QGA
# Fold the group/head_in_group dimensions together
def fold(x):
# Either the head is replicated
if x.stride(3) == 0:
return x[:, :, :, 0]
# Or we reshape
return x.reshape(
[
x.shape[0],
x.shape[1],
-1,
x.shape[4],
]
)
query = fold(query)
key = fold(key)
value = fold(value)
# Optimize for MHA
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
key = key[:, :, :1]
value = value[:, :, :1]
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
# We want format `[batch * seqlen, num_heads, head_dim_q]`
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
new_inp = replace(
inp,
query=query,
key=key,
value=value,
)
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
return isinstance(
attn_bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
if isinstance(
attn_bias,
(BlockDiagonalCausalLocalAttentionMask,),
):
return attn_bias._window_size or 0
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
return attn_bias._window_size
return 0
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
# Flash does not support TopLeft, so only allow causal masks with TopLeft
# if each batch element has equal number of queries and keys.
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
# if each batch element has equal number of queries and keys.
for k_start, q_start in zip_longest(
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
):
if k_start != q_start:
reasons.append(
"Only support BlockDiagonalCausalMask if equal"
" numbers of keys and queries"
)
break
elif isinstance(d.attn_bias, LowerTriangularMask):
if d.query.shape[1] != d.key.shape[1]:
reasons.append(
"Only support LowerTriangularMask if equal number of" "keys and queries"
)
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
"""
We want to be able to collapse the G/H dimensions together
"""
if x.ndim == 5:
stride_g, stride_h = x.stride(2), x.stride(3)
if x.shape[2] == 1:
return
if x.shape[3] == 1 or stride_h == 0:
return
if stride_g != stride_h * x.shape[-2]:
reasons.append(
f"GQA is only supported when the G/H dimensions are contiguous\n"
f" {name}.stride: {x.stride()}\n"
f" {name}.shape : {list(x.shape)}"
)
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
implementation.
"""
OPERATOR = get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 256
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = False
SUPPORTS_BMGHK = True
NAME = f"flshattF@{FLASH_VERSION}"
VERSION = FLASH_VERSION
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
_check_strides_for_bmghk(d.query, "query", reasons)
_check_strides_for_bmghk(d.key, "key", reasons)
_check_strides_for_bmghk(d.value, "value", reasons)
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
return_softmax = False
out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# no cumulative seqlen
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
out = out.reshape(out_shape)
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
if inp.p != 0.0:
ctx.op_bw = BwOp
return (out, ctx)
@classmethod
# type: ignore
def operator_flop(
cls,
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
return_softmax,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
IS_DETERMINISTIC = False
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
NAME = f"flshattB@{FLASH_VERSION}"
VERSION = FLASH_VERSION
MAX_HEADDIM_SM8x = 192
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
# is a strange embedding dimension.
# if K not in {8, 16, 32, 64, 128, 256}:
# reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
if d.device.type == "cuda":
# Due to limited shared-memory, some GPUs are limited in head dimension
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
if (
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
and not is_sm80_or_sm90
):
reasons.append(
"requires a GPU with compute capability 8.0 "
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
assert ctx.lse.is_contiguous
ctx_lse = ctx.lse
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
assert ctx_lse.shape[2] >= max_seqlen_q
if max_seqlen_q != ctx_lse.shape[2]:
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
kernel_out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# Create dq,dk,dv
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
# right strides, so we can avoid a `cat`
if (
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
):
# Create one big contiguous chunk
# This is because q, k and v usually come from a single
# output of a linear layer that is chunked.
# Creating the gradients with the right layout saves us
# a `torch.cat` call in the backward pass
chunk = torch.empty(
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
dtype=ctx.q_padded.dtype,
device=inp.device,
)
grads = Gradients(
dq=chunk.select(-3, 0),
dk=chunk.select(-3, 1),
dv=chunk.select(-3, 2),
)
else:
grads = Gradients(
dq=torch.empty_like(ctx.q_padded),
dk=torch.empty_like(ctx.k_padded),
dv=torch.empty_like(ctx.v_padded),
)
assert grad.dtype in cls.SUPPORTED_DTYPES
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
ctx.q_padded,
ctx.k_padded,
ctx.v_padded,
# ctx.out.reshape(kernel_out_shape),
ctx.o_padded,
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
return grads
@classmethod
# type: ignore
def operator_flop(
cls,
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)

View File

@@ -0,0 +1,186 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import AttentionBias
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
return (
tensor.permute((0, 2, 1, 3))
.contiguous()
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
.contiguous()
)
def _get_tensor_bias_bmk(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Optional[torch.Tensor]:
if not isinstance(attn_bias, torch.Tensor):
assert attn_bias is None
return None
# BMK -> BMHK
if attn_bias.ndim == 4:
attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]])
return attn_bias
@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_small_k")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.float}
SUPPORTED_MAX_K: float = 32
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = False
NAME = "smallkF"
BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
# as this kernel is a bit slow, this should make tests run faster
_TEST_BATCH_SIZES = [1, 3]
_TEST_K = [2, 3, 8, 16, 32]
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if inp.scale is not None:
raise NotImplementedError("Unsupport custom scale")
num_heads = inp.query.shape[2]
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out, lse, rng_seed, rng_offset = cls.OPERATOR(
query=query,
key=key,
value=value,
compute_logsumexp=needs_gradient,
attn_bias=_get_tensor_bias_bmk(inp.attn_bias),
p=inp.p,
)
out = bmk2bmhk(out, num_heads)
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
if not needs_gradient:
return out, None
ctx = Context(out=out, lse=lse)
if inp.p != 0.0:
ctx.op_bw = BwOp
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
NAME = "smallkB"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
num_heads = grad.shape[2]
grad = _bmhk2bmk_contiguous(grad)
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out = _bmhk2bmk_contiguous(ctx.out)
rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()
grad_q, grad_k, grad_v = cls.OPERATOR(
grad,
query,
key,
value,
# LSE: BHM -> (BH)M
ctx.lse.reshape([-1, ctx.lse.shape[-1]]),
out,
_get_tensor_bias_bmk(inp.attn_bias),
inp.p,
rng_seed,
rng_offset,
)
return Gradients(
dq=bmk2bmhk(grad_q, num_heads),
dk=bmk2bmhk(grad_k, num_heads),
dv=bmk2bmhk(grad_v, num_heads),
)

View File

@@ -0,0 +1,201 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
import torch
from ... import _is_triton_available
from ..common import register_operator
# This implementation needs pre-MLIR triton
# The BW pass is not stable/well tested
# And also does not have the latest improvements
if TYPE_CHECKING or (False and _is_triton_available()):
try:
from flash_attn.flash_attn_triton import (
_flash_attn_backward,
_flash_attn_forward,
)
except ImportError:
import importlib
import pathlib
import sys
import types
def import_module_from_path(path: str) -> types.ModuleType:
"""Import a module from the given path, w/o __init__.py"""
module_path = pathlib.Path(path).resolve()
module_name = module_path.stem # 'path/x.py' -> 'x'
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
assert isinstance(spec, importlib.machinery.ModuleSpec)
module = importlib.util.module_from_spec(spec) # type: ignore
sys.modules[module_name] = module
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
return module
flash_attn = import_module_from_path(
"third_party/flash-attention/flash_attn/flash_attn_triton.py"
)
_flash_attn_backward = flash_attn._flash_attn_backward
_flash_attn_forward = flash_attn._flash_attn_forward
triton_flash_backward = _flash_attn_backward
triton_flash_forward = _flash_attn_forward
else:
triton_flash_backward = None
triton_flash_forward = None
from .attn_bias import LowerTriangularMask
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
def _prepare_inputs(inp: Inputs) -> Inputs:
attn_bias = inp.attn_bias
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
B = inp.query.shape[0]
h = attn_bias.shape[0] // B
attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])
# Make sure that the last dimension is contiguous
query, key, value = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [inp.query, inp.key, inp.value]
]
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
implementation, based on
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
"""
OPERATOR = triton_flash_forward
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
# TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now.
# torch.Tensor,
}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
NAME = "tritonflashattF"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
# Has only been tested on 8.0 / 9.0.
# Fails on 7.5 with illegal memory access
if torch.cuda.get_device_capability(d.device) < (8, 0):
reasons.append(
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
)
if _is_triton_available():
import triton
if triton.__version__ > "2.0.0":
reasons.append("Only work on pre-MLIR triton for now")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
inp = _prepare_inputs(inp)
out, lse, softmax_scale = triton_flash_forward(
q=inp.query,
k=inp.key,
v=inp.value,
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
softmax_scale=inp.scale_float,
causal=isinstance(inp.attn_bias, LowerTriangularMask),
)
return out, Context(lse=lse, out=out)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = triton_flash_backward
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
NAME = "tritonflashattB"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
if torch.cuda.get_device_capability(d.device) != (8, 0):
reasons.append("requires A100 GPU")
if _is_triton_available():
import triton
if triton.__version__ > "2.0.0":
reasons.append("Only work on pre-MLIR triton for now")
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
inp = _prepare_inputs(inp)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
grads = Gradients(
dq=torch.empty_like(inp.query),
dk=torch.empty_like(inp.key),
dv=torch.empty_like(inp.value),
)
cls.OPERATOR(
grad,
inp.query,
inp.key,
inp.value,
ctx.out,
ctx.get_padded_lse(128),
grads.dq,
grads.dk,
grads.dv,
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
softmax_scale=inp.scale_float,
causal=isinstance(inp.attn_bias, LowerTriangularMask),
)
return grads

View File

@@ -0,0 +1,738 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
import torch
from ..common import _has_triton21, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
def _strides(x: torch.Tensor, *stride_names: str):
assert x.ndim == len(stride_names)
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
if TYPE_CHECKING or _has_triton21():
import triton
import triton.language as tl
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
@triton.jit
def _fwd_kernel_splitK(
Q,
K,
V,
sm_scale,
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Seq_len,
stride_qz,
stride_qm,
stride_qg,
stride_qh,
stride_qk,
stride_kz,
stride_kn,
stride_kg,
stride_kh,
stride_kk,
stride_vz,
stride_vn,
stride_vg,
stride_vh,
stride_vk,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
Z,
N_CTX_Q,
N_CTX_K,
BLOCK_N_PER_SPLIT,
H: tl.constexpr,
G: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_SEQ_LEN: tl.constexpr,
PACKED_PER_VAL: tl.constexpr = 1,
N_GROUPS: tl.constexpr = 1,
):
"""This kernel can accept non-quantized or int4-quantized keys/values.
PACKED_PER_VAL determines the quantization type:
- PACKED_PER_VAL == 1 means no quantization
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
For the quantized case K/V should be int32 tensors.
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
So K[B, H, M, :] has a form
[ quant_coef0, quant_coef1, ...|
group0_quant_value0, group0_quant_value1,... |
group1_quant_value0, group1_quant_value1,...]
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
See how FwOp.apply does it below.
"""
tl.static_assert(
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
)
tl.static_assert(
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
)
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
splitk_idx = tl.program_id(2)
lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_SEQ_LEN:
kv_len = tl.load(Seq_len + off_z)
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
Q_block_ptr = tl.make_block_ptr(
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
# Additional shift by 1 along the last dimension in the quantized case, since
# the first element along that dim contains packed quantization coefficients.
K_block_ptr = tl.make_block_ptr(
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
shape=(PACKED_D_PER_GROUP, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
order=(0, 1),
)
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
V_block_ptr = tl.make_block_ptr(
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
shape=(hi, PACKED_D_PER_GROUP),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
order=(1, 0),
)
if QUANTIZED:
# Pointers to quantization coefficients. Even those they are 1D,
# we have to use block pointers, since usual pointers
# don't support boundary checks
K_scale_shift_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(1, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(1, BLOCK_N),
order=(0, 1),
)
V_scale_shift_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, 1),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, 1),
order=(1, 0),
)
else:
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
# This is a solution for Triton native lack of support for lists of tensors.
acc: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
q[i] = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
k: "VAR_ARGS_ARRAY" # noqa: F821
v: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
PACKED_PER_VAL,
PACKED_D_PER_GROUP,
Q.dtype.element_ty,
i,
)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for i in range(len(acc)): # noqa: F821
qk += tl.dot(q[i], k[i]) # noqa: F821
qk *= qk_scale
# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
p = p.to(Q.dtype.element_ty)
# -- scale and update acc --
for i in range(len(acc)): # noqa: F821
acc[i] *= alpha[:, None] # noqa: F821
acc[i] += tl.dot(p, v[i]) # noqa: F821
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if PACKED_PER_VAL > 1:
K_scale_shift_block_ptr = tl.advance(
K_scale_shift_block_ptr, (0, BLOCK_N)
)
V_scale_shift_block_ptr = tl.advance(
V_scale_shift_block_ptr, (BLOCK_N, 0)
)
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_osk_m, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
for i in range(len(acc)): # noqa: F821
tl.store(
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
acc[i], # noqa: F821
boundary_check=(0,),
)
# Write metadata for split-K reduction
Metadata_ptr = (
Metadata
+ off_zhg * stride_mzhg
+ splitk_idx * stride_ms
+ start_m * BLOCK_M
+ tl.arange(0, BLOCK_M)
)
tl.store(Metadata_ptr, m_i)
tl.store(Metadata_ptr + stride_m2, l_i)
@triton.jit
def load_dequantize_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N: tl.constexpr,
PACKED_PER_VAL: tl.constexpr,
PACKED_D_PER_GROUP: tl.constexpr,
dtype: tl.constexpr,
group_id: tl.constexpr,
):
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
If quantization is group-wise, use group_id to advance the pointers to the current group.
"""
# Advance to the current quantization group
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
if PACKED_PER_VAL > 1:
# K/V are quantized, load quantization coefficients and dequantize
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
k_scale_shift = tl.load(
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
)
v_scale_shift = tl.load(
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
)
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
k_t = dequantize(
tl.trans(k),
tl.trans(k_scale),
tl.trans(k_shift),
PACKED_PER_VAL,
).to(dtype)
k = tl.trans(k_t)
return k, v
@triton.jit
def cast_uint32_to_half2(scale_shift):
"""Extract two float16 packed into one int32"""
scale = scale_shift & 0xFFFF
shift = scale_shift >> 16
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
return scale, shift
@triton.jit
def dequantize(
x_,
scale,
shift,
PACKED_PER_VAL: tl.constexpr = 8,
):
"""PACKED_PER_VAL is the number of values packed into each element x_.
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
"""
# Axis along which offsets are applied matters here
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
# and expand K/V to that shape before applying offsets
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
# on the implementation details which might change in the future.
# Ideally we would like to use tl.reshape, but it's not implemented yet.
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
# scale: (BLOCK_N, 1)
# offsets: (PACKED_PER_VAL,)
BLOCK_N: tl.constexpr = x_.shape[0]
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
offsets = tl.arange(0, PACKED_PER_VAL) * 4
quant_offset = (
x_[:, None, :] >> offsets[None, :, None]
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
quant_offset = tl.view(
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
)
# Trick - instead of converting int4 to float16 we view it as float16
# and then multiply by 32768 * 512 == 2**24
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
quant_offset = (quant_offset * 32768.0).to(tl.float16)
scale_512 = scale * 512
dequant = quant_offset * scale_512 + shift
return dequant
@triton.jit
def _splitK_reduce(
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Out, # [B, H, M, K]
LSE, # [B, H, M]
split_k,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
stride_oz,
stride_oh,
stride_og,
stride_om,
stride_ok,
stride_lse_zhg,
stride_lse_m,
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
):
off_zhg = tl.program_id(0)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
Out_splitK_ptr = (
Out_splitK
+ stride_osk_zhg * off_zhg
+ stride_osk_m * off_m
+ tl.arange(0, BLOCK_SIZE)
)
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
m = tl.load(Metadata_ptr)
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(Out_splitK_ptr)
for split_k_idx in range(1, split_k):
Metadata_ptr = Metadata_ptr + stride_ms
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
m_k = tl.load(Metadata_ptr)
l_k = tl.load(Metadata_ptr + stride_m2)
acc_k = tl.load(Out_splitK_ptr)
m_new = tl.maximum(m, m_k)
if m_k < m:
# Scale incoming values
alpha = tl.math.exp2(m_k - m_new)
acc_k = acc_k * alpha
l_k = l_k * alpha
else:
# Scale our values
alpha = tl.math.exp2(m - m_new)
acc = acc * alpha
l_sum = l_sum * alpha
m = m_new
l_sum = l_sum + l_k
acc = acc + acc_k
acc = acc / l_sum
Out_ptr = (
Out
+ stride_oz * off_z
+ stride_oh * off_h
+ stride_og * off_g
+ stride_om * off_m
+ tl.arange(0, BLOCK_SIZE)
)
tl.store(Out_ptr, acc)
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
else:
_fwd_kernel_splitK = None
_splitK_reduce = None
@register_operator
class FwOp(AttentionFwOpBase):
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
Quantized path will be taken if input K/V have type int32.
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
Quantization coefficients (scale and shift) are represented as two
float16 constants per group, packed into int32. Quantization coefficients of
all groups are placed at the beginning of the row. So, if unquantized K/V have head
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
and dtype int32.
Pseudocode for dequantizing one row can look like:
group_size = D // 8
for i in range(NUM_GROUPS):
group_start = NUM_GROUPS + i * group_size
group_quant = K[..., group_start: group_start + group_size]
scale, shift = unpack_int32_into_float16x2(group_quant[0])
group_dequant = group_quant[..., 1:] * scale + shift
...
"""
OPERATOR = _fwd_kernel_splitK
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES = {
torch.half,
torch.bfloat16,
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
SUPPORTED_MAX_K = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
BlockDiagonalCausalWithOffsetPaddedKeysMask,
}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_BMGHK = True
NAME = "triton_splitKF"
SPLIT_K: Optional[int] = None
BLOCK_M = 16
BLOCK_N = 64
NUM_GROUPS = 1 # Default quantization is row-wise
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
if K not in {16, 32, 64, 128}:
reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
if d.key.dtype != torch.int32:
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
# Has only been tested on 8.0 / 9.0.
if torch.cuda.get_device_capability(d.device) < (8, 0):
reasons.append(
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
)
q_len = d.query.shape[1]
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
seqinfo = d.attn_bias.q_seqinfo
if q_len != seqinfo.seqstart_py[-1]:
reasons.append(
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
)
q_len = seqinfo.min_seqlen
if q_len != seqinfo.max_seqlen:
reasons.append(
"Variable query len is not supported in the presence of causal mask."
)
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
reasons.append("multiquery is only supported with query seqlen=1")
if d.attn_bias is not None and q_len > 1:
reasons.append(
"query with seqlen > 1 is not supported in the presence of causal mask"
)
return reasons
@classmethod
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = B * H
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
while split_k > 0 and Mk / split_k < max_chunk_size:
split_k = split_k // 2
split_k = min(split_k, 64)
split_k = max(split_k, 1)
return split_k
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
attn_bias = inp.attn_bias
seq_len = None
q, k, v = inp.get_qkv_in_bmghk()
if attn_bias is not None:
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
# TODO: do we really need to do this cast? seems fishy but
# I just copied it from the decoder.py
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
seq_len = attn_bias.k_seqinfo.seqlen
B = len(seq_len)
G, H, Kq = q.shape[-3:]
Kkv = v.shape[-1]
# assume kv has been padded
q = q.reshape(B, -1, G, H, Kq)
k = k.reshape(B, -1, G, H, Kkv)
v = v.reshape(B, -1, G, H, Kkv)
# Transpose in the case of MQA/GQA
mqa_swap_seqlen_head = False
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
mqa_swap_seqlen_head = True
assert q.shape[1] == 1
q = q.transpose(1, 3)
k = k[:, :, :, :1]
v = v[:, :, :, :1]
if k.dtype == torch.int32:
# Quantized K/V
PACKED_PER_VAL = 8
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
else:
Lk = k.shape[-1]
PACKED_PER_VAL = 1
B, Mk, G, H, Kkv = k.shape
B, M, G, H, Kq = q.shape
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
BLOCK_M = cls.BLOCK_M
BLOCK_N = cls.BLOCK_N
if cls.SPLIT_K is not None:
split_k = cls.SPLIT_K
else:
# Use heuristics
split_k = cls.get_split_k(B, H, Mk)
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
o_splitk = torch.empty(
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
)
metadata = torch.empty(
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
)
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
num_warps = 2
split_size = (Mk + split_k - 1) // split_k
use_seq_len = seq_len is not None
_fwd_kernel_splitK_unrolled = unroll_varargs(
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
)
_fwd_kernel_splitK_unrolled[grid](
Q=q,
K=k,
V=v,
sm_scale=inp.scale_float,
Out_splitK=o_splitk,
Metadata=metadata,
Seq_len=seq_len,
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
Z=B,
H=H,
G=G,
N_CTX_Q=M,
N_CTX_K=Mk,
BLOCK_N_PER_SPLIT=split_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
USE_SEQ_LEN=use_seq_len,
num_warps=num_warps,
num_stages=1,
PACKED_PER_VAL=PACKED_PER_VAL,
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
)
if mqa_swap_seqlen_head:
out = torch.empty(
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
).transpose(1, 3)
else:
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
# Merge together
grid = (B * G * H, M, 1)
_splitK_reduce[grid](
o_splitk,
metadata,
out,
lse,
split_k=split_k,
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(out, "oz", "om", "og", "oh", "ok"),
**_strides(lse, "lse_zhg", "lse_m"),
BLOCK_SIZE=out.shape[-1],
G=G,
H=H,
# TODO: Tune num_warps
)
lse = lse.reshape([B, G, H, M])
if mqa_swap_seqlen_head:
# H/M dimensions have been swapped
out = out.transpose(1, 3)
lse = lse.transpose(2, 3)
if inp.query.ndim == 4:
# BMGHK -> BMHK
assert G == 1
out = out[:, :, 0]
lse = lse[:, 0]
return out, Context(out=out, lse=lse)
class FwOp_S1(FwOp):
SPLIT_K = 1
NAME = "triton_splitK1"
class FwOp_S2(FwOp):
SPLIT_K = 2
NAME = "triton_splitK2"
class FwOp_S4(FwOp):
SPLIT_K = 4
NAME = "triton_splitK4"
class FwOp_S8(FwOp):
SPLIT_K = 8
NAME = "triton_splitK8"
class FwOp_S16(FwOp):
SPLIT_K = 16
NAME = "triton_splitK16"
class FwOp_S32(FwOp):
SPLIT_K = 32
NAME = "triton_splitK32"
class FwOp_S64(FwOp):
SPLIT_K = 64
NAME = "triton_splitK64"
class FwOp_S128(FwOp):
SPLIT_K = 128
NAME = "triton_splitK128"