First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,97 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .fmha import (
AttentionBias,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
LowerTriangularMask,
MemoryEfficientAttentionCutlassFwdFlashBwOp,
MemoryEfficientAttentionCutlassOp,
MemoryEfficientAttentionFlashAttentionOp,
MemoryEfficientAttentionOp,
MemoryEfficientAttentionTritonFwdFlashBwOp,
TritonFlashAttentionOp,
memory_efficient_attention,
memory_efficient_attention_backward,
memory_efficient_attention_forward,
memory_efficient_attention_forward_requires_grad,
)
from .indexing import index_select_cat, scaled_index_add
from .rmsnorm import RMSNorm
from .rope_padded import rope_padded
from .swiglu_op import (
SwiGLU,
SwiGLUEagerOp,
SwiGLUFusedOp,
SwiGLUOp,
SwiGLUOpDispatch,
SwiGLUPackedFusedOp,
swiglu,
)
from .unbind import get_stack_strides, stack_or_none, unbind
# BW compatibility
AttentionMask = AttentionBias
def masked_matmul(a, b, mask=None):
if torch.overrides.has_torch_function((a, b, mask)):
return torch.overrides.handle_torch_function(
masked_matmul, (a, b, mask), a, b, mask
)
att = a @ b
if mask is None:
return att
if mask.dtype == torch.bool:
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
att += mask
return att
__all__ = [
"memory_efficient_attention",
"AttentionBias",
"AttentionMask",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"memory_efficient_attention_backward",
"memory_efficient_attention_forward",
"memory_efficient_attention_forward_requires_grad",
"RMSNorm",
"SwiGLU",
"SwiGLUEagerOp",
"SwiGLUFusedOp",
"SwiGLUOp",
"SwiGLUOpDispatch",
"SwiGLUPackedFusedOp",
"swiglu",
"TritonFlashAttentionOp",
"unbind",
"stack_or_none",
"get_stack_strides",
"masked_matmul",
"scaled_index_add",
"index_select_cat",
"rope_padded",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

133
pkgs/xformers/ops/common.py Normal file
View File

@@ -0,0 +1,133 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import os
from typing import Any, Dict, List, Type, TypeVar
import torch
from torch.torch_version import TorchVersion
def get_operator(library: str, name: str):
def no_such_operator(*args, **kwargs):
raise RuntimeError(
f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?"
)
try:
return getattr(getattr(torch.ops, library), name)
except (RuntimeError, AttributeError):
return no_such_operator
def get_xformers_operator(name: str):
return get_operator("xformers", name)
class BaseOperator:
OPERATOR: Any
NAME: str
OPERATOR_CATEGORY: str
@classmethod
def is_available(cls) -> bool:
if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator":
return False
return True
@classmethod
def operator_flop(cls, *inputs) -> int:
"""Calculate number of FLOP given inputs to `OPERATOR`"""
return -1
OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
ClsT = TypeVar("ClsT")
def register_operator(cls: ClsT) -> ClsT:
global OPERATORS_REGISTRY, FUNC_TO_XFORMERS_OPERATOR
OPERATORS_REGISTRY.append(cls) # type: ignore
FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore
return cls
# post-2.0, avoids a warning
# (`torch.Tensor.storage` will also be deleted in the future)
_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None)
if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist
_GET_TENSOR_STORAGE = torch.Tensor.storage
def _get_storage_base(x: torch.Tensor) -> int:
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore
def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
from .. import get_python_lib
def render_arg_type(annotation) -> str:
if annotation is torch.Tensor:
return "Tensor"
if annotation is bool:
return "bool"
if annotation is int:
return "int"
if annotation is List[int]:
return "int[]"
if annotation is List[torch.Tensor]:
return "Tensor[]"
assert False, f"Unable to parse annotation: `{annotation}`"
sign = inspect.signature(fn) # type: ignore
arguments = [
f"{render_arg_type(arg.annotation)} {arg.name}"
for arg in sign.parameters.values()
]
op_name = fn.__name__ # type: ignore
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"
xformers_lib = get_python_lib()
xformers_lib.define(definition)
xformers_lib.impl(op_name, fn, "CUDA")
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)
def wrapper(*args, **kwargs):
return dispatcher_impl(*args, **kwargs)
return wrapper # type: ignore
def _has_a_version_of_triton():
if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1":
return False
if not torch.cuda.is_available():
return False
try:
import triton # noqa: F401
except ImportError:
return False
return True
def _has_triton2():
if not _has_a_version_of_triton():
return False
import triton
tv = TorchVersion(triton.__version__)
return tv >= (2, 1) or tv == (2, 0)
def _has_triton21():
if not _has_a_version_of_triton():
return False
import triton
tv = TorchVersion(triton.__version__)
return tv >= (2, 1)

View File

