First commit
This commit is contained in:
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple, Type, Union
|
||||
import os
|
||||
import torch
|
||||
|
||||
from . import cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
AttentionOp,
|
||||
AttentionOpBase,
|
||||
AttentionOpDispatch,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
bmk2bmhk,
|
||||
)
|
||||
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
|
||||
|
||||
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
|
||||
TritonFlashAttentionOp = (triton.FwOp, triton.BwOp)
|
||||
|
||||
|
||||
class _fMHA(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
|
||||
inp = Inputs(*args)
|
||||
op_fw = op[0] if op is not None else None
|
||||
op_bw = op[1] if op is not None else None
|
||||
|
||||
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
inp=inp, op=op_fw
|
||||
)
|
||||
|
||||
# Saving attn_bias is a bit complicated, as the
|
||||
# torch part should go in `save_for_backward`
|
||||
if isinstance(inp.attn_bias, torch.Tensor):
|
||||
attn_bias_tensor = inp.attn_bias
|
||||
attn_bias_ctx = None
|
||||
else:
|
||||
attn_bias_tensor = None
|
||||
attn_bias_ctx = inp.attn_bias
|
||||
|
||||
ctx.save_for_backward(
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
op_ctx.q_padded,
|
||||
op_ctx.k_padded,
|
||||
op_ctx.v_padded,
|
||||
op_ctx.o_padded,
|
||||
op_ctx.out,
|
||||
op_ctx.lse,
|
||||
)
|
||||
ctx.rng_state = op_ctx.rng_state
|
||||
ctx.attn_bias_tensor = attn_bias_tensor
|
||||
if op_ctx.op_bw is not None:
|
||||
if op_bw is not None and op_bw is not op_ctx.op_bw:
|
||||
raise ValueError(
|
||||
f"Specified op_bw={op_bw.NAME}, but forward op "
|
||||
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
|
||||
)
|
||||
op_bw = op_ctx.op_bw
|
||||
ctx.op_fw = op_fw
|
||||
ctx.op_bw = op_bw
|
||||
ctx.p = inp.p
|
||||
ctx.use_alibi = inp.use_alibi
|
||||
ctx.alibi_mode = inp.alibi_mode
|
||||
ctx.imp_mode = inp.imp_mode
|
||||
|
||||
ctx.scale = inp.scale
|
||||
ctx.attn_bias_ctx = attn_bias_ctx
|
||||
ctx.n_args = len(args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def deserialize_bias(
|
||||
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
|
||||
) -> Any:
|
||||
if attn_bias_tensor is None:
|
||||
return attn_bias_ctx
|
||||
return attn_bias_tensor
|
||||
|
||||
@classmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(cls, ctx, grad):
|
||||
# Re-create context
|
||||
query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors
|
||||
attn_bias_tensor = ctx.attn_bias_tensor
|
||||
rng_state = ctx.rng_state
|
||||
inp = Inputs(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
|
||||
p=ctx.p,
|
||||
scale=ctx.scale,
|
||||
use_alibi=ctx.use_alibi,
|
||||
alibi_mode=ctx.alibi_mode,
|
||||
imp_mode = ctx.imp_mode,
|
||||
)
|
||||
op_ctx = Context(
|
||||
lse=lse,
|
||||
out=out,
|
||||
q_padded=q_padded,
|
||||
k_padded=k_padded,
|
||||
v_padded=v_padded,
|
||||
o_padded=o_padded,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
grads = _memory_efficient_attention_backward(
|
||||
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
|
||||
)
|
||||
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
|
||||
ctx.n_args - 2
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[AttentionOp] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Implements the memory-efficient attention mechanism following
|
||||
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
|
||||
|
||||
:Inputs shape:
|
||||
|
||||
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
|
||||
the sequence length, H the number of heads, and K the embeding size per head
|
||||
|
||||
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
|
||||
|
||||
- Inputs can also be of dimension 5 with GQA - see note below
|
||||
|
||||
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
|
||||
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = query @ key.transpose(-2, -1)
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
attn = F.dropout(attn, p)
|
||||
return attn @ value
|
||||
|
||||
:Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
# Compute regular attention
|
||||
y = xops.memory_efficient_attention(q, k, v)
|
||||
|
||||
# With a dropout of 0.2
|
||||
y = xops.memory_efficient_attention(q, k, v, p=0.2)
|
||||
|
||||
# Causal attention
|
||||
y = xops.memory_efficient_attention(
|
||||
q, k, v,
|
||||
attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
:Supported hardware:
|
||||
|
||||
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
|
||||
|
||||
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
|
||||
|
||||
MQA/GQA is an experimental feature supported only for the forward pass.
|
||||
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
|
||||
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
|
||||
``H`` is the number of heads per group (8 in the example).
|
||||
|
||||
Please note that xFormers will not automatically broadcast the inputs, so you will need
|
||||
to broadcast it manually before calling `memory_efficient_attention`.
|
||||
|
||||
:GQA/MQA example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import xformers.ops as xops
|
||||
|
||||
B, M, K = 3, 32, 128
|
||||
kwargs = dict(device="cuda", dtype=torch.float16)
|
||||
q = torch.randn([B, M, 8, K], **kwargs)
|
||||
k = torch.randn([B, M, 2, K], **kwargs)
|
||||
v = torch.randn([B, M, 2, K], **kwargs)
|
||||
out_gqa = xops.memory_efficient_attention(
|
||||
q.reshape([B, M, 2, 4, K]),
|
||||
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
)
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if there is no operator available to compute the MHA
|
||||
ValueError: if inputs are invalid
|
||||
|
||||
:parameter query: Tensor of shape ``[B, Mq, H, K]``
|
||||
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
|
||||
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
|
||||
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
|
||||
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
|
||||
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
|
||||
:parameter p: Dropout probability. Disabled if set to ``0.0``
|
||||
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
|
||||
scale (q.shape[-1]**-0.5) will be used.
|
||||
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
|
||||
If set to ``None`` (recommended), xFormers \
|
||||
will dispatch to the best available operator, depending on the inputs \
|
||||
and options.
|
||||
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
|
||||
"""
|
||||
return _memory_efficient_attention(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
|
||||
"""
|
||||
return _memory_efficient_attention_forward(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward_requires_grad(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
|
||||
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
return out, ctx.lse
|
||||
|
||||
|
||||
def memory_efficient_attention_backward(
|
||||
grad: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionBwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the gradient of the attention.
|
||||
Returns a tuple (dq, dk, dv)
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
|
||||
`lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
gradients = _memory_efficient_attention_backward(
|
||||
Context(out=output, lse=lse),
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
grad,
|
||||
op=op,
|
||||
)
|
||||
return (gradients.dq, gradients.dk, gradients.dv)
|
||||
|
||||
|
||||
def _memory_efficient_attention(
|
||||
inp: Inputs, op: Optional[AttentionOp] = None
|
||||
) -> torch.Tensor:
|
||||
# fast-path that doesn't require computing the logsumexp for backward computation
|
||||
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
|
||||
return _memory_efficient_attention_forward(
|
||||
inp, op=op[0] if op is not None else None
|
||||
)
|
||||
|
||||
output_shape = inp.normalize_bmhk()
|
||||
return _fMHA.apply(
|
||||
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode
|
||||
).reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> torch.Tensor:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, False)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
|
||||
out, *_ = op.apply(inp, needs_gradient=False)
|
||||
return out.reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward_requires_grad(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> Tuple[torch.Tensor, Context]:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, True)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
out = op.apply(inp, needs_gradient=True)
|
||||
assert out[1] is not None
|
||||
return (out[0].reshape(output_shape), out[1])
|
||||
|
||||
|
||||
def _memory_efficient_attention_backward(
|
||||
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
|
||||
) -> Gradients:
|
||||
"""Warning: grad/ctx.out is potentially in BMK format"""
|
||||
inp.validate_inputs()
|
||||
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
|
||||
raise ValueError(
|
||||
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
|
||||
f"grad.shape : {grad.shape} \n"
|
||||
f"out.shape : {ctx.out.shape} \n"
|
||||
f"query.shape: {inp.query.shape}"
|
||||
)
|
||||
shape_dq, shape_dk, shape_dv = tuple(
|
||||
x.shape for x in (inp.query, inp.key, inp.value)
|
||||
)
|
||||
inp.normalize_bmhk()
|
||||
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
|
||||
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
if (
|
||||
ctx.lse.ndim != 3
|
||||
# Dim 0
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.query.shape[0]
|
||||
)
|
||||
or (
|
||||
isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
|
||||
)
|
||||
# Dim 1
|
||||
or ctx.lse.shape[1] != inp.query.shape[2]
|
||||
# Dim 2
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[2] < inp.query.shape[1]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Input tensors have incompatible shapes."
|
||||
f"lse.shape : {ctx.lse.shape} \n"
|
||||
f"query.shape : {inp.query.shape}"
|
||||
)
|
||||
grad = bmk2bmhk(grad, 1)
|
||||
ctx.out = bmk2bmhk(ctx.out, 1)
|
||||
|
||||
if op is None:
|
||||
op = _dispatch_bw(inp)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(
|
||||
ValueError, "memory_efficient_attention_backward", op, inp
|
||||
)
|
||||
|
||||
grads = op.apply(ctx, inp, grad)
|
||||
grads.dq = grads.dq.reshape(shape_dq)
|
||||
grads.dk = grads.dk.reshape(shape_dk)
|
||||
grads.dv = grads.dv.reshape(shape_dv)
|
||||
return grads
|
||||
|
||||
|
||||
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
|
||||
cutlass.FwOp,
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
small_k.FwOp,
|
||||
triton_splitk.FwOp,
|
||||
]
|
||||
|
||||
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
|
||||
cutlass.BwOp,
|
||||
flash.BwOp,
|
||||
triton.BwOp,
|
||||
small_k.BwOp,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"AttentionBias",
|
||||
"AttentionOp",
|
||||
"AttentionOpBase",
|
||||
"AttentionOpDispatch",
|
||||
"LowerTriangularMask",
|
||||
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionTritonFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionCutlassOp",
|
||||
"MemoryEfficientAttentionFlashAttentionOp",
|
||||
"MemoryEfficientAttentionOp",
|
||||
"TritonFlashAttentionOp",
|
||||
"memory_efficient_attention",
|
||||
"ALL_FW_OPS",
|
||||
"ALL_BW_OPS",
|
||||
]
|
||||
Reference in New Issue
Block a user