@@ -0,0 +1,474 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional, Sequence, Tuple, Type, Union
import os
import torch
from . import cutlass, decoder, flash, small_k, triton, triton_splitk
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp)
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
TritonFlashAttentionOp = (triton.FwOp, triton.BwOp)
class _fMHA(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
inp = Inputs(*args)
op_fw = op[0] if op is not None else None
op_bw = op[1] if op is not None else None
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
inp=inp, op=op_fw
)
# Saving attn_bias is a bit complicated, as the
# torch part should go in `save_for_backward`
if isinstance(inp.attn_bias, torch.Tensor):
attn_bias_tensor = inp.attn_bias
attn_bias_ctx = None
else:
attn_bias_tensor = None
attn_bias_ctx = inp.attn_bias
ctx.save_for_backward(
inp.query,
inp.key,
inp.value,
op_ctx.q_padded,
op_ctx.k_padded,
op_ctx.v_padded,
op_ctx.o_padded,
op_ctx.out,
op_ctx.lse,
)
ctx.rng_state = op_ctx.rng_state
ctx.attn_bias_tensor = attn_bias_tensor
if op_ctx.op_bw is not None:
if op_bw is not None and op_bw is not op_ctx.op_bw:
raise ValueError(
f"Specified op_bw={op_bw.NAME}, but forward op "
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
)
op_bw = op_ctx.op_bw
ctx.op_fw = op_fw
ctx.op_bw = op_bw
ctx.p = inp.p
ctx.use_alibi = inp.use_alibi
ctx.alibi_mode = inp.alibi_mode
ctx.imp_mode = inp.imp_mode
ctx.scale = inp.scale
ctx.attn_bias_ctx = attn_bias_ctx
ctx.n_args = len(args)
return out
@staticmethod
def deserialize_bias(
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
) -> Any:
if attn_bias_tensor is None:
return attn_bias_ctx
return attn_bias_tensor
@classmethod
@torch.autograd.function.once_differentiable
def backward(cls, ctx, grad):
# Re-create context
query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors
attn_bias_tensor = ctx.attn_bias_tensor
rng_state = ctx.rng_state
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
p=ctx.p,
scale=ctx.scale,
use_alibi=ctx.use_alibi,
alibi_mode=ctx.alibi_mode,
imp_mode = ctx.imp_mode,
)
op_ctx = Context(
lse=lse,
out=out,
q_padded=q_padded,
k_padded=k_padded,
v_padded=v_padded,
o_padded=o_padded,
rng_state=rng_state,
)
grads = _memory_efficient_attention_backward(
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
)
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
ctx.n_args - 2
)
def memory_efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[AttentionOp] = None,
) -> torch.Tensor:
"""Implements the memory-efficient attention mechanism following
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
:Inputs shape:
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
the sequence length, H the number of heads, and K the embeding size per head
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
- Inputs can also be of dimension 5 with GQA - see note below
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
:Equivalent pytorch code:
.. code-block:: python
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value
:Examples:
.. code-block:: python
import xformers.ops as xops
# Compute regular attention
y = xops.memory_efficient_attention(q, k, v)
# With a dropout of 0.2
y = xops.memory_efficient_attention(q, k, v, p=0.2)
# Causal attention
y = xops.memory_efficient_attention(
q, k, v,
attn_bias=xops.LowerTriangularMask()
)
:Supported hardware:
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
MQA/GQA is an experimental feature supported only for the forward pass.
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
``H`` is the number of heads per group (8 in the example).
Please note that xFormers will not automatically broadcast the inputs, so you will need
to broadcast it manually before calling `memory_efficient_attention`.
:GQA/MQA example:
.. code-block:: python
import torch
import xformers.ops as xops
B, M, K = 3, 32, 128
kwargs = dict(device="cuda", dtype=torch.float16)
q = torch.randn([B, M, 8, K], **kwargs)
k = torch.randn([B, M, 2, K], **kwargs)
v = torch.randn([B, M, 2, K], **kwargs)
out_gqa = xops.memory_efficient_attention(
q.reshape([B, M, 2, 4, K]),
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
)
Raises:
NotImplementedError: if there is no operator available to compute the MHA
ValueError: if inputs are invalid
:parameter query: Tensor of shape ``[B, Mq, H, K]``
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
:parameter p: Dropout probability. Disabled if set to ``0.0``
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
scale (q.shape[-1]**-0.5) will be used.
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
If set to ``None`` (recommended), xFormers \
will dispatch to the best available operator, depending on the inputs \
and options.
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
"""
return _memory_efficient_attention(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
def memory_efficient_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
) -> torch.Tensor:
"""
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
"""
return _memory_efficient_attention_forward(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
def memory_efficient_attention_forward_requires_grad(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionFwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
out, ctx = _memory_efficient_attention_forward_requires_grad(
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
op=op,
)
return out, ctx.lse
def memory_efficient_attention_backward(
grad: torch.Tensor,
output: torch.Tensor,
lse: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
use_alibi: bool = False,
alibi_mode: int = 1,
imp_mode: int = 0,
*,
op: Optional[Type[AttentionBwOpBase]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the gradient of the attention.
Returns a tuple (dq, dk, dv)
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
`lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
"""
if p != 0.0:
raise NotImplementedError(
"dropout is not supported on the non-autograd API."
" If you want to use dropout, please call `memory_efficient_attention` directly"
)
gradients = _memory_efficient_attention_backward(
Context(out=output, lse=lse),
Inputs(
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
),
grad,
op=op,
)
return (gradients.dq, gradients.dk, gradients.dv)
def _memory_efficient_attention(
inp: Inputs, op: Optional[AttentionOp] = None
) -> torch.Tensor:
# fast-path that doesn't require computing the logsumexp for backward computation
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
return _memory_efficient_attention_forward(
inp, op=op[0] if op is not None else None
)
output_shape = inp.normalize_bmhk()
return _fMHA.apply(
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode
).reshape(output_shape)
def _memory_efficient_attention_forward(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> torch.Tensor:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, False)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out, *_ = op.apply(inp, needs_gradient=False)
return out.reshape(output_shape)
def _memory_efficient_attention_forward_requires_grad(
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
) -> Tuple[torch.Tensor, Context]:
inp.validate_inputs()
output_shape = inp.normalize_bmhk()
if op is None:
op = _dispatch_fw(inp, True)
else:
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
out = op.apply(inp, needs_gradient=True)
assert out[1] is not None
return (out[0].reshape(output_shape), out[1])
def _memory_efficient_attention_backward(
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
) -> Gradients:
"""Warning: grad/ctx.out is potentially in BMK format"""
inp.validate_inputs()
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
raise ValueError(
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
f"grad.shape : {grad.shape} \n"
f"out.shape : {ctx.out.shape} \n"
f"query.shape: {inp.query.shape}"
)
shape_dq, shape_dk, shape_dv = tuple(
x.shape for x in (inp.query, inp.key, inp.value)
)
inp.normalize_bmhk()
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
if (
ctx.lse.ndim != 3
# Dim 0
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.query.shape[0]
)
or (
isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
)
# Dim 1
or ctx.lse.shape[1] != inp.query.shape[2]
# Dim 2
or (
not isinstance(inp.attn_bias, BlockDiagonalMask)
and ctx.lse.shape[2] < inp.query.shape[1]
)
):
raise ValueError(
"Input tensors have incompatible shapes."
f"lse.shape : {ctx.lse.shape} \n"
f"query.shape : {inp.query.shape}"
)
grad = bmk2bmhk(grad, 1)
ctx.out = bmk2bmhk(ctx.out, 1)
if op is None:
op = _dispatch_bw(inp)
else:
_ensure_op_supports_or_raise(
ValueError, "memory_efficient_attention_backward", op, inp
)
grads = op.apply(ctx, inp, grad)
grads.dq = grads.dq.reshape(shape_dq)
grads.dk = grads.dk.reshape(shape_dk)
grads.dv = grads.dv.reshape(shape_dv)
return grads
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
cutlass.FwOp,
flash.FwOp,
triton.FwOp,
small_k.FwOp,
triton_splitk.FwOp,
]
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
cutlass.BwOp,
flash.BwOp,
triton.BwOp,
small_k.BwOp,
]
__all__ = [
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionTritonFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
"MemoryEfficientAttentionFlashAttentionOp",
"MemoryEfficientAttentionOp",
"TritonFlashAttentionOp",
"memory_efficient_attention",
"ALL_FW_OPS",
"ALL_BW_OPS",
]

View File

@@ -0,0 +1,779 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
import torch
class AttentionBias:
"""Base class for a custom bias that can be applied \
as the attn_bias argument in
:attr:`xformers.ops.memory_efficient_attention`.
That function has the ability to add a tensor, the
attention bias, to the QK^T matrix before it is used
in the softmax part of the attention calculation.
The attention bias tensor with shape
(B or 1, n_queries, number of keys)
can be given as the attn_bias input.
The most common use case is for an attention bias is
to contain only zeros and negative infinities, which forms
a mask so that some queries only attend to some keys.
Children of this class define alternative things which can
be used as the attn_bias input to define an attention bias which
forms such a mask, for some common cases.
When using an :attr:`xformers.ops.AttentionBias`
instead of a :attr:`torch.Tensor`, the mask matrix does
not need to be materialized, and can be
hardcoded into some kernels for better performance.
See:
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
"""
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""
Materializes the bias as a `torch.Tensor`. This is very slow
and we don't attempt to make it fast. Only use for debugging/testing.
Shape should be like `[*, q_seqlen, k_seqlen]`
"""
raise NotImplementedError()
class LowerTriangularMask(AttentionBias):
"""
A lower-triangular (aka causal) mask
A query Q cannot attend to a key which is farther from the
initial key than Q is from the initial query.
"""
def __init__(self, *tensor_args, **tensor_kwargs) -> None:
# NOTE: Unused arguments, we keep them for backward compatibility
super().__init__()
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
return LowerTriangularMaskWithTensorBias(bias)
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
"""A lower-triangular (aka causal) mask with an additive bias"""
def __init__(self, bias: torch.Tensor) -> None:
self._bias = bias
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return super().materialize(shape, dtype=dtype, device=device) + self._bias
@dataclass
class _SeqLenInfo:
"""
(Internal) Represents the division of a dimension into blocks.
For example, to represents a dimension of length 7 divided into
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 2, 5, 7]
seqstart: torch.IntTensor([0, 2, 5, 7])
"""
seqstart: torch.Tensor
max_seqlen: int
min_seqlen: int
seqstart_py: List[int]
def to(self, device: torch.device) -> None:
self.seqstart = self.seqstart.to(device, non_blocking=True)
def intervals(self) -> Iterable[Tuple[int, int]]:
yield from zip(self.seqstart_py, self.seqstart_py[1:])
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
"""
assert not isinstance(seqlens, torch.Tensor)
seqstart_py = [0]
max_seqlen = -1
min_seqlen = -1
for seqlen in seqlens:
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
max_seqlen = max(max_seqlen, seqlen)
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
return cls(
max_seqlen=max_seqlen,
min_seqlen=min_seqlen,
seqstart=seqstart,
seqstart_py=seqstart_py,
)
def split(
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
) -> List[torch.Tensor]:
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
raise ValueError(
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
f" seqstart: {self.seqstart_py}"
)
if batch_sizes is None:
batch_sizes = [1] * (len(self.seqstart_py) - 1)
split_chunks = []
it = 0
for batch_size in batch_sizes:
split_chunks.append(
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
)
it += batch_size
return [
tensor.reshape([bs, -1, *tensor.shape[2:]])
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
]
@dataclass
class _PaddedSeqLenInfo(_SeqLenInfo):
"""
(Internal) Represents the division of a dimension into blocks which are
padded out to the same total length.
For example, to represent a dimension of length 12 with space for
three blocks of length 4, but where the occupied lengths are
2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
The layout along the dimension is
0 ─► block 0
block 0
<space>
<space>
4 ─► block 1
block 1
block 1
<space>
8 ─► block 2
block 2
<space>
<space>
12 ─►
The members will be:
max_seqlen: 3
min_seqlen: 2
seqstart_py: [0, 4, 8, 12]
seqstart: torch.IntTensor([0, 4, 8, 12])
seqlen_py: [2, 3, 2]
seqlen: torch.IntTensor([2, 3, 2])
padding: 4
"""
seqlen: torch.Tensor
seqlen_py: Sequence[int]
padding: int
# From parent: seqstart[i] contains the start position
# of the i-th sequence
# seqstart: torch.Tensor
def __post_init__(self) -> None:
assert len(self.seqstart_py) == len(self.seqlen_py) + 1
def to(self, device: torch.device) -> None:
self.seqlen = self.seqlen.to(device, non_blocking=True)
super().to(device)
def intervals(self) -> Iterable[Tuple[int, int]]:
for (start, _), length in zip(super().intervals(), self.seqlen_py):
yield start, start + length
@classmethod
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
raise RuntimeError(
"Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
)
@classmethod
def from_seqlens_padded(
cls, seqlens: Sequence[int], padding: int
) -> "_PaddedSeqLenInfo":
"""
Input tensors are assumed to be in shape [B, M, *]
seqstart = padding * torch.arange(batch_size)
"""
assert not isinstance(seqlens, torch.Tensor)
assert all(seqlen <= padding for seqlen in seqlens)
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
return cls(
seqlen=torch.tensor(seqlens, dtype=torch.int32),
seqlen_py=seqlens,
max_seqlen=max(seqlens),
min_seqlen=min(seqlens),
seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
seqstart_py=seqstart_py,
padding=padding,
)
def split(
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
) -> List[torch.Tensor]:
raise NotImplementedError("_PaddedSeqLenInfo.split")
@dataclass
class BlockDiagonalMask(AttentionBias):
"""
A block-diagonal mask that can be passed as ``attn_bias``
argument to :attr:`xformers.ops.memory_efficient_attention`.
Queries and Keys are each divided into the same number of blocks.
Queries in block i only attend to keys in block i.
.. figure:: /_static/block_diag_bias.png
This bias can be used to handle a batch of sequences of
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
:Example:
.. code-block:: python
import torch
from xformers.ops import fmha
K = 16
dtype = torch.float16
device = "cuda"
list_x = [
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
]
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
list_out = attn_bias.split(out)
print(list_out[0].shape) # [1, 3, 1, K]
assert tuple(list_out[0].shape) == (1, 3, 1, K)
"""
q_seqinfo: _SeqLenInfo
k_seqinfo: _SeqLenInfo
_batch_sizes: Optional[Sequence[int]] = None
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return torch.zeros(
shape,
dtype=dtype,
device=device,
)
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""Materialize the attention bias - for debugging & testing"""
assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
shape[-1],
self.k_seqinfo.seqstart_py[-1],
)
assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
shape[-2],
self.q_seqinfo.seqstart_py[-1],
)
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
mask.fill_(-math.inf)
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
dtype=dtype,
device=device,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(
cls,
q_seqlen: Sequence[int],
kv_seqlen: Optional[Sequence[int]] = None,
) -> "BlockDiagonalMask":
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
Args:
q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
(Defaults to ``q_seqlen``.)
Returns:
BlockDiagonalMask
"""
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
if kv_seqlen is None or q_seqlen == kv_seqlen:
k_seqinfo = q_seqinfo
else:
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
@classmethod
def from_tensor_list(
cls,
tensors: Sequence[torch.Tensor],
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
concatenated on the sequence length dimension
.. figure:: /_static/block_diag_cat_split.png
See also :attr:`BlockDiagonalMask.split` to split the returned
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
Args:
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
All tensors should have the same dimension and the same batch size ``B``, but
they can have different sequence length ``M``.
Returns:
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
"""
batch_sizes = [tensor.shape[0] for tensor in tensors]
seqlens = []
for x in tensors:
for _ in range(x.shape[0]):
seqlens.append(x.shape[1])
block_diag = cls.from_seqlens(seqlens)
block_diag._batch_sizes = batch_sizes
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
concat_tensors = torch.cat(tensors_bs1, dim=1)
return block_diag, concat_tensors
@classmethod
def from_tensor_lists_qkv(
cls,
tensors_q: Sequence[torch.Tensor],
tensors_k: Sequence[torch.Tensor],
tensors_v: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert len(tensors_q) == len(tensors_k)
assert tensors_v is None or len(tensors_v) == len(tensors_q)
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
q_seqlens, kv_seqlens = [], []
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
assert q.shape[0] == k.shape[0]
q_seqlens += [q.shape[1]] * q.shape[0]
kv_seqlens += [k.shape[1]] * k.shape[0]
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
block_diag._batch_sizes = batch_sizes
return (
block_diag,
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
if tensors_v is not None
else None,
)
def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.q_seqinfo.split(tensor, self._batch_sizes)
def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
return self.k_seqinfo.split(tensor, self._batch_sizes)
def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
Args:
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
Returns:
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
"""
assert self.q_seqinfo is self.k_seqinfo
return self.q_seqinfo.split(tensor, self._batch_sizes)
def make_causal(self) -> "BlockDiagonalCausalMask":
"""Makes each block causal"""
return BlockDiagonalCausalMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)
def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
"""Makes each block causal with a possible non-causal prefix"""
return BlockDiagonalCausalFromBottomRightMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)
def make_local_attention(
self, window_size: int
) -> "BlockDiagonalCausalLocalAttentionMask":
"""Experimental: Makes each block causal with local attention"""
return BlockDiagonalCausalLocalAttentionMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
_window_size=window_size,
)
def make_local_attention_from_bottomright(
self, window_size: int
) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
"""Experimental: Makes each block causal with local attention, start from bottom right"""
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
_window_size=window_size,
)
@dataclass
class BlockDiagonalCausalMask(BlockDiagonalMask):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
Queries and Keys are each divided into the same number of blocks.
A query Q in block i cannot attend to a key which is not in block i,
nor one which is farther from the initial key in block i than Q
is from the initial query in block i.
"""
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return LowerTriangularMask().materialize(
shape,
dtype=dtype,
device=device,
)
@dataclass
class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
This mask allows for a non-causal prefix
NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
defined (softmax of vector of `-inf` in the attention)
Queries and keys are each divided into the same number of blocks.
A query Q in block i cannot attend to a key which is not in block i,
nor one which nearer the final key in block i than Q is to the
final query in block i.
"""
def __post_init__(self) -> None:
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
num_queries = q_end - q_start
num_keys = k_end - k_start
if num_keys < num_queries:
raise ValueError(
f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
" Expected `num_keys >= num_queries`"
)
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
num_queries, num_keys = shape[-2:]
return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
@dataclass
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
"""
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
except an offset on causality is allowed for each block and we support padding for k/v
The keys and values are divided into blocks which are padded out to
the same total length.
For example, if there is space for 12 keys, for three blocks of
max length 4, but we only want to use the first 2, 3 and 2
of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
The queries are divided into blocks, without padding, of lengths given by
q_seqlen.
A query Q in block i cannot attend to a key which is not in block i,
nor one which is not in use (i.e. in the padded area),
nor one which is nearer to the final key in block i
than Q is to the final query in block i.
"""
q_seqinfo: _SeqLenInfo
k_seqinfo: _PaddedSeqLenInfo
causal_diagonal: Any = None # unused. Exists for BC only.
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=float("-inf"),
device=device,
)
num_queries, num_keys = shape[-2:]
return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
"""Materialize the attention bias - for debugging & testing"""
if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
raise ValueError("k shapes wrong")
if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
raise ValueError("q shapes wrong")
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
mask.fill_(-math.inf)
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
self.q_seqinfo.intervals(),
self.k_seqinfo.intervals(),
)
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
dtype=dtype,
device=device,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(
cls,
q_seqlen: Sequence[int],
kv_padding: int,
kv_seqlen: Sequence[int],
causal_diagonal: Any = None,
) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
"""Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
lengths for query and key/value.
Args:
q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
kv_padding (int): Padding for k/v - also an upperbound on each individual key length
kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
causal_diagonal: unused, for BC only
Returns:
BlockDiagonalCausalWithOffsetPaddedKeysMask
"""
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
q_seqlen,
kv_seqlen,
)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
@dataclass
class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
"""
(Experimental feature)
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
This makes the mask "local" and the attention pattern banded.
Query i only attends to keys in its block and cannot attend keys further than "window_size"
from it.
"""
_window_size: int = 0 # forced due to inheritance and default arguments
def __post_init__(self):
if self._window_size <= 0:
raise ValueError(
f"Expected `window_size > 0`, but window_size={self._window_size}"
)
q_seqlen = [
y - x
for x, y in zip(
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
)
]
kv_seqlen = [
y - x
for x, y in zip(
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
)
]
for q, k in zip(q_seqlen, kv_seqlen):
if q - self._window_size >= k:
raise RuntimeError(
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
)
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=1,
device=device,
)
num_queries, num_keys = shape[-2:]
mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
if self._window_size is not None and self._window_size > 0:
mask = torch.triu(mask, diagonal=-self._window_size + 1)
mask = torch.log(mask)
return mask.to(dtype)
@dataclass
class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
BlockDiagonalCausalFromBottomRightMask
):
"""
(Experimental feature)
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
This makes the mask "local" and the attention pattern banded.
Query i only attends to keys in its block and cannot attend keys further than "window_size"
from it.
"""
_window_size: int = 0 # forced due to inheritance and default arguments
def __post_init__(self):
super().__post_init__()
if self._window_size <= 0:
raise ValueError(
f"Expected `window_size > 0`, but window_size={self._window_size}"
)
q_seqlen = [
y - x
for x, y in zip(
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
)
]
kv_seqlen = [
y - x
for x, y in zip(
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
)
]
for q, k in zip(q_seqlen, kv_seqlen):
if q + (q - k) - self._window_size >= k:
raise RuntimeError(
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
)
materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
if torch.max(materialized, dim=1).values.min() == -float("inf"):
raise RuntimeError("FUCKING FUCK FUCK")
def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
tensor = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=1,
device=device,
)
num_queries, num_keys = shape[-2:]
mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
if self._window_size is not None:
mask = torch.triu(
mask, diagonal=num_keys - num_queries - self._window_size + 1
)
mask = torch.log(mask)
return mask.to(dtype)

View File

@@ -0,0 +1,542 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
import torch
from ..._cpp_lib import _built_with_cuda
from ..common import BaseOperator
from .attn_bias import (
AttentionBias,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
# NoneType
if isinstance(None, attn_bias_type):
return True
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
return True
return False
@dataclass
class Inputs:
"""
Stores inputs to the `memory_efficient_attention` operators
"""
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
p: float = 0.0
scale: Optional[float] = None
use_alibi: bool = False
alibi_mode: int = 1
imp_mode: int = 0
@property
def device(self) -> torch.device:
return self.query.device
@property
def scale_float(self) -> float:
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.query.ndim == 5:
return self.query, self.key, self.value
if self.query.ndim == 4:
return (
self.query.unsqueeze(2),
self.key.unsqueeze(2),
self.value.unsqueeze(2),
)
if self.value.ndim == 3:
return (
self.query[:, :, None, None],
self.key[:, :, None, None],
self.value[:, :, None, None],
)
assert False
def normalize_bmhk(self) -> Tuple[int, ...]:
if self.query.ndim not in [3, 4, 5]:
raise ValueError(
f"Invalid shape for query: {self.query.shape}. "
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
)
if self.value.dtype == torch.int32:
# Quantized K/V case, in which the last dims of Q and K are different.
# NB we currently don't have any implementations for quantized KV with
# SUPPORTS_DIFFERENT_VALUE_EMBED.
output_shape = tuple(self.query.shape)
else:
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
# Convert from legacy format
if self.query.ndim == 3:
self.query = self.query.unsqueeze(2)
self.key = self.key.unsqueeze(2)
self.value = self.value.unsqueeze(2)
if isinstance(self.attn_bias, torch.Tensor):
if self.attn_bias.ndim != 3:
raise ValueError(
f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
)
self.attn_bias = self.attn_bias.unsqueeze(1)
return output_shape
def validate_inputs(self) -> None:
qkv = (self.query, self.key, self.value)
if self.query.ndim not in (3, 4, 5) or any(
x.ndim != self.query.ndim for x in qkv
):
raise ValueError(
f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if any(x.device != self.query.device for x in qkv):
raise ValueError("Query/Key/Value should all be on the same device")
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
if not (quantized_dtypes or non_quantized_dtypes):
raise ValueError(
"Query/Key/Value should either all have the same dtype, or "
"(in the quantized case) Key/Value should have dtype torch.int32\n"
f" query.dtype: {self.query.dtype}\n"
f" key.dtype : {self.key.dtype}\n"
f" value.dtype: {self.value.dtype}"
)
# Biases with tensors attached are meant to be in BMHK format
# This would require to permute biases/gradients which can be expensive,
# so let's just forbid it - BMK is a legacy format anyway
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
type(self.attn_bias)
):
raise ValueError(
f"Please provide inputs in BMHK format rather "
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
)
attn_bias_t: Optional[torch.Tensor] = None
if isinstance(self.attn_bias, torch.Tensor):
attn_bias_t = self.attn_bias
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
attn_bias_t = self.attn_bias._bias
if self.query.ndim == 4 and attn_bias_t is not None:
expected_shape = (
self.query.shape[0],
self.query.shape[2],
self.query.shape[1],
self.key.shape[1],
)
if attn_bias_t.shape != expected_shape:
raise ValueError(
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if isinstance(self.attn_bias, BlockDiagonalMask):
if any(x.shape[0] != 1 for x in qkv):
raise ValueError(
f"Expected batch_size=1 when using block-diagonal bias\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if self.p < 0.0 or self.p > 1.0:
raise ValueError(f"Invalid dropout probability: p={self.p}")
# Check that shapes match between inputs
B, Mq = self.query.shape[:2]
K = self.query.shape[-1]
B, Mkv = self.key.shape[:2]
Kv = self.value.shape[-1]
valid_shapes = True
if self.query.ndim == 3: # BMK
valid_shapes = (
self.query.shape == (B, Mq, K)
and self.key.shape == (B, Mkv, K)
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
if self.query.ndim == 4: # BMHK
quantized_kv_cache = self.value.dtype == torch.int32
key_embed_dim = Kv if quantized_kv_cache else K
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
valid_shapes = (
self.query.shape == (B, Mq, G, H, K)
and self.key.shape == (B, Mkv, G, H, K)
and self.value.shape == (B, Mkv, G, H, Kv)
)
if not valid_shapes:
raise ValueError(
f"Incompatible shapes for attention inputs:\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}\n"
"HINT: We don't support broadcasting, please use `expand` "
"yourself before calling `memory_efficient_attention` if you need to"
)
@dataclass
class Context:
lse: torch.Tensor
out: torch.Tensor
q_padded: Optional[torch.Tensor] = None
k_padded: Optional[torch.Tensor] = None
v_padded: Optional[torch.Tensor] = None
o_padded: Optional[torch.Tensor] = None
op_bw: Optional[Type["AttentionBwOpBase"]] = None
rng_state: Optional[torch.Tensor] = None
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
lse = self.lse
if pad_amount > 0:
if force_pad_inf:
lse = lse[:, :, : self.out.shape[1]]
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
lse[:, :, self.out.shape[1] :].fill_(math.inf)
return lse
@dataclass
class Gradients:
dq: torch.Tensor
dk: torch.Tensor
dv: torch.Tensor
# bias gradient. None if there is no tensor bias or if it doesn't require grad
db: Optional[torch.Tensor] = None
class AttentionOpBase(BaseOperator):
"""Base class for any attention operator in xFormers
See:
- :attr:`xformers.ops.fmha.cutlass.FwOp`
- :attr:`xformers.ops.fmha.cutlass.BwOp`
- :attr:`xformers.ops.fmha.flash.FwOp`
- :attr:`xformers.ops.fmha.flash.BwOp`
- :attr:`xformers.ops.fmha.triton.FwOp`
- :attr:`xformers.ops.fmha.triton.BwOp`
- :attr:`xformers.ops.fmha.small_k.FwOp`
- :attr:`xformers.ops.fmha.small_k.BwOp`
"""
OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"
_TEST_BATCH_SIZES: List[int] = [1, 300]
_TEST_K: List[int] = [32, 128]
@classmethod
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = []
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
reasons.append("query.shape[-1] != value.shape[-1]")
if max(K, Kv) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
reasons = cls.shape_not_supported_reasons(
Mq=d.query.shape[1],
Mkv=d.key.shape[1],
K=d.query.shape[-1],
Kv=d.query.shape[-1],
)
device_type = d.query.device.type
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
if device_type == "cuda" and not _built_with_cuda:
reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too old)"
)
if dtype not in cls.SUPPORTED_DTYPES:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
reasons.append("has custom scale")
# bfloat16 is only supported on A100+
# ... although the kernels can still run and give the
# correct result
if dtype is torch.bfloat16 and (
not device_type.startswith("cuda")
):
reasons.append("bf16 is only supported on A100+ GPUs")
if not cls.is_available():
reasons.append(
"operator wasn't built - see `python -m xformers.info` for more info"
)
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
reasons.append(
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
)
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
reasons.append("operator does not support BMGHK format")
return reasons
class AttentionFwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 3e-4,
torch.half: 4e-3,
torch.bfloat16: 2e-2,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 2e-5,
torch.half: 4e-4,
torch.bfloat16: 5e-3,
}
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# Q @ K.transpose
total_flop += num_q * num_kv * query.shape[-1] * 2
# (ignore softmax)
# attn @ V
total_flop += num_q * key.shape[-1] * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
class AttentionBwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 5e-4,
torch.half: 9e-2,
torch.bfloat16: 0.7,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 0.1,
}
SUPPORTS_ATTN_BIAS_GRAD = False
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
if (
isinstance(d.attn_bias, torch.Tensor)
and d.attn_bias.requires_grad
and not cls.SUPPORTS_ATTN_BIAS_GRAD
):
reasons.append(
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
Kqk = query.shape[-1]
Kv = value.shape[-1]
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# att = Q @ K.transpose
total_flop += num_q * num_kv * Kqk * 2
# att @ dO
total_flop += num_kv * num_q * Kv * 2
# dov = dO @ V
total_flop += num_q * Kv * num_kv * 2
# dov @ K
total_flop += num_q * Kqk * num_kv * 2
# dov @ Q
total_flop += num_q * Kqk * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
AttentionOp = Tuple[
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
]
@dataclass
class AttentionOpDispatch:
"""Dispatcher to automatically select
the best operator to run memory-efficient attention.
:Deprecated:
This class is deprecated and will be removed in a later version
"""
op: AttentionOp
@classmethod
def from_arguments(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> "AttentionOpDispatch":
"""Here for backward compatibility"""
from .dispatch import _dispatch_bw, _dispatch_fw
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
)
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)
def check_lastdim_alignment_stride1(
reasons: List[str], name: str, x: torch.Tensor, alignment: int
) -> None:
if x.shape[-1] % alignment != 0:
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
elif x.stride(-2) % alignment != 0:
reasons.append(
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
)
# We can have stride=0 sometimes if dimension=1
if x.stride(-1) > 1:
reasons.append(
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
)

View File

@@ -0,0 +1,479 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from enum import Enum
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
import torch
from ..common import get_xformers_operator, register_operator
from . import attn_bias
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
def _uses_tensorcores(sm: int, is_half: bool) -> bool:
if sm >= 80:
return True
if sm >= 70:
return is_half
return False
def _minimum_gemm_alignment(inp: Inputs) -> int:
if inp.device.type != "cuda":
return 1
cap = torch.cuda.get_device_capability(inp.device)
sm = cap[0] * 10 + cap[1]
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
inp.query.dtype
]
uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
matmul_alignment_mn = 1
if sm >= 80:
matmul_alignment_mn = 4
if uses_tensorcores:
matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
return matmul_alignment_mn
def _get_seqlen_info(
inp: Inputs,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
attn_bias = inp.attn_bias
if isinstance(
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
):
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
seqstart_k = attn_bias.k_seqinfo.seqstart
seqstart_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
else:
seqstart_k = None
seqstart_q = None
max_seqlen_q = -1
max_seqlen_k = -1
return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
def _get_tensor_bias(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Optional[torch.Tensor]:
if isinstance(attn_bias, torch.Tensor):
return attn_bias
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
return attn_bias._bias
return None
def _check_bias_alignment(
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> None:
attn_bias_tensor = _get_tensor_bias(attn_bias)
if attn_bias_tensor is not None:
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
show_padding_hint = False
for d in range(attn_bias_tensor.ndim - 1):
if attn_bias_tensor.stride(d) % alignment != 0:
reasons.append(
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
)
show_padding_hint = True
if show_padding_hint:
reasons.append(
"""\
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
you need to ensure memory is aligned by slicing a bigger tensor. \
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
)
# We can have stride=0 sometimes if dimension=1
if attn_bias_tensor.stride(-1) > 1:
reasons.append(
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
"you should call `.contiguous()` on the bias"
)
class _CustomMaskType(int, Enum):
"""
(Matches CustomMaskType in C++.)
"""
NoCustomMask = 0
CausalFromTopLeft = 1
CausalFromBottomRight = 2
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
if isinstance(
bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
),
):
return int(_CustomMaskType.CausalFromTopLeft)
if isinstance(
bias,
(
attn_bias.BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
):
return int(_CustomMaskType.CausalFromBottomRight)
return int(_CustomMaskType.NoCustomMask)
@register_operator
class FwOp(AttentionFwOpBase):
"""xFormers' MHA kernel based on CUTLASS.
Supports a large number of settings (including without TensorCores, f32 ...)
and GPUs as old as P100 (Sm60)
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 65536
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
torch.Tensor,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
attn_bias.BlockDiagonalCausalFromBottomRightMask,
attn_bias.BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = True
SUPPORTS_BMGHK = True
NAME = "cutlassF"
_TEST_K: List[int] = [
32, # 64x64 kernel
128, # 64x128 kernel
256, # 64x128 with accumulation in gmem
]
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
if inp.query.ndim in [3, 4]:
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
# Workaround until this is properly implemented in C++
# run each head group in a different stream
n_groups = inp.key.shape[2]
main_stream = torch.cuda.current_stream()
streams = [main_stream] + [
torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
]
outs = []
for group, stream in enumerate(streams):
stream.wait_stream(main_stream)
with torch.cuda.stream(stream):
query = inp.query[:, :, group]
key = inp.key[:, :, group]
value = inp.value[:, :, group]
bias = inp.attn_bias
if isinstance(bias, torch.Tensor):
bias = bias[:, group]
if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias):
bias = attn_bias.LowerTriangularMaskWithTensorBias(
bias._bias[:, group]
)
outs.append(
cls.apply_bmhk(
replace(inp, query=query, key=key, value=value, attn_bias=bias),
needs_gradient=needs_gradient,
)
)
for s in streams[1:]:
main_stream.wait_stream(s)
out = torch.stack([o[0] for o in outs], dim=2)
ctx: Optional[Context] = None
if needs_gradient:
ctx = Context(
out=out,
lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
op_bw=outs[0][1].op_bw, # type: ignore
)
return out, ctx
@classmethod
def apply_bmhk(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
out, lse, rng_seed, rng_offset = cls.OPERATOR(
query=inp.query,
key=inp.key,
value=inp.value,
attn_bias=_get_tensor_bias(inp.attn_bias),
seqstart_q=seqstart_q,
seqstart_k=seqstart_k,
max_seqlen_q=max_seqlen_q,
dropout_p=inp.p,
compute_logsumexp=needs_gradient,
custom_mask_type=_custom_mask_type(inp.attn_bias),
scale=inp.scale,
seqlen_k=inp.attn_bias.k_seqinfo.seqlen
if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
else None,
window_size=inp.attn_bias._window_size
if isinstance(
inp.attn_bias,
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
else None,
)
ctx: Optional[Context] = None
if needs_gradient:
ctx = Context(
out=out,
lse=lse,
# cutlass forward is only compatible with cutlass backward if
# dropout is used (because of the way RNG states are passed and the
# way random numbers are generated during backward)
op_bw=BwOp if inp.p != 0 else None,
)
if inp.p != 0:
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
return reasons
@classmethod
# type: ignore
def operator_flop(
cls,
q,
k,
v,
b,
seqstart_q,
seqstart_k,
max_seqlen_q_,
compute_lse,
custom_mask_type,
*a,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
causal=custom_mask_type > 0,
seqstart_k=seqstart_k,
seqstart_q=seqstart_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
torch.Tensor,
LowerTriangularMask,
# TODO: Fix handling of gradient through the fMHA autograd function
# LowerTriangularMaskWithTensorBias,
BlockDiagonalMask,
BlockDiagonalCausalMask,
attn_bias.BlockDiagonalCausalFromBottomRightMask,
attn_bias.BlockDiagonalCausalLocalAttentionMask,
}
SUPPORTS_ATTN_BIAS_GRAD = True
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
NAME = "cutlassB"
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 5e-4,
# increased from 9e-2, more opportunities for numerical errors when bias is
# used, noticed in gK on SM80
torch.half: 1e-1,
torch.bfloat16: 7e-1,
}
_TEST_K: List[int] = [
32, # 64x64 kernel
128, # 64x128/128x128 kernel
256, # 64x128 with accumulation in gmem
]
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
matmul_alignment_mn = _minimum_gemm_alignment(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
# Backprop of gradient through broadcasted bias is not supported
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
# Don't forget that inputs are either in BMK or BMHK!
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
else:
# bias is B H Mq Mk
expected_bias_shape = (
d.query.shape[0],
d.query.shape[2] if d.query.ndim == 4 else 1,
d.query.shape[1],
d.key.shape[1],
)
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
reasons.append(
"Broadcasting the `attn_bias` tensor is not supported "
f"(shape: {tuple(attn_bias_tensor.shape)}"
f"/ expected: {expected_bias_shape})"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
raise NotImplementedError("Unsupported attn_bias type")
seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
dtype = inp.query.dtype
rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()
force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
grad.to(dtype),
inp.query,
inp.key,
inp.value,
_get_tensor_bias(inp.attn_bias),
cu_seqlens_q=seqstart_q,
cu_seqlens_k=seqstart_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
output=ctx.out.to(dtype),
dropout_p=inp.p,
# if not using dropout, seed and offset are irrelevant but still expected
# in function signature so just pass 0
# seed and offset could be None if a different FW op other than cutlass
# was used.
rng_seed=rng_seed,
rng_offset=rng_offset,
custom_mask_type=_custom_mask_type(inp.attn_bias),
scale=inp.scale,
num_splits_key=-1, # Let C++ determine it
window_size=inp.attn_bias._window_size
if isinstance(
inp.attn_bias,
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
else None,
)
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
# require grad
if not (
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
):
grad_bias = None
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
@classmethod
# type: ignore
def operator_flop(
cls,
dO,
q,
k,
v,
b,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
logsumexp,
output,
dropout_p,
rng_seed,
rng_offset,
custom_mask_type,
scale,
) -> int:
return cls.attn_operator_flop(
q,
k,
v,
seqstart_q=cu_seqlens_q,
seqstart_k=cu_seqlens_k,
causal=custom_mask_type > 0,
)

View File

@@ -0,0 +1,108 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Optional, Set, Tuple
import numpy as np
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs
@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
SUPPORTED_MAX_K: float = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
NAME = "decoderF"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
attn_bias = d.attn_bias
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
# If we don't get here, we've an error elsewhere
if d.query.ndim != 4 or d.key.ndim != 4:
reasons.append("Inputs must be BMHK. BMK not supported")
if d.query.shape[0] != 1:
reasons.append("One formal batch element expected")
if d.query.shape[-1] != 128:
reasons.append("Only head_dim==128 for now.")
if d.key.stride(-1) != 1:
reasons.append("expect keys to have last dim contiguous")
if d.value.stride(-1) != 1:
reasons.append("expect values to have last dim contiguous")
q_starts = attn_bias.q_seqinfo.seqstart_py
if attn_bias.q_seqinfo.max_seqlen != 1:
reasons.append("decoding expects one query")
elif d.query.shape[1] != len(q_starts) - 1:
reasons.append("empty lanes not supported yet")
if attn_bias.k_seqinfo.padding > 8192:
reasons.append("key padding exceeds 8192")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if needs_gradient:
raise NotImplementedError("gradient")
attn_bias = inp.attn_bias
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
padding = attn_bias.k_seqinfo.padding
multiquery = inp.key.stride(2) == 0
if multiquery:
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
else:
key = inp.key[0].unflatten(0, (-1, padding))
value = inp.value[0].unflatten(0, (-1, padding))
seq_positions = attn_bias.k_seqinfo.seqlen
query = inp.query[0, :, None]
if inp.scale is not None:
qk_scale = inp.scale
else:
qk_scale = 1.0 / np.sqrt(key.shape[-1])
out = cls.OPERATOR(
query=query,
key=key,
value=value,
seq_positions=seq_positions,
scale=qk_scale,
)
return out, None

View File

@@ -0,0 +1,147 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import textwrap
from collections import deque
from typing import List, Sequence, Type, TypeVar
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
return False
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
# TODO: fill out
return False
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
def _format_inputs_description(inp: Inputs) -> str:
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
attn_bias : {type(inp.attn_bias)}
p : {inp.p}"""
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
reasons = op.not_supported_reasons(inp)
if not reasons:
return
raise exc_type(
f"""Operator `{name}` does not support inputs:
{textwrap.indent(_format_inputs_description(inp), ' ')}
{_format_not_supported_reasons(op, reasons)}"""
)
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
not_supported_reasons: List[List[str]] = []
for op in priority_list:
not_supported = op.not_supported_reasons(inp)
if not not_supported:
return op
not_supported_reasons.append(not_supported)
# Let's write a nice message explaining what we tried and why it's not supported
msg = f"""No operator found for `{name}` with inputs:
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
for op, not_supported in zip(priority_list, not_supported_reasons):
msg += "\n" + _format_not_supported_reasons(op, not_supported)
raise NotImplementedError(msg)
def _dispatch_fw_priority_list(
inp: Inputs, needs_gradient: bool
) -> Sequence[Type[AttentionFwOpBase]]:
priority_list_ops = deque(
[
flash.FwOp,
triton.FwOp,
cutlass.FwOp,
small_k.FwOp,
]
)
if _is_cutlass_fwd_faster_than_flash(inp):
priority_list_ops.remove(cutlass.FwOp)
priority_list_ops.appendleft(cutlass.FwOp)
if _is_triton_fwd_fastest(inp):
priority_list_ops.remove(triton.FwOp)
priority_list_ops.appendleft(triton.FwOp)
if not needs_gradient:
mqa_or_gqa = (
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
)
if not mqa_or_gqa:
# With multiquery, cutlass is sometimes faster than decoder
# but it's not currently clear when.
priority_list_ops.appendleft(decoder.FwOp)
# Split-KV is useful with MQA
# for short Q-seqlen / long K-seqlen
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
parallelism_BH = 0 # BMK
if inp.query.ndim == 3:
parallelism_BH = inp.query.shape[0]
elif inp.query.ndim == 4: # BMHK
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
elif inp.query.ndim == 5: # BMGHK
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
if parallelism_BH > 0 and parallelism_BH < 64:
priority_list_ops.appendleft(triton_splitk.FwOp)
# Without variable seqlen flash is fastest
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
priority_list_ops.remove(flash.FwOp)
priority_list_ops.appendleft(flash.FwOp)
return priority_list_ops
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
"""Computes the best operator for forward
Raises:
NotImplementedError: if not operator was found
Returns:
AttentionOp: The best operator for the configuration
"""
# return _run_priority_list(
# "memory_efficient_attention_forward",
# _dispatch_fw_priority_list(inp, needs_gradient),
# inp,
# )
return flash.FwOp
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
return False
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
priority_list_ops: List[Type[AttentionBwOpBase]] = [
flash.BwOp,
cutlass.BwOp,
# CUDA illegal memory issues, race conditions etc..
# triton.BwOp,
# Deprecated
small_k.BwOp,
]
# if _is_cutlassB_faster_than_flash(inp):
# priority_list_ops.remove(cutlass.BwOp)
# priority_list_ops.insert(0, cutlass.BwOp)
# return _run_priority_list(
# "memory_efficient_attention_backward", priority_list_ops, inp
# )
return flash.BwOp

View File

@@ -0,0 +1,666 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from itertools import zip_longest
from typing import Any, List, Optional, Set, Tuple, Union
import os
import torch
from ..common import _get_storage_base, get_operator, register_operator
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
global enable_ixdnn
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
FLASH_VERSION = "0.0.0"
try:
try:
from ... import _C_flashattention # type: ignore[attr-defined]
from ..._cpp_lib import _build_metadata
if _build_metadata is not None:
FLASH_VERSION = _build_metadata.flash_version
except ImportError:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
FLASH_VERSION = flash_attn.__version__
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
if flash_ver_parsed < (2, 3):
raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
_flash_lib = torch.library.Library("xformers_flash", "DEF")
_flash_lib.define(
"flash_fwd(Tensor query, Tensor key, Tensor value, "
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)
_flash_lib.define(
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
)
def _flash_fwd(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
return_softmax,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd_ixdnn(
query,
key,
value,
None, # out
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None, # rng
)
else:
if cu_seq_lens_q is None:
assert cu_seq_lens_k is None
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd(
query,
key,
value,
None, # out
p,
softmax_scale,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None, # rng
)
else:
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.varlen_fwd(
query,
key,
value,
out,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None,
)
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
def _flash_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
_C_flashattention.bwd_ixdnn(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None,
)
else:
if cu_seq_lens_k is None:
assert cu_seq_lens_q is None
_C_flashattention.bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
p,
softmax_scale,
is_causal,
use_alibi,
alibi_mode,
None,
)
else:
_C_flashattention.varlen_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False, # zero_tensors
is_causal,
use_alibi,
alibi_mode,
None,
)
return dq, dk, dv
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
def _convert_input_format(
inp: Inputs,
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
assert inp.query.ndim in [4, 5]
query, key, value = inp.query, inp.key, inp.value
batch = query.shape[0]
seqlen_q = query.shape[1]
seqlen_kv = key.shape[1]
head_dim_q = query.shape[-1]
head_dim_v = value.shape[-1]
attn_bias = inp.attn_bias
if enable_ixdnn:
if query.shape[0] == 1:
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
else:
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
else:
if isinstance(attn_bias, BlockDiagonalMask):
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
else:
cu_seqlen_k = None
cu_seqlen_q = None
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
if query.ndim == 5: # QGA
# Fold the group/head_in_group dimensions together
def fold(x):
# Either the head is replicated
if x.stride(3) == 0:
return x[:, :, :, 0]
# Or we reshape
return x.reshape(
[
x.shape[0],
x.shape[1],
-1,
x.shape[4],
]
)
query = fold(query)
key = fold(key)
value = fold(value)
# Optimize for MHA
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
key = key[:, :, :1]
value = value[:, :, :1]
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
# We want format `[batch * seqlen, num_heads, head_dim_q]`
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
new_inp = replace(
inp,
query=query,
key=key,
value=value,
)
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
return isinstance(
attn_bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
if isinstance(
attn_bias,
(BlockDiagonalCausalLocalAttentionMask,),
):
return attn_bias._window_size or 0
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
return attn_bias._window_size
return 0
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
# Flash does not support TopLeft, so only allow causal masks with TopLeft
# if each batch element has equal number of queries and keys.
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
# if each batch element has equal number of queries and keys.
for k_start, q_start in zip_longest(
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
):
if k_start != q_start:
reasons.append(
"Only support BlockDiagonalCausalMask if equal"
" numbers of keys and queries"
)
break
elif isinstance(d.attn_bias, LowerTriangularMask):
if d.query.shape[1] != d.key.shape[1]:
reasons.append(
"Only support LowerTriangularMask if equal number of" "keys and queries"
)
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
"""
We want to be able to collapse the G/H dimensions together
"""
if x.ndim == 5:
stride_g, stride_h = x.stride(2), x.stride(3)
if x.shape[2] == 1:
return
if x.shape[3] == 1 or stride_h == 0:
return
if stride_g != stride_h * x.shape[-2]:
reasons.append(
f"GQA is only supported when the G/H dimensions are contiguous\n"
f" {name}.stride: {x.stride()}\n"
f" {name}.shape : {list(x.shape)}"
)
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
implementation.
"""
OPERATOR = get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 256
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = False
SUPPORTS_BMGHK = True
NAME = f"flshattF@{FLASH_VERSION}"
VERSION = FLASH_VERSION
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
_check_strides_for_bmghk(d.query, "query", reasons)
_check_strides_for_bmghk(d.key, "key", reasons)
_check_strides_for_bmghk(d.value, "value", reasons)
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
return_softmax = False
out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# no cumulative seqlen
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
out = out.reshape(out_shape)
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
if inp.p != 0.0:
ctx.op_bw = BwOp
return (out, ctx)
@classmethod
# type: ignore
def operator_flop(
cls,
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
return_softmax,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
IS_DETERMINISTIC = False
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
NAME = f"flshattB@{FLASH_VERSION}"
VERSION = FLASH_VERSION
MAX_HEADDIM_SM8x = 192
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
# is a strange embedding dimension.
# if K not in {8, 16, 32, 64, 128, 256}:
# reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
if d.device.type == "cuda":
# Due to limited shared-memory, some GPUs are limited in head dimension
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
if (
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
and not is_sm80_or_sm90
):
reasons.append(
"requires a GPU with compute capability 8.0 "
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
assert ctx.lse.is_contiguous
ctx_lse = ctx.lse
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
assert ctx_lse.shape[2] >= max_seqlen_q
if max_seqlen_q != ctx_lse.shape[2]:
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
kernel_out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# Create dq,dk,dv
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
# right strides, so we can avoid a `cat`
if (
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
):
# Create one big contiguous chunk
# This is because q, k and v usually come from a single
# output of a linear layer that is chunked.
# Creating the gradients with the right layout saves us
# a `torch.cat` call in the backward pass
chunk = torch.empty(
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
dtype=ctx.q_padded.dtype,
device=inp.device,
)
grads = Gradients(
dq=chunk.select(-3, 0),
dk=chunk.select(-3, 1),
dv=chunk.select(-3, 2),
)
else:
grads = Gradients(
dq=torch.empty_like(ctx.q_padded),
dk=torch.empty_like(ctx.k_padded),
dv=torch.empty_like(ctx.v_padded),
)
assert grad.dtype in cls.SUPPORTED_DTYPES
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
ctx.q_padded,
ctx.k_padded,
ctx.v_padded,
# ctx.out.reshape(kernel_out_shape),
ctx.o_padded,
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
return grads
@classmethod
# type: ignore
def operator_flop(
cls,
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)

View File

@@ -0,0 +1,186 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
import torch
from ..common import get_xformers_operator, register_operator
from .attn_bias import AttentionBias
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
bmk2bmhk,
)
def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
return (
tensor.permute((0, 2, 1, 3))
.contiguous()
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
.contiguous()
)
def _get_tensor_bias_bmk(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Optional[torch.Tensor]:
if not isinstance(attn_bias, torch.Tensor):
assert attn_bias is None
return None
# BMK -> BMHK
if attn_bias.ndim == 4:
attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]])
return attn_bias
@register_operator
class FwOp(AttentionFwOpBase):
"""An operator optimized for very small values of K (``K <= 32``) \
and f32 pre-Ampere as it does not use TensorCores.
Only supports contiguous inputs in BMK format, so an extra reshape \
or contiguous call might be done.
:Deprecated:
This operator is deprecated and should not be used in new code
"""
OPERATOR = get_xformers_operator("efficient_attention_forward_small_k")
SUPPORTED_DEVICES = {"cuda"}
SUPPORTED_DTYPES = {torch.float}
SUPPORTED_MAX_K: float = 32
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = False
NAME = "smallkF"
BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
# as this kernel is a bit slow, this should make tests run faster
_TEST_BATCH_SIZES = [1, 3]
_TEST_K = [2, 3, 8, 16, 32]
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
if inp.scale is not None:
raise NotImplementedError("Unsupport custom scale")
num_heads = inp.query.shape[2]
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out, lse, rng_seed, rng_offset = cls.OPERATOR(
query=query,
key=key,
value=value,
compute_logsumexp=needs_gradient,
attn_bias=_get_tensor_bias_bmk(inp.attn_bias),
p=inp.p,
)
out = bmk2bmhk(out, num_heads)
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
if not needs_gradient:
return out, None
ctx = Context(out=out, lse=lse)
if inp.p != 0.0:
ctx.op_bw = BwOp
ctx.rng_state = torch.tensor(
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
)
return out, ctx
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
# there is some extra precision loss in the CPU implementation due to an
# extra accumulation step in grad_q, which is not present in the CUDA
# implementation
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 4e-3,
}
NAME = "smallkB"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
reasons.append("bias with non-zero stride not supported")
buffer_size = 8
k = d.query.shape[-1]
for pack in [1, 2, 4]:
if (k % pack) == 0 and (k // pack) <= buffer_size:
return reasons
reasons.append(f"unsupported embed per head: {k}")
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
num_heads = grad.shape[2]
grad = _bmhk2bmk_contiguous(grad)
query = _bmhk2bmk_contiguous(inp.query)
key = _bmhk2bmk_contiguous(inp.key)
value = _bmhk2bmk_contiguous(inp.value)
out = _bmhk2bmk_contiguous(ctx.out)
rng_seed = rng_offset = 0
if inp.p != 0.0:
if (
ctx.rng_state is None
or ctx.rng_state.dtype != torch.int64
or ctx.rng_state.device.type != "cpu"
or ctx.rng_state.shape != (2,)
):
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
rng_seed, rng_offset = ctx.rng_state.tolist()
grad_q, grad_k, grad_v = cls.OPERATOR(
grad,
query,
key,
value,
# LSE: BHM -> (BH)M
ctx.lse.reshape([-1, ctx.lse.shape[-1]]),
out,
_get_tensor_bias_bmk(inp.attn_bias),
inp.p,
rng_seed,
rng_offset,
)
return Gradients(
dq=bmk2bmhk(grad_q, num_heads),
dk=bmk2bmhk(grad_k, num_heads),
dv=bmk2bmhk(grad_v, num_heads),
)

View File

@@ -0,0 +1,201 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
import torch
from ... import _is_triton_available
from ..common import register_operator
# This implementation needs pre-MLIR triton
# The BW pass is not stable/well tested
# And also does not have the latest improvements
if TYPE_CHECKING or (False and _is_triton_available()):
try:
from flash_attn.flash_attn_triton import (
_flash_attn_backward,
_flash_attn_forward,
)
except ImportError:
import importlib
import pathlib
import sys
import types
def import_module_from_path(path: str) -> types.ModuleType:
"""Import a module from the given path, w/o __init__.py"""
module_path = pathlib.Path(path).resolve()
module_name = module_path.stem # 'path/x.py' -> 'x'
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
assert isinstance(spec, importlib.machinery.ModuleSpec)
module = importlib.util.module_from_spec(spec) # type: ignore
sys.modules[module_name] = module
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
return module
flash_attn = import_module_from_path(
"third_party/flash-attention/flash_attn/flash_attn_triton.py"
)
_flash_attn_backward = flash_attn._flash_attn_backward
_flash_attn_forward = flash_attn._flash_attn_forward
triton_flash_backward = _flash_attn_backward
triton_flash_forward = _flash_attn_forward
else:
triton_flash_backward = None
triton_flash_forward = None
from .attn_bias import LowerTriangularMask
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
def _prepare_inputs(inp: Inputs) -> Inputs:
attn_bias = inp.attn_bias
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
B = inp.query.shape[0]
h = attn_bias.shape[0] // B
attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])
# Make sure that the last dimension is contiguous
query, key, value = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [inp.query, inp.key, inp.value]
]
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
implementation, based on
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
"""
OPERATOR = triton_flash_forward
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
# TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now.
# torch.Tensor,
}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
NAME = "tritonflashattF"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
# Has only been tested on 8.0 / 9.0.
# Fails on 7.5 with illegal memory access
if torch.cuda.get_device_capability(d.device) < (8, 0):
reasons.append(
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
)
if _is_triton_available():
import triton
if triton.__version__ > "2.0.0":
reasons.append("Only work on pre-MLIR triton for now")
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
inp = _prepare_inputs(inp)
out, lse, softmax_scale = triton_flash_forward(
q=inp.query,
k=inp.key,
v=inp.value,
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
softmax_scale=inp.scale_float,
causal=isinstance(inp.attn_bias, LowerTriangularMask),
)
return out, Context(lse=lse, out=out)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = triton_flash_backward
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
NAME = "tritonflashattB"
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
if torch.cuda.get_device_capability(d.device) != (8, 0):
reasons.append("requires A100 GPU")
if _is_triton_available():
import triton
if triton.__version__ > "2.0.0":
reasons.append("Only work on pre-MLIR triton for now")
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
inp = _prepare_inputs(inp)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
grads = Gradients(
dq=torch.empty_like(inp.query),
dk=torch.empty_like(inp.key),
dv=torch.empty_like(inp.value),
)
cls.OPERATOR(
grad,
inp.query,
inp.key,
inp.value,
ctx.out,
ctx.get_padded_lse(128),
grads.dq,
grads.dk,
grads.dv,
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
softmax_scale=inp.scale_float,
causal=isinstance(inp.attn_bias, LowerTriangularMask),
)
return grads

View File

@@ -0,0 +1,738 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
import torch
from ..common import _has_triton21, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
def _strides(x: torch.Tensor, *stride_names: str):
assert x.ndim == len(stride_names)
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
if TYPE_CHECKING or _has_triton21():
import triton
import triton.language as tl
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
@triton.jit
def _fwd_kernel_splitK(
Q,
K,
V,
sm_scale,
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Seq_len,
stride_qz,
stride_qm,
stride_qg,
stride_qh,
stride_qk,
stride_kz,
stride_kn,
stride_kg,
stride_kh,
stride_kk,
stride_vz,
stride_vn,
stride_vg,
stride_vh,
stride_vk,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
Z,
N_CTX_Q,
N_CTX_K,
BLOCK_N_PER_SPLIT,
H: tl.constexpr,
G: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_SEQ_LEN: tl.constexpr,
PACKED_PER_VAL: tl.constexpr = 1,
N_GROUPS: tl.constexpr = 1,
):
"""This kernel can accept non-quantized or int4-quantized keys/values.
PACKED_PER_VAL determines the quantization type:
- PACKED_PER_VAL == 1 means no quantization
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
For the quantized case K/V should be int32 tensors.
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
So K[B, H, M, :] has a form
[ quant_coef0, quant_coef1, ...|
group0_quant_value0, group0_quant_value1,... |
group1_quant_value0, group1_quant_value1,...]
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
See how FwOp.apply does it below.
"""
tl.static_assert(
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
)
tl.static_assert(
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
)
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
splitk_idx = tl.program_id(2)
lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_SEQ_LEN:
kv_len = tl.load(Seq_len + off_z)
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
Q_block_ptr = tl.make_block_ptr(
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
# Additional shift by 1 along the last dimension in the quantized case, since
# the first element along that dim contains packed quantization coefficients.
K_block_ptr = tl.make_block_ptr(
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
shape=(PACKED_D_PER_GROUP, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
order=(0, 1),
)
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
V_block_ptr = tl.make_block_ptr(
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
shape=(hi, PACKED_D_PER_GROUP),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
order=(1, 0),
)
if QUANTIZED:
# Pointers to quantization coefficients. Even those they are 1D,
# we have to use block pointers, since usual pointers
# don't support boundary checks
K_scale_shift_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(1, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(1, BLOCK_N),
order=(0, 1),
)
V_scale_shift_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, 1),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, 1),
order=(1, 0),
)
else:
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
# This is a solution for Triton native lack of support for lists of tensors.
acc: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
q[i] = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
k: "VAR_ARGS_ARRAY" # noqa: F821
v: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
PACKED_PER_VAL,
PACKED_D_PER_GROUP,
Q.dtype.element_ty,
i,
)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for i in range(len(acc)): # noqa: F821
qk += tl.dot(q[i], k[i]) # noqa: F821
qk *= qk_scale
# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
p = p.to(Q.dtype.element_ty)
# -- scale and update acc --
for i in range(len(acc)): # noqa: F821
acc[i] *= alpha[:, None] # noqa: F821
acc[i] += tl.dot(p, v[i]) # noqa: F821
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if PACKED_PER_VAL > 1:
K_scale_shift_block_ptr = tl.advance(
K_scale_shift_block_ptr, (0, BLOCK_N)
)
V_scale_shift_block_ptr = tl.advance(
V_scale_shift_block_ptr, (BLOCK_N, 0)
)
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_osk_m, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
for i in range(len(acc)): # noqa: F821
tl.store(
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
acc[i], # noqa: F821
boundary_check=(0,),
)
# Write metadata for split-K reduction
Metadata_ptr = (
Metadata
+ off_zhg * stride_mzhg
+ splitk_idx * stride_ms
+ start_m * BLOCK_M
+ tl.arange(0, BLOCK_M)
)
tl.store(Metadata_ptr, m_i)
tl.store(Metadata_ptr + stride_m2, l_i)
@triton.jit
def load_dequantize_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N: tl.constexpr,
PACKED_PER_VAL: tl.constexpr,
PACKED_D_PER_GROUP: tl.constexpr,
dtype: tl.constexpr,
group_id: tl.constexpr,
):
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
If quantization is group-wise, use group_id to advance the pointers to the current group.
"""
# Advance to the current quantization group
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
if PACKED_PER_VAL > 1:
# K/V are quantized, load quantization coefficients and dequantize
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
k_scale_shift = tl.load(
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
)
v_scale_shift = tl.load(
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
)
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
k_t = dequantize(
tl.trans(k),
tl.trans(k_scale),
tl.trans(k_shift),
PACKED_PER_VAL,
).to(dtype)
k = tl.trans(k_t)
return k, v
@triton.jit
def cast_uint32_to_half2(scale_shift):
"""Extract two float16 packed into one int32"""
scale = scale_shift & 0xFFFF
shift = scale_shift >> 16
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
return scale, shift
@triton.jit
def dequantize(
x_,
scale,
shift,
PACKED_PER_VAL: tl.constexpr = 8,
):
"""PACKED_PER_VAL is the number of values packed into each element x_.
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
"""
# Axis along which offsets are applied matters here
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
# and expand K/V to that shape before applying offsets
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
# on the implementation details which might change in the future.
# Ideally we would like to use tl.reshape, but it's not implemented yet.
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
# scale: (BLOCK_N, 1)
# offsets: (PACKED_PER_VAL,)
BLOCK_N: tl.constexpr = x_.shape[0]
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
offsets = tl.arange(0, PACKED_PER_VAL) * 4
quant_offset = (
x_[:, None, :] >> offsets[None, :, None]
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
quant_offset = tl.view(
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
)
# Trick - instead of converting int4 to float16 we view it as float16
# and then multiply by 32768 * 512 == 2**24
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
quant_offset = (quant_offset * 32768.0).to(tl.float16)
scale_512 = scale * 512
dequant = quant_offset * scale_512 + shift
return dequant
@triton.jit
def _splitK_reduce(
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Out, # [B, H, M, K]
LSE, # [B, H, M]
split_k,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
stride_oz,
stride_oh,
stride_og,
stride_om,
stride_ok,
stride_lse_zhg,
stride_lse_m,
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
):
off_zhg = tl.program_id(0)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
Out_splitK_ptr = (
Out_splitK
+ stride_osk_zhg * off_zhg
+ stride_osk_m * off_m
+ tl.arange(0, BLOCK_SIZE)
)
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
m = tl.load(Metadata_ptr)
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(Out_splitK_ptr)
for split_k_idx in range(1, split_k):
Metadata_ptr = Metadata_ptr + stride_ms
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
m_k = tl.load(Metadata_ptr)
l_k = tl.load(Metadata_ptr + stride_m2)
acc_k = tl.load(Out_splitK_ptr)
m_new = tl.maximum(m, m_k)
if m_k < m:
# Scale incoming values
alpha = tl.math.exp2(m_k - m_new)
acc_k = acc_k * alpha
l_k = l_k * alpha
else:
# Scale our values
alpha = tl.math.exp2(m - m_new)
acc = acc * alpha
l_sum = l_sum * alpha
m = m_new
l_sum = l_sum + l_k
acc = acc + acc_k
acc = acc / l_sum
Out_ptr = (
Out
+ stride_oz * off_z
+ stride_oh * off_h
+ stride_og * off_g
+ stride_om * off_m
+ tl.arange(0, BLOCK_SIZE)
)
tl.store(Out_ptr, acc)
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
else:
_fwd_kernel_splitK = None
_splitK_reduce = None
@register_operator
class FwOp(AttentionFwOpBase):
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
Quantized path will be taken if input K/V have type int32.
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
Quantization coefficients (scale and shift) are represented as two
float16 constants per group, packed into int32. Quantization coefficients of
all groups are placed at the beginning of the row. So, if unquantized K/V have head
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
and dtype int32.
Pseudocode for dequantizing one row can look like:
group_size = D // 8
for i in range(NUM_GROUPS):
group_start = NUM_GROUPS + i * group_size
group_quant = K[..., group_start: group_start + group_size]
scale, shift = unpack_int32_into_float16x2(group_quant[0])
group_dequant = group_quant[..., 1:] * scale + shift
...
"""
OPERATOR = _fwd_kernel_splitK
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES = {
torch.half,
torch.bfloat16,
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
SUPPORTED_MAX_K = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
BlockDiagonalCausalWithOffsetPaddedKeysMask,
}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_BMGHK = True
NAME = "triton_splitKF"
SPLIT_K: Optional[int] = None
BLOCK_M = 16
BLOCK_N = 64
NUM_GROUPS = 1 # Default quantization is row-wise
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
if K not in {16, 32, 64, 128}:
reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
if d.key.dtype != torch.int32:
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
# Has only been tested on 8.0 / 9.0.
if torch.cuda.get_device_capability(d.device) < (8, 0):
reasons.append(
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
)
q_len = d.query.shape[1]
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
seqinfo = d.attn_bias.q_seqinfo
if q_len != seqinfo.seqstart_py[-1]:
reasons.append(
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
)
q_len = seqinfo.min_seqlen
if q_len != seqinfo.max_seqlen:
reasons.append(
"Variable query len is not supported in the presence of causal mask."
)
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
reasons.append("multiquery is only supported with query seqlen=1")
if d.attn_bias is not None and q_len > 1:
reasons.append(
"query with seqlen > 1 is not supported in the presence of causal mask"
)
return reasons
@classmethod
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = B * H
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
while split_k > 0 and Mk / split_k < max_chunk_size:
split_k = split_k // 2
split_k = min(split_k, 64)
split_k = max(split_k, 1)
return split_k
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
attn_bias = inp.attn_bias
seq_len = None
q, k, v = inp.get_qkv_in_bmghk()
if attn_bias is not None:
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
# TODO: do we really need to do this cast? seems fishy but
# I just copied it from the decoder.py
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
seq_len = attn_bias.k_seqinfo.seqlen
B = len(seq_len)
G, H, Kq = q.shape[-3:]
Kkv = v.shape[-1]
# assume kv has been padded
q = q.reshape(B, -1, G, H, Kq)
k = k.reshape(B, -1, G, H, Kkv)
v = v.reshape(B, -1, G, H, Kkv)
# Transpose in the case of MQA/GQA
mqa_swap_seqlen_head = False
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
mqa_swap_seqlen_head = True
assert q.shape[1] == 1
q = q.transpose(1, 3)
k = k[:, :, :, :1]
v = v[:, :, :, :1]
if k.dtype == torch.int32:
# Quantized K/V
PACKED_PER_VAL = 8
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
else:
Lk = k.shape[-1]
PACKED_PER_VAL = 1
B, Mk, G, H, Kkv = k.shape
B, M, G, H, Kq = q.shape
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
BLOCK_M = cls.BLOCK_M
BLOCK_N = cls.BLOCK_N
if cls.SPLIT_K is not None:
split_k = cls.SPLIT_K
else:
# Use heuristics
split_k = cls.get_split_k(B, H, Mk)
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
o_splitk = torch.empty(
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
)
metadata = torch.empty(
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
)
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
num_warps = 2
split_size = (Mk + split_k - 1) // split_k
use_seq_len = seq_len is not None
_fwd_kernel_splitK_unrolled = unroll_varargs(
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
)
_fwd_kernel_splitK_unrolled[grid](
Q=q,
K=k,
V=v,
sm_scale=inp.scale_float,
Out_splitK=o_splitk,
Metadata=metadata,
Seq_len=seq_len,
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
Z=B,
H=H,
G=G,
N_CTX_Q=M,
N_CTX_K=Mk,
BLOCK_N_PER_SPLIT=split_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
USE_SEQ_LEN=use_seq_len,
num_warps=num_warps,
num_stages=1,
PACKED_PER_VAL=PACKED_PER_VAL,
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
)
if mqa_swap_seqlen_head:
out = torch.empty(
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
).transpose(1, 3)
else:
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
# Merge together
grid = (B * G * H, M, 1)
_splitK_reduce[grid](
o_splitk,
metadata,
out,
lse,
split_k=split_k,
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(out, "oz", "om", "og", "oh", "ok"),
**_strides(lse, "lse_zhg", "lse_m"),
BLOCK_SIZE=out.shape[-1],
G=G,
H=H,
# TODO: Tune num_warps
)
lse = lse.reshape([B, G, H, M])
if mqa_swap_seqlen_head:
# H/M dimensions have been swapped
out = out.transpose(1, 3)
lse = lse.transpose(2, 3)
if inp.query.ndim == 4:
# BMGHK -> BMHK
assert G == 1
out = out[:, :, 0]
lse = lse[:, 0]
return out, Context(out=out, lse=lse)
class FwOp_S1(FwOp):
SPLIT_K = 1
NAME = "triton_splitK1"
class FwOp_S2(FwOp):
SPLIT_K = 2
NAME = "triton_splitK2"
class FwOp_S4(FwOp):
SPLIT_K = 4
NAME = "triton_splitK4"
class FwOp_S8(FwOp):
SPLIT_K = 8
NAME = "triton_splitK8"
class FwOp_S16(FwOp):
SPLIT_K = 16
NAME = "triton_splitK16"
class FwOp_S32(FwOp):
SPLIT_K = 32
NAME = "triton_splitK32"
class FwOp_S64(FwOp):
SPLIT_K = 64
NAME = "triton_splitK64"
class FwOp_S128(FwOp):
SPLIT_K = 128
NAME = "triton_splitK128"

View File

@@ -0,0 +1,223 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence
import torch
from .common import BaseOperator, get_xformers_operator, register_operator
@register_operator
class ScaledIndexAddFw(BaseOperator):
OPERATOR = get_xformers_operator("scaled_index_addF")
OPERATOR_CATEGORY = "indexing"
NAME = "scaled_index_addF"
@register_operator
class ScaledIndexAddBw(BaseOperator):
OPERATOR = get_xformers_operator("scaled_index_addB")
OPERATOR_CATEGORY = "indexing"
NAME = "scaled_index_addB"
@register_operator
class IndexSelect(BaseOperator):
OPERATOR = get_xformers_operator("index_select")
OPERATOR_CATEGORY = "indexing"
NAME = "index_select"
class _ScaledIndexAdd(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx,
input: torch.Tensor,
index: torch.Tensor,
source: torch.Tensor,
scaling: Optional[torch.Tensor],
alpha: float,
) -> torch.Tensor:
ScaledIndexAddFw.OPERATOR(
output=input, # in-place
input=input,
source=source,
index=index,
source_scaling=scaling,
alpha=alpha,
)
ctx.mark_dirty(input)
ctx.save_for_backward(index, scaling, source)
ctx.source_shape = source.shape
ctx.alpha = alpha
return input
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
index, scaling, source = ctx.saved_tensors
grad_source = torch.empty_like(grad_output[: index.shape[0]])
grad_source_scaling = (
torch.empty(
ctx.source_shape,
dtype=scaling.dtype,
device=scaling.device,
)
if scaling is not None
else None
)
ScaledIndexAddBw.OPERATOR(
grad_source=grad_source,
grad_source_scaling=grad_source_scaling,
grad_output=grad_output,
source=source,
index=index,
source_scaling=scaling,
alpha=ctx.alpha,
)
if grad_source_scaling is not None:
grad_source_scaling = grad_source_scaling.sum((0, 1))
return (
grad_output, # input
None, # index
grad_source, # source
grad_source_scaling, # scaling
None, # alpha
)
def scaled_index_add(
input: torch.Tensor, # [B, M, D]
index: torch.Tensor, # [Bi] - int64
source: torch.Tensor, # [Bi, M, D]
scaling: Optional[torch.Tensor] = None, # [D]
alpha: float = 1.0,
) -> torch.Tensor:
"""
In-place scaling+index_add
Indices in ``index`` are assumed to be unique
:Note:
The FW pass is done in-place (``input`` is modified)
:Note:
This is experimental and has only been optimized for a few shapes
:Equivalent pytorch code:
.. code-block:: python
return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha)
"""
return _ScaledIndexAdd.apply(
input,
index,
source,
scaling,
alpha,
)
class _IndexSelectCat(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx,
*args: torch.Tensor,
) -> torch.Tensor:
assert len(args) % 2 == 0
sources = args[: len(args) // 2]
indices = args[len(args) // 2 :]
output_shape = 0
total_source_elements = 0
for source, index in zip(sources, indices):
output_shape += index.shape[0] * source.shape[1]
total_source_elements += source.shape[0] * source.shape[1]
output = torch.empty(
[output_shape], dtype=sources[0].dtype, device=sources[0].device
)
output_i = 0
for source, index in zip(sources, indices):
elements_here = index.shape[0] * source.shape[1]
IndexSelect.OPERATOR(
output=output[output_i : output_i + elements_here].view(
[index.shape[0], source.shape[1]]
),
source=source,
index=index,
)
output_i += elements_here
ctx.save_for_backward(*indices)
ctx.total_source_elements = total_source_elements
ctx.source_shapes = [s.shape for s in sources]
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
indices = ctx.saved_tensors
grad_sources = torch.zeros(
[ctx.total_source_elements],
dtype=grad_output.dtype,
device=grad_output.device,
)
grad_sources_i = 0
grad_output_i = 0
gradients = []
for source_shape, index in zip(ctx.source_shapes, indices):
grad_output_slice = grad_output[
grad_output_i : grad_output_i + index.shape[0] * source_shape[1]
].reshape([index.shape[0], source_shape[1]])
grad_output_i += index.shape[0] * source_shape[1]
gradient_source = grad_sources[
grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1]
].reshape(source_shape)
grad_sources_i += source_shape[0] * source_shape[1]
ScaledIndexAddFw.OPERATOR(
output=gradient_source.unsqueeze(1),
input=None,
source=grad_output_slice.unsqueeze(1),
index=index,
source_scaling=None,
alpha=1.0,
)
gradients.append(gradient_source)
return (*gradients, *([None] * len(gradients)))
def index_select_cat(
sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor]
) -> torch.Tensor:
"""
Indices in ``index`` are assumed to be unique
:Note:
This is experimental and has only been optimized for a few shapes
:Example:
Given:
- ``sources[0]`` of shape ``[S0, D0]``
- ``indices[0]`` of shape ``[I0]``
- ``sources[1]`` of shape ``[S1, D1]``
- ``indices[1]`` of shape ``[I1]``
returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]``
:Equivalent pytorch code:
.. code-block:: python
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
"""
return _IndexSelectCat.apply(*sources, *indices)

View File

@@ -0,0 +1,113 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from torch import nn
from .. import _is_triton_available
def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6):
"""
RMS Normalization along the last dimension.
This is similar to torch.nn.functional.normalize but with eps being added
instead of max.
Expects x contiguous of shape (..., dim), and returns normalized data
of the same shape. For each dim-length vector x, the result has
x / sqrt( x*x.sum() + eps)
If weights are included, they are a contiguous parameter of length dim
which multiplies the result.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
assert _is_triton_available()
from .triton.rmsnorm_kernels import _rms_norm_forward
if torch.is_grad_enabled() and (
x.requires_grad or (weight is not None and weight.requires_grad)
):
raise ValueError("Gradients not supported.")
return _rms_norm_forward(x, weight, eps)
def rms_norm_add(
x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor], eps: float = 1e-6
):
"""
An addition fused with rms_norm.
z = rms_norm_add(x, y, weight, eps)
is equivalent to
x += y
z = rms_norm(x, weight, eps)
where x, y and z are all contiguous.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
if torch.is_grad_enabled() and (
x.requires_grad
or y.requires_grad
or (weight is not None and weight.requires_grad)
):
raise ValueError("Gradients not supported.")
assert _is_triton_available()
from .triton.rmsnorm_kernels import _rms_norm_add_forward
return _rms_norm_add_forward(x, y, weight, eps)
class RMSNorm(torch.nn.Module):
"""
RMS Normalization layer along the last dimension.
This is similar to torch.nn.functional.normalize but with eps being added
instead of max.
Expects contiguous input of shape (..., dim), and returns normalized data
of the same shape. For each dim-length vector x, the result has
x / sqrt( x*x.sum() + eps)
If weights are included, they are a parameter of length dim which multiplies
the result.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6):
super().__init__()
self.eps = eps
if include_weight:
self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, x: torch.Tensor):
return rms_norm(x, self.weight, self.eps) # type: ignore
def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
"""
An addition fused with forward.
z = layer.increment_and_forward_(x, y)
is equivalent to
x += y
z = layer(x)
"""
return rms_norm_add(x, y, self.weight, self.eps) # type: ignore

View File

@@ -0,0 +1,188 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from xformers.ops.fmha.attn_bias import ( # type: ignore
BlockDiagonalCausalWithOffsetPaddedKeysMask,
)
from .. import _is_triton_available
def rope_padded(
xq: torch.Tensor,
xk: torch.Tensor,
xv: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask,
*,
theta: float = 10000.0,
out_q: Optional[torch.Tensor] = None,
adjacents: bool = True,
internal_dtype: str = "",
):
"""
Performs RoPE (rotary embeddings) and kv-cache emplacement for a heterogeneous
batch for inference in the style given by
BlockDiagonalCausalWithOffsetPaddedKeysMask.
The batch is concatted along the sequence dimension, so the
actual dim-0 length of all tensors is 1.
xq, xk and xv should be (1, slen, n_heads, dim), where
xq's n_heads can differ from xk and xv.
This function places the roped xk in the right place in cache_k, and
xv (unmodified) in the right place in cache_v, and returns out_q
(the roped xq) such that things are ready to call
xformers.ops.memory_efficient_attention(
out_q, cache_k, cache_v, attn_bias=attn_bias
)
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
Arguments:
xq: tensor of queries to apply rope to
xk: tensor of keys to apply rope to
xv: tensor of values to copy into cache_v
cache_k: cache of keys, MODIFIED IN PLACE
cache_v: cache of values, MODIFIED IN PLACE
attn_bias: details the layout of caches.
Used to determine frequencies for the
RoPE calculation as well as the locations in cache_k and cache_v
to write to. Must be on the device.
adjacents: If True, the inputs are in adjacent pairs along the final dim axis.
This is like the released LLaMA model.
If False, the dim axis is split in two equal pieces.
I.e. the features are ordered with all the real parts before all
the imaginary parts. This matches HuggingFace, e.g.
https://github.com/huggingface/transformers/blob/
f143037789288ba532dada934a118e648e715738/
src/transformers/models/llama/modeling_llama.py#L126-L130
internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation
"""
if torch.is_grad_enabled() and (
xq.requires_grad
or xk.requires_grad
or xv.requires_grad
or cache_k.requires_grad
or cache_v.requires_grad
or out_q is not None
):
raise ValueError("Gradients not supported.")
assert _is_triton_available()
import triton
from .triton.rope_padded_kernels import _rope_padded_kernel
n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1]
cache_length = attn_bias.k_seqinfo.seqstart_py[-1]
bsz, q_len, n_q_heads, dim = xq.shape
assert q_len == n_total_queries
if bsz != 1:
raise ValueError(
"Expected batch size dimension to be 1" "as batches should be concatenated."
)
xk_shape = xk.shape
n_kv_heads = xk_shape[2]
if xk_shape != (1, n_total_queries, n_kv_heads, dim):
raise ValueError("unexpected k shape")
if xv.shape != (1, n_total_queries, n_kv_heads, dim):
raise ValueError("unexpected v shape")
if cache_k.shape != (1, cache_length, n_kv_heads, dim):
raise ValueError("unexpected cache_k length")
if cache_v.shape != (1, cache_length, n_kv_heads, dim):
raise ValueError("unexpected cache_v length")
xq_stride = xq.stride()
xk_stride = xk.stride()
xv_stride = xv.stride()
cache_k_stride = cache_k.stride()
cache_v_stride = cache_v.stride()
if xq_stride[3] != 1:
raise ValueError("Each q head must be contiguous")
if xk_stride[3] != 1:
raise ValueError("Each k head must be contiguous")
if xv_stride[3] != 1:
raise ValueError("Each v head must be contiguous")
if cache_k_stride[3] != 1:
raise ValueError("Each cache_k head must be contiguous")
if cache_v_stride[3] != 1:
raise ValueError("Each cache_v head must be contiguous")
n_total_heads = n_q_heads + 2 * n_kv_heads
v_start = n_total_heads - n_kv_heads
k_start = n_q_heads
if out_q is None:
out_q = xq.new_empty(1, n_total_queries, n_q_heads, dim)
out_q_stride: Tuple[int, ...] = (0, n_q_heads * dim, dim, 1)
else:
if out_q.shape != xq.shape:
raise ValueError("Unexpected shape of out_q")
out_q_stride = out_q.stride()
if out_q_stride[3] != 1:
raise ValueError("Each out_q head must be contiguous")
assert out_q is not None
logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // xq.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
device = xq.device
# Move these to the right device, like fmha does.
attn_bias.k_seqinfo.to(device)
attn_bias.q_seqinfo.to(device)
seqstartq = attn_bias.q_seqinfo.seqstart
seqstartk = attn_bias.k_seqinfo.seqstart
seqlenk = attn_bias.k_seqinfo.seqlen
assert internal_dtype in ["", "f32", "f64"]
# experiment with the order of dims here.
_rope_padded_kernel[(logical_bsz, attn_bias.q_seqinfo.max_seqlen, n_total_heads)](
xq,
xk,
xv,
out_q,
cache_k,
cache_v,
seqstartq,
seqstartk,
seqlenk,
theta,
k_start,
v_start,
dim,
xq_stride[1],
xq_stride[2],
xk_stride[1],
xk_stride[2],
xv_stride[1],
xv_stride[2],
cache_k_stride[1],
cache_k_stride[2],
cache_v_stride[1],
cache_v_stride[2],
seqstartq.stride(0),
seqstartk.stride(0),
seqlenk.stride(0),
out_q_stride[1],
out_q_stride[2],
internal_dtype,
const_batch_strides=False,
cache_padding_length=0,
seqlenk_shift=0,
BLOCK_SIZE=BLOCK_SIZE,
adjacents=adjacents,
num_warps=num_warps,
)
return out_q

View File

@@ -0,0 +1,467 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from .common import BaseOperator, get_xformers_operator, register_operator
from .unbind import stack_or_none, unbind
@register_operator
class DualGemmSiluOp(BaseOperator):
OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul")
OPERATOR_CATEGORY = "swiglu"
NAME = "dual_gemm_silu"
@classmethod
# type: ignore
def operator_flop(
cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2
) -> int:
"""NOTE: we neglect the impact of biases / pointwises"""
M, N, K = x.shape[0], w1.shape[0], w1.shape[1]
return M * N * K * 2 * 2
@register_operator
class GemmFusedSumOp(BaseOperator):
OPERATOR = get_xformers_operator("gemm_fused_operand_sum")
OPERATOR_CATEGORY = "swiglu"
NAME = "gemm_fused_operand_sum"
@classmethod
# type: ignore
def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int:
M, N, K = a.shape[0], b.shape[1], a.shape[1]
return M * N * K * 2
class _SwiGLUDecomposedFunc(torch.autograd.Function):
"""
This is just an example implementation with all
operations explicited. This implementation is worse
than pytorch, because pytorch is able to fuse some operations
(eg the linear forward ...) that are decomposed here.
The time measurements were made on the ViT-Giant setting:
- A100/f16
- input: [4440, 1536]
- hidden: [4440, 4096]
"""
NAME = "decomposed"
FORCE_BW_F32 = False
def _silu_backward(dy, x):
# https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483
sigm = 1 / (1 + torch.exp(-x.float()))
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)
# 952us
@classmethod
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
x2 = x @ w2.transpose(-2, -1) + b2 # 275us
x3 = F.silu(x1) # 62us
x4 = x3 * x2 # 90us
x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us
ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5)
return x5
# 1900us
@classmethod
def backward(cls, ctx, dx5):
saved_tensors = ctx.saved_tensors
if cls.FORCE_BW_F32:
dx5 = dx5.float()
saved_tensors = [t.float() for t in ctx.saved_tensors]
x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors
dx4 = dx5 @ w3 # 255us (nn)
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
db3 = dx5.sum(0) # 25us
dx3 = dx4 * x2 # 88us
dx2 = dx4 * x3 # 88us
dx1 = cls._silu_backward(dx3, x1) # 90us
dx = dx2 @ w2 # 260us (nn)
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
db2 = dx2.sum(0) # 50us
dx += dx1 @ w1 # 260us (nn)
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
db1 = dx1.sum(0) # 50us
return (dx, dw1, db1, dw2, db2, dw3, db3)
class _SwiGLUFusedFunc(torch.autograd.Function):
NAME = "fused.py"
@classmethod
@torch.cuda.amp.custom_fwd
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)
x5 = F.linear(x4, w3, b3)
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
ctx.bias = [b1 is not None, b2 is not None, b3 is not None]
return x5
@staticmethod
def _linear_bw(
dy: torch.Tensor, x: torch.Tensor, bias: bool
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if not bias:
return (dy.transpose(-2, -1) @ x), None
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
return dw, db
@classmethod
@torch.cuda.amp.custom_bwd
def backward(cls, ctx, dx5):
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
w1w2 = stack_or_none([w1, w2], dim=0)
dx4 = dx5 @ w3 # 255us (nn)
dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4)
dx1, dx2 = dx1dx2.unbind(1)
del x1, x2, dx4
dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2])
del x4, dx5
if w1w2 is not None:
assert dx1dx2.is_contiguous()
assert w1w2.is_contiguous()
w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]])
dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2
# backward of linear1 + linear2 - packed
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
dw1dw2, db1db2 = cls._linear_bw(
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0]
)
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
if ctx.bias[0]:
db1db2 = db1db2.view([2, dx1.shape[1]])
db1, db2 = torch.unbind(db1db2, dim=0)
else:
db1 = db2 = None
else:
dx = dx2 @ w2 # 260us (nn)
torch.addmm(
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
) # dx += dx1 @ w1
dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1])
dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0])
return (dx, dw1, db1, dw2, db2, dw3, db3)
class SwiGLUOp:
"""Base class for any swiglu operator in :attr:`xformers.ops.swiglu`"""
def __init__(self, op, packed_weights: bool, name: str, constraints):
self.NAME = name
self.PACKED_WEIGHTS = packed_weights
self.op = op
self.constraints = constraints
def supports(self, op: "SwiGLUOpDispatch") -> bool:
if self.PACKED_WEIGHTS and not op.packed_weights:
return False
return all(c(op) for c in self.constraints)
def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
pass
def __str__(self) -> str:
return f"SwiGLUOp:{self.NAME}"
class _ForwardToPythonAutogradFunc(SwiGLUOp):
def supports(self, op: "SwiGLUOpDispatch") -> bool:
# Let's disable autocast in bf16 until this issue is fixed
# https://github.com/pytorch/pytorch/issues/87979
if op.dtype_autocast_gpu == torch.bfloat16:
return False
return super().supports(op)
def __call__(self, *args, **kwargs):
return self.op.apply(*args, **kwargs)
class _ForwardToFunc(SwiGLUOp):
def __call__(self, *args, **kwargs):
return self.op(*args, **kwargs)
def info(self):
if self.op.__name__ == "no_such_operator":
return "not built"
return "available"
def _eager_functional_swiglu(
x: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
w3: torch.Tensor,
b3: torch.Tensor,
) -> torch.Tensor:
x1 = F.linear(x, w1, b1)
x2 = F.linear(x, w2, b2)
hidden = F.silu(x1) * x2
return F.linear(hidden, w3, b3)
@dataclass
class SwiGLUOpDispatch:
"""Dispatcher to automatically select
the best operator in :attr:`xformers.ops.swiglu`
"""
device: Union[torch.device, str]
dtype: torch.dtype
dtype_autocast_gpu: Optional[torch.dtype]
packed_weights: bool
bias_enabled: bool
@property
def op(self) -> SwiGLUOp:
"""Computes the best operator
Returns:
SwiGLUOp: The best operator for the configuration
"""
priorities: Sequence[SwiGLUOp] = [
SwiGLUPackedFusedOp,
SwiGLUFusedOp,
]
for op in priorities:
if op.supports(self):
return op
return SwiGLUEagerOp
@staticmethod
def from_arguments(
x: torch.Tensor,
w1: torch.Tensor,
b1: Optional[torch.Tensor],
w2: torch.Tensor,
b2: Optional[torch.Tensor],
w3: torch.Tensor,
b3: Optional[torch.Tensor],
) -> "SwiGLUOpDispatch":
return SwiGLUOpDispatch(
device=x.device,
dtype=x.dtype,
packed_weights=stack_or_none((w1, w2), dim=0) is not None,
dtype_autocast_gpu=torch.get_autocast_gpu_dtype()
if torch.is_autocast_enabled()
else w1.dtype,
bias_enabled=b1 is not None and b2 is not None and b3 is not None,
)
def _only_sm80(op: SwiGLUOpDispatch) -> bool:
device_type = op.device if isinstance(op.device, str) else op.device.type
return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8
def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool:
HALF_DTYPES = [torch.half, torch.bfloat16]
return op.dtype in HALF_DTYPES or (
op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES
)
def _bias_enabled(op: SwiGLUOpDispatch) -> bool:
return op.bias_enabled
_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc(
_SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled]
)
SwiGLUFusedOp = _ForwardToPythonAutogradFunc(
_SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast]
)
SwiGLUPackedFusedOp = _ForwardToFunc(
get_xformers_operator("swiglu_packedw"),
True,
"fused.p.cpp",
constraints=[_only_sm80, _only_half_or_autocast],
)
SwiGLUEagerOp = _ForwardToFunc(
_eager_functional_swiglu,
False,
"eager",
constraints=[],
)
def _info() -> Dict[str, str]:
return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]}
def swiglu(
x: torch.Tensor,
w1: torch.Tensor,
b1: Optional[torch.Tensor],
w2: torch.Tensor,
b2: Optional[torch.Tensor],
w3: torch.Tensor,
b3: Optional[torch.Tensor],
*,
op: SwiGLUOp = None,
) -> torch.Tensor:
"""
Computes a SwiGLU block given the weights/bias of the 3
linear layers.
- It is recommended to keep ``op=None`` so the best implementation \
available for the inputs will be used.
:Equivalent pytorch code:
.. code-block:: python
x1 = F.linear(x, w1, b1)
x2 = F.linear(x, w2, b2)
hidden = F.silu(x1) * x2
return F.linear(hidden, w3, b3)
:Packing weights:
To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in:
.. code-block:: python
w1, w2 = xformers.ops.unbind(w12, 0)
:Supported hardware:
This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \
(autocast is supported), and will fallback to a functional pytorch \
implementation otherwise.
"""
batch_shape = x.shape[:-1]
x = x.reshape([-1, x.shape[-1]])
if w1.ndim != 2 or w1.shape != w2.shape:
raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}")
if b1 is not None:
if b1.ndim != 1 or b1.shape[0] != w1.shape[0]:
raise ValueError(f"Invalid shapes for b1: {b1.shape}")
if b2 is not None:
if b2.ndim != 1 or b2.shape[0] != w2.shape[0]:
raise ValueError(f"Invalid shapes for b2: {b2.shape}")
if w3.ndim != 2 or w3.shape[1] != w2.shape[0]:
raise ValueError(f"Invalid shape for w3: {w3.shape}")
if b3 is not None:
if b3.ndim != 1 or b3.shape[0] != w3.shape[0]:
raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}")
if op is None:
op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op
if not op.PACKED_WEIGHTS:
return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1])
w1w2 = stack_or_none((w1, w2), dim=0)
if b1 is not None and b2 is not None:
b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0)
if b1b2 is None:
raise NotImplementedError("b1/b2 needs to be properly packed")
else:
b1b2 = None
assert b1 is None and b2 is None
if w1w2 is None:
raise NotImplementedError("w1/w2 needs to be properly packed")
return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1])
class SwiGLU(nn.Module):
"""
A Module that encapsulates the call to :attr:`xformers.ops.swiglu`,
and holds the weights for the 3 linear layers
"""
def __init__(
self,
in_features: int,
hidden_features: int,
out_features: Optional[int] = None,
bias: bool = True,
*,
_pack_weights: bool = True,
) -> None:
"""Create a SwiGLU module
Args:
in_features (int): Number of features of the input
hidden_features (int): Number of hidden features
out_features (Optional[int], optional): Number of features of the input. Defaults to None.
bias (bool, optional): Whether linear layers also include a bias. Defaults to True.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12: Optional[nn.Linear]
if _pack_weights:
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
else:
self.w12 = None
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
self.hidden_features = hidden_features
self.out_features = out_features
self.in_features = in_features
self.op = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes :attr:`swiglu` with the module's weights
Args:
x (torch.Tensor): A Tensor of shape ``[..., in_features]``
Returns:
torch.Tensor: A Tensor of shape ``[..., out_features]``
"""
return swiglu(x, *self._ordered_params(), op=self.op)
def _ordered_params(self):
"""Used for testing - returns ordered arguments for operators"""
b1: Optional[torch.Tensor]
b2: Optional[torch.Tensor]
if self.w12 is not None:
w1w2 = self.w12.weight
b1b2 = self.w12.bias
w1, w2 = unbind(
w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
dim=0,
)
if b1b2 is not None:
b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
else:
b1, b2 = None, None
else:
w1, w2 = self.w1.weight, self.w2.weight
b1, b2 = self.w1.bias, self.w2.bias
return [
w1,
b1,
w2,
b2,
self.w3.weight,
self.w3.bias,
]

View File

@@ -0,0 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,158 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
if hasattr(tl, "libdevice"):
tl_math = tl.libdevice
else:
tl_math = tl.math
@triton.jit
def _rms_norm_kernel(
x_ptr,
h1_ptr,
w_ptr,
eps,
stride,
N_COLS,
BLOCK_SIZE: tl.constexpr,
INCLUDE_WEIGHT: tl.constexpr,
):
row = tl.program_id(0)
x_ptr += row * stride
h1_ptr += row * stride
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
a = tl.load(
x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
).to(tl.float32)
_mean += a * a
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
a = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
if INCLUDE_WEIGHT:
w = tl.load(w_ptr + cols, mask=mask)
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
else:
tl.store(h1_ptr + cols, a * rstd, mask=mask)
@triton.jit
def _rms_norm_add_kernel(
x_ptr,
y_ptr,
h1_ptr,
w_ptr,
eps,
stride,
N_COLS,
BLOCK_SIZE: tl.constexpr,
INCLUDE_WEIGHT: tl.constexpr,
):
row = tl.program_id(0)
x_ptr += row * stride
y_ptr += row * stride
h1_ptr += row * stride
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
ax = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
).to(tl.float32)
ay = tl.load(
y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
a = ax + ay
tl.store(x_ptr + cols, a, mask=mask)
_mean += a * a
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
for offset in range(0, N_COLS, BLOCK_SIZE):
cols = offset + tl.arange(0, BLOCK_SIZE)
mask = cols < N_COLS
a = tl.load(
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
).to(tl.float32)
if INCLUDE_WEIGHT:
w = tl.load(w_ptr + cols, mask=mask)
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
else:
tl.store(h1_ptr + cols, a * rstd, mask=mask)
def _rms_norm_forward(x, attn_norm_weights, eps):
if not x.is_contiguous():
raise ValueError("data must be contiguous")
if attn_norm_weights is not None:
if not attn_norm_weights.is_contiguous():
raise ValueError("weights must be contiguous")
out = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_rms_norm_kernel[(M,)](
x_arg,
out,
attn_norm_weights,
eps,
x_arg.stride(0),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
INCLUDE_WEIGHT=attn_norm_weights is not None,
)
return out
def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
# x, y contiguous of same shape [..., n]
# output of same shape, normed over the last dim.
if not x.is_contiguous():
raise ValueError("x must be contiguous")
if not y.is_contiguous():
raise ValueError("y must be contiguous")
if attn_norm_weights is not None:
if not attn_norm_weights.is_contiguous():
raise ValueError("weights must be contiguous")
out = torch.empty_like(x)
x_arg = x.reshape(-1, x.shape[-1])
y_arg = y.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
_rms_norm_add_kernel[(M,)](
x_arg,
y_arg,
out,
attn_norm_weights,
eps,
x_arg.stride(0),
N,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
INCLUDE_WEIGHT=attn_norm_weights is not None,
)
return out

View File

@@ -0,0 +1,161 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import triton # type: ignore
import triton.language as tl # type: ignore
if hasattr(tl, "libdevice"):
tl_math = tl.libdevice
else:
tl_math = tl.math
@triton.jit
def _rope_padded_kernel(
xq,
xk,
xv,
out_q,
cache_k,
cache_v,
seqstartq,
seqstartk,
seqlenk,
theta,
k_start: tl.constexpr,
v_start: tl.constexpr,
dim: tl.constexpr, # dimension of each head
stride_xqM,
stride_xqH,
stride_xkM,
stride_xkH,
stride_xvM,
stride_xvH,
stride_cachekM,
stride_cachekH,
stride_cachevM,
stride_cachevH,
stride_seqstartq,
stride_seqstartk,
stride_seqlenk,
stride_outqM,
stride_outqH,
internal_dtype: tl.constexpr,
# If True, seqstartq and seqstartk are not used but rather we
# assume that every batch element has the same number of
# queries (i.e. num_queries := tl.num_programs(1) )
# and the same cache space cache_padding_length.
# Always False when called below.
const_batch_strides: tl.constexpr,
# If const_batch_strides==True, the common cache length for each batch element.
# (Only the first seqlenk[i] elements are actually in use, and only the last
# num_queries of those are actually written to.)
cache_padding_length,
# offset added to all values in seqlenk before using them.
# Always 0 when called below.
seqlenk_shift: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
adjacents: tl.constexpr,
):
"""
Each letter in this diagram is a whole row of length dim.
INPUT xq xk xv
head_dim ─►
batch qqqqqq kk vv
│ qqqqqq kk vv
▼ qqqqqq kk vv
head_idx: (goes across all heads of all 3 inputs)
▲ ▲ ▲ ▲ ▲ ▲
│ │ │ │ │ │
│ │
0 k_start │v_start │n_total_heads
│ │
│ │
k_start v_start
Output is to out_q (same shape as xq), an xk-shaped part
of cache_k and an xv-shaped part of cache_v
"""
batch_elt = tl.program_id(0)
query_pos_in_batch_elt = tl.program_id(1)
head_idx = tl.program_id(2)
if internal_dtype == "f32":
theta = theta.to(tl.float32)
elif internal_dtype == "f64":
theta = theta.to(tl.float64)
if const_batch_strides:
query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt
end_query_pos = tl.num_programs(1) * (batch_elt + 1)
else:
query_pos = query_pos_in_batch_elt + tl.load(
seqstartq + batch_elt * stride_seqstartq
)
end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq)
if query_pos >= end_query_pos:
return
is_q = head_idx < k_start
is_v = head_idx >= v_start
xq += query_pos * stride_xqM + head_idx * stride_xqH
out_q += query_pos * stride_outqM + head_idx * stride_outqH
if const_batch_strides:
cache_start = cache_padding_length * batch_elt
else:
cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk)
end_of_batch_elt_cache = (
cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift
)
cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
seq_pos = cache_pos - cache_start
cache_k += (head_idx - k_start) * stride_cachekH + cache_pos * stride_cachekM
xk += query_pos * stride_xkM + (head_idx - k_start) * stride_xkH
in_qk = tl.where(is_q, xq, xk)
out_qk = tl.where(is_q, out_q, cache_k)
cache_v += (head_idx - v_start) * stride_cachevH + cache_pos * stride_cachevM
xv += query_pos * stride_xvM + (head_idx - v_start) * stride_xvH
out = tl.where(is_v, cache_v, out_qk)
x_in = tl.where(is_v, xv, in_qk)
for offset in range(0, dim // 2, BLOCK_SIZE // 2):
c = tl.arange(0, BLOCK_SIZE // 2)
powers = (offset + c) * 2.0
if adjacents:
cols_re = (offset + c) * 2
cols_im = cols_re + 1
else:
cols_re = offset + c
cols_im = cols_re + dim // 2
mask = cols_im < dim
re_x = tl.load(x_in + cols_re, mask=mask)
im_x = tl.load(x_in + cols_im, mask=mask)
# freqs = seq_pos / (theta ** (powers / dim))
freqs = seq_pos * tl_math.pow(theta, powers / (-dim))
sines = tl.sin(freqs)
cosines = tl.cos(freqs)
re_out = re_x * cosines - im_x * sines
im_out = im_x * cosines + re_x * sines
re_out_ = tl.where(is_v, re_x, re_out)
im_out_ = tl.where(is_v, im_x, im_out)
if internal_dtype == "f64":
if re_x.dtype == tl.bfloat16:
# triton 2.0.0 crashes if you try to convert
# float64 directly to bfloat16, so make an intermediate step.
re_out_ = re_out_.to(tl.float32)
im_out_ = im_out_.to(tl.float32)
tl.store(out + cols_re, re_out_, mask=mask)
tl.store(out + cols_im, im_out_, mask=mask)

125
pkgs/xformers/ops/unbind.py Normal file
View File

@@ -0,0 +1,125 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from .common import _get_storage_base
def get_stack_strides(
tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Tuple[int, ...]]:
"""
If the tensors are already stacked on dimension :code:`dim`, \
returns the strides of the stacked tensors. \
Otherwise returns :code:`None`.
"""
if len(tensors) <= 1 or dim > tensors[0].ndim:
return None
final_stride = []
for i in range(tensors[0].ndim + 1):
if i == dim:
final_stride.append(
tensors[1].storage_offset() - tensors[0].storage_offset()
)
continue
if i > dim:
i -= 1
final_stride.append(tensors[0].stride(i))
storage_data_ptr: Optional[int] = None
for i, x in enumerate(tensors[1:]):
# Sanity checks
if x.shape != tensors[0].shape:
return None
if x.stride() != tensors[0].stride():
return None
if (
x.storage_offset()
!= tensors[0].storage_offset() + (i + 1) * final_stride[dim]
):
return None
if storage_data_ptr is None:
storage_data_ptr = _get_storage_base(tensors[0])
# Actual storage check
if _get_storage_base(x) != storage_data_ptr:
return None
return tuple(final_stride)
def _stack_or_none_fw(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
dim: int,
) -> Optional[torch.Tensor]:
strides = get_stack_strides(tensors, dim)
if strides is not None:
input_shape = list(tensors[0].shape)
input_shape.insert(dim, len(tensors))
return tensors[0].as_strided(input_shape, strides)
return None
def _stack_fw(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
dim: int,
) -> torch.Tensor:
out = _stack_or_none_fw(tensors, dim)
if out is None:
out = torch.stack(tensors, dim=dim)
return out
class _Unbind(torch.autograd.Function):
"""
See function `unbind`
"""
@staticmethod
# type: ignore
def forward(ctx, x: torch.Tensor, dim: int):
ctx.dim = dim
return x.unbind(dim)
@classmethod
# type: ignore
def backward(cls, ctx, *tensors: torch.Tensor):
return _stack_fw(tensors, ctx.dim), None
class _StackOrNone(torch.autograd.Function):
"""
See function `stack_or_none`
"""
@staticmethod
# type: ignore
def forward(ctx, dim: int, *tensors: torch.Tensor):
ctx.dim = dim
return _stack_or_none_fw(tensors, dim=dim)
@classmethod
# type: ignore
def backward(cls, ctx, grad: torch.Tensor):
return (None, *grad.unbind(dim=ctx.dim))
def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
"""
Does exactly the same as :attr:`torch.unbind` for the forward.
In backward, avoids a :attr:`torch.cat` if the gradients
are already multiple views of the same storage
"""
return _Unbind.apply(x, dim)
def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
"""
Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
without any memory operation. Otherwise returns None.
"""
return _StackOrNone.apply(dim, *tensors)