First commit
This commit is contained in:
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple, Type, Union
|
||||
import os
|
||||
import torch
|
||||
|
||||
from . import cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
AttentionOp,
|
||||
AttentionOpBase,
|
||||
AttentionOpDispatch,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
bmk2bmhk,
|
||||
)
|
||||
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
|
||||
|
||||
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
|
||||
TritonFlashAttentionOp = (triton.FwOp, triton.BwOp)
|
||||
|
||||
|
||||
class _fMHA(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
|
||||
inp = Inputs(*args)
|
||||
op_fw = op[0] if op is not None else None
|
||||
op_bw = op[1] if op is not None else None
|
||||
|
||||
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
inp=inp, op=op_fw
|
||||
)
|
||||
|
||||
# Saving attn_bias is a bit complicated, as the
|
||||
# torch part should go in `save_for_backward`
|
||||
if isinstance(inp.attn_bias, torch.Tensor):
|
||||
attn_bias_tensor = inp.attn_bias
|
||||
attn_bias_ctx = None
|
||||
else:
|
||||
attn_bias_tensor = None
|
||||
attn_bias_ctx = inp.attn_bias
|
||||
|
||||
ctx.save_for_backward(
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
op_ctx.q_padded,
|
||||
op_ctx.k_padded,
|
||||
op_ctx.v_padded,
|
||||
op_ctx.o_padded,
|
||||
op_ctx.out,
|
||||
op_ctx.lse,
|
||||
)
|
||||
ctx.rng_state = op_ctx.rng_state
|
||||
ctx.attn_bias_tensor = attn_bias_tensor
|
||||
if op_ctx.op_bw is not None:
|
||||
if op_bw is not None and op_bw is not op_ctx.op_bw:
|
||||
raise ValueError(
|
||||
f"Specified op_bw={op_bw.NAME}, but forward op "
|
||||
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
|
||||
)
|
||||
op_bw = op_ctx.op_bw
|
||||
ctx.op_fw = op_fw
|
||||
ctx.op_bw = op_bw
|
||||
ctx.p = inp.p
|
||||
ctx.use_alibi = inp.use_alibi
|
||||
ctx.alibi_mode = inp.alibi_mode
|
||||
ctx.imp_mode = inp.imp_mode
|
||||
|
||||
ctx.scale = inp.scale
|
||||
ctx.attn_bias_ctx = attn_bias_ctx
|
||||
ctx.n_args = len(args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def deserialize_bias(
|
||||
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
|
||||
) -> Any:
|
||||
if attn_bias_tensor is None:
|
||||
return attn_bias_ctx
|
||||
return attn_bias_tensor
|
||||
|
||||
@classmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(cls, ctx, grad):
|
||||
# Re-create context
|
||||
query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors
|
||||
attn_bias_tensor = ctx.attn_bias_tensor
|
||||
rng_state = ctx.rng_state
|
||||
inp = Inputs(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
|
||||
p=ctx.p,
|
||||
scale=ctx.scale,
|
||||
use_alibi=ctx.use_alibi,
|
||||
alibi_mode=ctx.alibi_mode,
|
||||
imp_mode = ctx.imp_mode,
|
||||
)
|
||||
op_ctx = Context(
|
||||
lse=lse,
|
||||
out=out,
|
||||
q_padded=q_padded,
|
||||
k_padded=k_padded,
|
||||
v_padded=v_padded,
|
||||
o_padded=o_padded,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
grads = _memory_efficient_attention_backward(
|
||||
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
|
||||
)
|
||||
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
|
||||
ctx.n_args - 2
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[AttentionOp] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Implements the memory-efficient attention mechanism following
|
||||
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
|
||||
|
||||
:Inputs shape:
|
||||
|
||||
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
|
||||
the sequence length, H the number of heads, and K the embeding size per head
|
||||
|
||||
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
|
||||
|
||||
- Inputs can also be of dimension 5 with GQA - see note below
|
||||
|
||||
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
|
||||
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = query @ key.transpose(-2, -1)
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
attn = F.dropout(attn, p)
|
||||
return attn @ value
|
||||
|
||||
:Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
# Compute regular attention
|
||||
y = xops.memory_efficient_attention(q, k, v)
|
||||
|
||||
# With a dropout of 0.2
|
||||
y = xops.memory_efficient_attention(q, k, v, p=0.2)
|
||||
|
||||
# Causal attention
|
||||
y = xops.memory_efficient_attention(
|
||||
q, k, v,
|
||||
attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
:Supported hardware:
|
||||
|
||||
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
|
||||
|
||||
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
|
||||
|
||||
MQA/GQA is an experimental feature supported only for the forward pass.
|
||||
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
|
||||
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
|
||||
``H`` is the number of heads per group (8 in the example).
|
||||
|
||||
Please note that xFormers will not automatically broadcast the inputs, so you will need
|
||||
to broadcast it manually before calling `memory_efficient_attention`.
|
||||
|
||||
:GQA/MQA example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import xformers.ops as xops
|
||||
|
||||
B, M, K = 3, 32, 128
|
||||
kwargs = dict(device="cuda", dtype=torch.float16)
|
||||
q = torch.randn([B, M, 8, K], **kwargs)
|
||||
k = torch.randn([B, M, 2, K], **kwargs)
|
||||
v = torch.randn([B, M, 2, K], **kwargs)
|
||||
out_gqa = xops.memory_efficient_attention(
|
||||
q.reshape([B, M, 2, 4, K]),
|
||||
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
)
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if there is no operator available to compute the MHA
|
||||
ValueError: if inputs are invalid
|
||||
|
||||
:parameter query: Tensor of shape ``[B, Mq, H, K]``
|
||||
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
|
||||
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
|
||||
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
|
||||
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
|
||||
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
|
||||
:parameter p: Dropout probability. Disabled if set to ``0.0``
|
||||
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
|
||||
scale (q.shape[-1]**-0.5) will be used.
|
||||
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
|
||||
If set to ``None`` (recommended), xFormers \
|
||||
will dispatch to the best available operator, depending on the inputs \
|
||||
and options.
|
||||
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
|
||||
"""
|
||||
return _memory_efficient_attention(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
|
||||
"""
|
||||
return _memory_efficient_attention_forward(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward_requires_grad(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
|
||||
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
return out, ctx.lse
|
||||
|
||||
|
||||
def memory_efficient_attention_backward(
|
||||
grad: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionBwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the gradient of the attention.
|
||||
Returns a tuple (dq, dk, dv)
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
|
||||
`lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
gradients = _memory_efficient_attention_backward(
|
||||
Context(out=output, lse=lse),
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
grad,
|
||||
op=op,
|
||||
)
|
||||
return (gradients.dq, gradients.dk, gradients.dv)
|
||||
|
||||
|
||||
def _memory_efficient_attention(
|
||||
inp: Inputs, op: Optional[AttentionOp] = None
|
||||
) -> torch.Tensor:
|
||||
# fast-path that doesn't require computing the logsumexp for backward computation
|
||||
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
|
||||
return _memory_efficient_attention_forward(
|
||||
inp, op=op[0] if op is not None else None
|
||||
)
|
||||
|
||||
output_shape = inp.normalize_bmhk()
|
||||
return _fMHA.apply(
|
||||
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode
|
||||
).reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> torch.Tensor:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, False)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
|
||||
out, *_ = op.apply(inp, needs_gradient=False)
|
||||
return out.reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward_requires_grad(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> Tuple[torch.Tensor, Context]:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, True)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
out = op.apply(inp, needs_gradient=True)
|
||||
assert out[1] is not None
|
||||
return (out[0].reshape(output_shape), out[1])
|
||||
|
||||
|
||||
def _memory_efficient_attention_backward(
|
||||
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
|
||||
) -> Gradients:
|
||||
"""Warning: grad/ctx.out is potentially in BMK format"""
|
||||
inp.validate_inputs()
|
||||
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
|
||||
raise ValueError(
|
||||
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
|
||||
f"grad.shape : {grad.shape} \n"
|
||||
f"out.shape : {ctx.out.shape} \n"
|
||||
f"query.shape: {inp.query.shape}"
|
||||
)
|
||||
shape_dq, shape_dk, shape_dv = tuple(
|
||||
x.shape for x in (inp.query, inp.key, inp.value)
|
||||
)
|
||||
inp.normalize_bmhk()
|
||||
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
|
||||
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
if (
|
||||
ctx.lse.ndim != 3
|
||||
# Dim 0
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.query.shape[0]
|
||||
)
|
||||
or (
|
||||
isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
|
||||
)
|
||||
# Dim 1
|
||||
or ctx.lse.shape[1] != inp.query.shape[2]
|
||||
# Dim 2
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[2] < inp.query.shape[1]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Input tensors have incompatible shapes."
|
||||
f"lse.shape : {ctx.lse.shape} \n"
|
||||
f"query.shape : {inp.query.shape}"
|
||||
)
|
||||
grad = bmk2bmhk(grad, 1)
|
||||
ctx.out = bmk2bmhk(ctx.out, 1)
|
||||
|
||||
if op is None:
|
||||
op = _dispatch_bw(inp)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(
|
||||
ValueError, "memory_efficient_attention_backward", op, inp
|
||||
)
|
||||
|
||||
grads = op.apply(ctx, inp, grad)
|
||||
grads.dq = grads.dq.reshape(shape_dq)
|
||||
grads.dk = grads.dk.reshape(shape_dk)
|
||||
grads.dv = grads.dv.reshape(shape_dv)
|
||||
return grads
|
||||
|
||||
|
||||
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
|
||||
cutlass.FwOp,
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
small_k.FwOp,
|
||||
triton_splitk.FwOp,
|
||||
]
|
||||
|
||||
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
|
||||
cutlass.BwOp,
|
||||
flash.BwOp,
|
||||
triton.BwOp,
|
||||
small_k.BwOp,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"AttentionBias",
|
||||
"AttentionOp",
|
||||
"AttentionOpBase",
|
||||
"AttentionOpDispatch",
|
||||
"LowerTriangularMask",
|
||||
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionTritonFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionCutlassOp",
|
||||
"MemoryEfficientAttentionFlashAttentionOp",
|
||||
"MemoryEfficientAttentionOp",
|
||||
"TritonFlashAttentionOp",
|
||||
"memory_efficient_attention",
|
||||
"ALL_FW_OPS",
|
||||
"ALL_BW_OPS",
|
||||
]
|
||||
BIN
pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc
Normal file
Binary file not shown.
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
@@ -0,0 +1,779 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionBias:
|
||||
"""Base class for a custom bias that can be applied \
|
||||
as the attn_bias argument in
|
||||
:attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
That function has the ability to add a tensor, the
|
||||
attention bias, to the QK^T matrix before it is used
|
||||
in the softmax part of the attention calculation.
|
||||
The attention bias tensor with shape
|
||||
(B or 1, n_queries, number of keys)
|
||||
can be given as the attn_bias input.
|
||||
The most common use case is for an attention bias is
|
||||
to contain only zeros and negative infinities, which forms
|
||||
a mask so that some queries only attend to some keys.
|
||||
|
||||
Children of this class define alternative things which can
|
||||
be used as the attn_bias input to define an attention bias which
|
||||
forms such a mask, for some common cases.
|
||||
|
||||
When using an :attr:`xformers.ops.AttentionBias`
|
||||
instead of a :attr:`torch.Tensor`, the mask matrix does
|
||||
not need to be materialized, and can be
|
||||
hardcoded into some kernels for better performance.
|
||||
|
||||
See:
|
||||
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
|
||||
|
||||
"""
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Materializes the bias as a `torch.Tensor`. This is very slow
|
||||
and we don't attempt to make it fast. Only use for debugging/testing.
|
||||
|
||||
Shape should be like `[*, q_seqlen, k_seqlen]`
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LowerTriangularMask(AttentionBias):
|
||||
"""
|
||||
A lower-triangular (aka causal) mask
|
||||
|
||||
A query Q cannot attend to a key which is farther from the
|
||||
initial key than Q is from the initial query.
|
||||
"""
|
||||
|
||||
def __init__(self, *tensor_args, **tensor_kwargs) -> None:
|
||||
# NOTE: Unused arguments, we keep them for backward compatibility
|
||||
super().__init__()
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
|
||||
|
||||
def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
|
||||
return LowerTriangularMaskWithTensorBias(bias)
|
||||
|
||||
|
||||
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
|
||||
"""A lower-triangular (aka causal) mask with an additive bias"""
|
||||
|
||||
def __init__(self, bias: torch.Tensor) -> None:
|
||||
self._bias = bias
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return super().materialize(shape, dtype=dtype, device=device) + self._bias
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
|
||||
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
|
||||
f" seqstart: {self.seqstart_py}"
|
||||
)
|
||||
if batch_sizes is None:
|
||||
batch_sizes = [1] * (len(self.seqstart_py) - 1)
|
||||
split_chunks = []
|
||||
it = 0
|
||||
for batch_size in batch_sizes:
|
||||
split_chunks.append(
|
||||
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
|
||||
)
|
||||
it += batch_size
|
||||
return [
|
||||
tensor.reshape([bs, -1, *tensor.shape[2:]])
|
||||
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PaddedSeqLenInfo(_SeqLenInfo):
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks which are
|
||||
padded out to the same total length.
|
||||
|
||||
For example, to represent a dimension of length 12 with space for
|
||||
three blocks of length 4, but where the occupied lengths are
|
||||
2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
|
||||
|
||||
The layout along the dimension is
|
||||
|
||||
0 ─► block 0
|
||||
block 0
|
||||
<space>
|
||||
<space>
|
||||
4 ─► block 1
|
||||
block 1
|
||||
block 1
|
||||
<space>
|
||||
8 ─► block 2
|
||||
block 2
|
||||
<space>
|
||||
<space>
|
||||
12 ─►
|
||||
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 4, 8, 12]
|
||||
seqstart: torch.IntTensor([0, 4, 8, 12])
|
||||
seqlen_py: [2, 3, 2]
|
||||
seqlen: torch.IntTensor([2, 3, 2])
|
||||
padding: 4
|
||||
"""
|
||||
|
||||
seqlen: torch.Tensor
|
||||
seqlen_py: Sequence[int]
|
||||
padding: int
|
||||
# From parent: seqstart[i] contains the start position
|
||||
# of the i-th sequence
|
||||
# seqstart: torch.Tensor
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert len(self.seqstart_py) == len(self.seqlen_py) + 1
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqlen = self.seqlen.to(device, non_blocking=True)
|
||||
super().to(device)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
for (start, _), length in zip(super().intervals(), self.seqlen_py):
|
||||
yield start, start + length
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
raise RuntimeError(
|
||||
"Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens_padded(
|
||||
cls, seqlens: Sequence[int], padding: int
|
||||
) -> "_PaddedSeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
seqstart = padding * torch.arange(batch_size)
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
assert all(seqlen <= padding for seqlen in seqlens)
|
||||
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
|
||||
return cls(
|
||||
seqlen=torch.tensor(seqlens, dtype=torch.int32),
|
||||
seqlen_py=seqlens,
|
||||
max_seqlen=max(seqlens),
|
||||
min_seqlen=min(seqlens),
|
||||
seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
|
||||
seqstart_py=seqstart_py,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
raise NotImplementedError("_PaddedSeqLenInfo.split")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalMask(AttentionBias):
|
||||
"""
|
||||
A block-diagonal mask that can be passed as ``attn_bias``
|
||||
argument to :attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
Queries in block i only attend to keys in block i.
|
||||
|
||||
.. figure:: /_static/block_diag_bias.png
|
||||
|
||||
This bias can be used to handle a batch of sequences of
|
||||
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
|
||||
|
||||
:Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from xformers.ops import fmha
|
||||
|
||||
K = 16
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
list_x = [
|
||||
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
|
||||
]
|
||||
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
|
||||
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
|
||||
|
||||
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
|
||||
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
list_out = attn_bias.split(out)
|
||||
print(list_out[0].shape) # [1, 3, 1, K]
|
||||
assert tuple(list_out[0].shape) == (1, 3, 1, K)
|
||||
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _SeqLenInfo
|
||||
_batch_sizes: Optional[Sequence[int]] = None
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
|
||||
shape[-1],
|
||||
self.k_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
|
||||
shape[-2],
|
||||
self.q_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_seqlen: Optional[Sequence[int]] = None,
|
||||
) -> "BlockDiagonalMask":
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
|
||||
kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
|
||||
(Defaults to ``q_seqlen``.)
|
||||
Returns:
|
||||
BlockDiagonalMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
if kv_seqlen is None or q_seqlen == kv_seqlen:
|
||||
k_seqinfo = q_seqinfo
|
||||
else:
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
@classmethod
|
||||
def from_tensor_list(
|
||||
cls,
|
||||
tensors: Sequence[torch.Tensor],
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
|
||||
concatenated on the sequence length dimension
|
||||
|
||||
.. figure:: /_static/block_diag_cat_split.png
|
||||
|
||||
See also :attr:`BlockDiagonalMask.split` to split the returned
|
||||
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
|
||||
|
||||
Args:
|
||||
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
|
||||
All tensors should have the same dimension and the same batch size ``B``, but
|
||||
they can have different sequence length ``M``.
|
||||
|
||||
Returns:
|
||||
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
|
||||
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
|
||||
"""
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors]
|
||||
seqlens = []
|
||||
for x in tensors:
|
||||
for _ in range(x.shape[0]):
|
||||
seqlens.append(x.shape[1])
|
||||
block_diag = cls.from_seqlens(seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
|
||||
concat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
return block_diag, concat_tensors
|
||||
|
||||
@classmethod
|
||||
def from_tensor_lists_qkv(
|
||||
cls,
|
||||
tensors_q: Sequence[torch.Tensor],
|
||||
tensors_k: Sequence[torch.Tensor],
|
||||
tensors_v: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert len(tensors_q) == len(tensors_k)
|
||||
assert tensors_v is None or len(tensors_v) == len(tensors_q)
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
|
||||
q_seqlens, kv_seqlens = [], []
|
||||
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
|
||||
assert q.shape[0] == k.shape[0]
|
||||
q_seqlens += [q.shape[1]] * q.shape[0]
|
||||
kv_seqlens += [k.shape[1]] * k.shape[0]
|
||||
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
|
||||
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
return (
|
||||
block_diag,
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
|
||||
if tensors_v is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.k_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
|
||||
|
||||
Returns:
|
||||
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
|
||||
"""
|
||||
assert self.q_seqinfo is self.k_seqinfo
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def make_causal(self) -> "BlockDiagonalCausalMask":
|
||||
"""Makes each block causal"""
|
||||
return BlockDiagonalCausalMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
|
||||
"""Makes each block causal with a possible non-causal prefix"""
|
||||
return BlockDiagonalCausalFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_local_attention(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionMask":
|
||||
"""Experimental: Makes each block causal with local attention"""
|
||||
return BlockDiagonalCausalLocalAttentionMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
def make_local_attention_from_bottomright(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
|
||||
"""Experimental: Makes each block causal with local attention, start from bottom right"""
|
||||
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is farther from the initial key in block i than Q
|
||||
is from the initial query in block i.
|
||||
"""
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return LowerTriangularMask().materialize(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
This mask allows for a non-causal prefix
|
||||
NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
|
||||
defined (softmax of vector of `-inf` in the attention)
|
||||
|
||||
Queries and keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which nearer the final key in block i than Q is to the
|
||||
final query in block i.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
num_queries = q_end - q_start
|
||||
num_keys = k_end - k_start
|
||||
if num_keys < num_queries:
|
||||
raise ValueError(
|
||||
f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
|
||||
" Expected `num_keys >= num_queries`"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
|
||||
except an offset on causality is allowed for each block and we support padding for k/v
|
||||
|
||||
The keys and values are divided into blocks which are padded out to
|
||||
the same total length.
|
||||
For example, if there is space for 12 keys, for three blocks of
|
||||
max length 4, but we only want to use the first 2, 3 and 2
|
||||
of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
|
||||
The queries are divided into blocks, without padding, of lengths given by
|
||||
q_seqlen.
|
||||
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is not in use (i.e. in the padded area),
|
||||
nor one which is nearer to the final key in block i
|
||||
than Q is to the final query in block i.
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _PaddedSeqLenInfo
|
||||
causal_diagonal: Any = None # unused. Exists for BC only.
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("k shapes wrong")
|
||||
if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("q shapes wrong")
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_padding: int,
|
||||
kv_seqlen: Sequence[int],
|
||||
causal_diagonal: Any = None,
|
||||
) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
|
||||
"""Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
|
||||
lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
|
||||
kv_padding (int): Padding for k/v - also an upperbound on each individual key length
|
||||
kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
|
||||
causal_diagonal: unused, for BC only
|
||||
Returns:
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
|
||||
if self._window_size is not None and self._window_size > 0:
|
||||
mask = torch.triu(mask, diagonal=-self._window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
BlockDiagonalCausalFromBottomRightMask
|
||||
):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q + (q - k) - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
|
||||
if torch.max(materialized, dim=1).values.min() == -float("inf"):
|
||||
raise RuntimeError("FUCKING FUCK FUCK")
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
|
||||
if self._window_size is not None:
|
||||
mask = torch.triu(
|
||||
mask, diagonal=num_keys - num_queries - self._window_size + 1
|
||||
)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
542
pkgs/xformers/ops/fmha/common.py
Normal file
542
pkgs/xformers/ops/fmha/common.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..._cpp_lib import _built_with_cuda
|
||||
from ..common import BaseOperator
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
|
||||
|
||||
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
|
||||
# NoneType
|
||||
if isinstance(None, attn_bias_type):
|
||||
return True
|
||||
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Inputs:
|
||||
"""
|
||||
Stores inputs to the `memory_efficient_attention` operators
|
||||
"""
|
||||
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
|
||||
p: float = 0.0
|
||||
scale: Optional[float] = None
|
||||
use_alibi: bool = False
|
||||
alibi_mode: int = 1
|
||||
imp_mode: int = 0
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.query.device
|
||||
|
||||
@property
|
||||
def scale_float(self) -> float:
|
||||
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
|
||||
|
||||
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.query.ndim == 5:
|
||||
return self.query, self.key, self.value
|
||||
if self.query.ndim == 4:
|
||||
return (
|
||||
self.query.unsqueeze(2),
|
||||
self.key.unsqueeze(2),
|
||||
self.value.unsqueeze(2),
|
||||
)
|
||||
if self.value.ndim == 3:
|
||||
return (
|
||||
self.query[:, :, None, None],
|
||||
self.key[:, :, None, None],
|
||||
self.value[:, :, None, None],
|
||||
)
|
||||
assert False
|
||||
|
||||
def normalize_bmhk(self) -> Tuple[int, ...]:
|
||||
if self.query.ndim not in [3, 4, 5]:
|
||||
raise ValueError(
|
||||
f"Invalid shape for query: {self.query.shape}. "
|
||||
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
|
||||
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
|
||||
)
|
||||
if self.value.dtype == torch.int32:
|
||||
# Quantized K/V case, in which the last dims of Q and K are different.
|
||||
# NB we currently don't have any implementations for quantized KV with
|
||||
# SUPPORTS_DIFFERENT_VALUE_EMBED.
|
||||
output_shape = tuple(self.query.shape)
|
||||
else:
|
||||
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
|
||||
# Convert from legacy format
|
||||
if self.query.ndim == 3:
|
||||
self.query = self.query.unsqueeze(2)
|
||||
self.key = self.key.unsqueeze(2)
|
||||
self.value = self.value.unsqueeze(2)
|
||||
if isinstance(self.attn_bias, torch.Tensor):
|
||||
if self.attn_bias.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
|
||||
)
|
||||
self.attn_bias = self.attn_bias.unsqueeze(1)
|
||||
return output_shape
|
||||
|
||||
def validate_inputs(self) -> None:
|
||||
qkv = (self.query, self.key, self.value)
|
||||
if self.query.ndim not in (3, 4, 5) or any(
|
||||
x.ndim != self.query.ndim for x in qkv
|
||||
):
|
||||
raise ValueError(
|
||||
f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if any(x.device != self.query.device for x in qkv):
|
||||
raise ValueError("Query/Key/Value should all be on the same device")
|
||||
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
|
||||
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
|
||||
if not (quantized_dtypes or non_quantized_dtypes):
|
||||
raise ValueError(
|
||||
"Query/Key/Value should either all have the same dtype, or "
|
||||
"(in the quantized case) Key/Value should have dtype torch.int32\n"
|
||||
f" query.dtype: {self.query.dtype}\n"
|
||||
f" key.dtype : {self.key.dtype}\n"
|
||||
f" value.dtype: {self.value.dtype}"
|
||||
)
|
||||
# Biases with tensors attached are meant to be in BMHK format
|
||||
# This would require to permute biases/gradients which can be expensive,
|
||||
# so let's just forbid it - BMK is a legacy format anyway
|
||||
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
|
||||
type(self.attn_bias)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Please provide inputs in BMHK format rather "
|
||||
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
|
||||
)
|
||||
attn_bias_t: Optional[torch.Tensor] = None
|
||||
if isinstance(self.attn_bias, torch.Tensor):
|
||||
attn_bias_t = self.attn_bias
|
||||
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
|
||||
attn_bias_t = self.attn_bias._bias
|
||||
if self.query.ndim == 4 and attn_bias_t is not None:
|
||||
expected_shape = (
|
||||
self.query.shape[0],
|
||||
self.query.shape[2],
|
||||
self.query.shape[1],
|
||||
self.key.shape[1],
|
||||
)
|
||||
if attn_bias_t.shape != expected_shape:
|
||||
raise ValueError(
|
||||
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if isinstance(self.attn_bias, BlockDiagonalMask):
|
||||
if any(x.shape[0] != 1 for x in qkv):
|
||||
raise ValueError(
|
||||
f"Expected batch_size=1 when using block-diagonal bias\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if self.p < 0.0 or self.p > 1.0:
|
||||
raise ValueError(f"Invalid dropout probability: p={self.p}")
|
||||
# Check that shapes match between inputs
|
||||
B, Mq = self.query.shape[:2]
|
||||
K = self.query.shape[-1]
|
||||
B, Mkv = self.key.shape[:2]
|
||||
Kv = self.value.shape[-1]
|
||||
|
||||
valid_shapes = True
|
||||
if self.query.ndim == 3: # BMK
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, K)
|
||||
and self.key.shape == (B, Mkv, K)
|
||||
and self.value.shape == (B, Mkv, Kv)
|
||||
)
|
||||
H = self.query.shape[-2]
|
||||
if self.query.ndim == 4: # BMHK
|
||||
quantized_kv_cache = self.value.dtype == torch.int32
|
||||
key_embed_dim = Kv if quantized_kv_cache else K
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, H, K)
|
||||
and self.key.shape == (B, Mkv, H, key_embed_dim)
|
||||
and self.value.shape == (B, Mkv, H, Kv)
|
||||
)
|
||||
G = self.query.shape[2]
|
||||
if self.query.ndim == 5: # BMNHK
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, G, H, K)
|
||||
and self.key.shape == (B, Mkv, G, H, K)
|
||||
and self.value.shape == (B, Mkv, G, H, Kv)
|
||||
)
|
||||
if not valid_shapes:
|
||||
raise ValueError(
|
||||
f"Incompatible shapes for attention inputs:\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}\n"
|
||||
"HINT: We don't support broadcasting, please use `expand` "
|
||||
"yourself before calling `memory_efficient_attention` if you need to"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
lse: torch.Tensor
|
||||
out: torch.Tensor
|
||||
q_padded: Optional[torch.Tensor] = None
|
||||
k_padded: Optional[torch.Tensor] = None
|
||||
v_padded: Optional[torch.Tensor] = None
|
||||
o_padded: Optional[torch.Tensor] = None
|
||||
op_bw: Optional[Type["AttentionBwOpBase"]] = None
|
||||
rng_state: Optional[torch.Tensor] = None
|
||||
|
||||
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
|
||||
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
|
||||
lse = self.lse
|
||||
if pad_amount > 0:
|
||||
if force_pad_inf:
|
||||
lse = lse[:, :, : self.out.shape[1]]
|
||||
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
|
||||
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
|
||||
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
|
||||
lse[:, :, self.out.shape[1] :].fill_(math.inf)
|
||||
return lse
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gradients:
|
||||
dq: torch.Tensor
|
||||
dk: torch.Tensor
|
||||
dv: torch.Tensor
|
||||
# bias gradient. None if there is no tensor bias or if it doesn't require grad
|
||||
db: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AttentionOpBase(BaseOperator):
|
||||
"""Base class for any attention operator in xFormers
|
||||
|
||||
See:
|
||||
|
||||
- :attr:`xformers.ops.fmha.cutlass.FwOp`
|
||||
- :attr:`xformers.ops.fmha.cutlass.BwOp`
|
||||
- :attr:`xformers.ops.fmha.flash.FwOp`
|
||||
- :attr:`xformers.ops.fmha.flash.BwOp`
|
||||
- :attr:`xformers.ops.fmha.triton.FwOp`
|
||||
- :attr:`xformers.ops.fmha.triton.BwOp`
|
||||
- :attr:`xformers.ops.fmha.small_k.FwOp`
|
||||
- :attr:`xformers.ops.fmha.small_k.BwOp`
|
||||
"""
|
||||
|
||||
OPERATOR: Any
|
||||
SUPPORTED_DEVICES: Set[str]
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
|
||||
SUPPORTED_DTYPES: Set[torch.dtype]
|
||||
SUPPORTED_MAX_K: float
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
|
||||
SUPPORTS_DROPOUT: bool
|
||||
SUPPORTS_CUSTOM_SCALE: bool = False
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
|
||||
IS_DETERMINISTIC: bool = True
|
||||
SUPPORTS_BMGHK: bool = False
|
||||
NAME: str
|
||||
OPERATOR_CATEGORY = "memory_efficient_attention"
|
||||
|
||||
_TEST_BATCH_SIZES: List[int] = [1, 300]
|
||||
_TEST_K: List[int] = [32, 128]
|
||||
|
||||
@classmethod
|
||||
def supports(cls, d: Inputs) -> bool:
|
||||
return not cls.not_supported_reasons(d)
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = []
|
||||
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
|
||||
reasons.append("query.shape[-1] != value.shape[-1]")
|
||||
if max(K, Kv) > cls.SUPPORTED_MAX_K:
|
||||
reasons.append(
|
||||
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
"""
|
||||
Returns a list of reasons why this is not supported.
|
||||
The kernel can run these inputs only if the returned list is empty
|
||||
"""
|
||||
reasons = cls.shape_not_supported_reasons(
|
||||
Mq=d.query.shape[1],
|
||||
Mkv=d.key.shape[1],
|
||||
K=d.query.shape[-1],
|
||||
Kv=d.query.shape[-1],
|
||||
)
|
||||
device_type = d.query.device.type
|
||||
dtype = d.query.dtype
|
||||
if device_type not in cls.SUPPORTED_DEVICES:
|
||||
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
||||
if device_type == "cuda" and not _built_with_cuda:
|
||||
reasons.append("xFormers wasn't build with CUDA support")
|
||||
if device_type == "cuda":
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
||||
reasons.append(
|
||||
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
|
||||
f"but your GPU has capability {device_capability} (too old)"
|
||||
)
|
||||
if dtype not in cls.SUPPORTED_DTYPES:
|
||||
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
|
||||
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
|
||||
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
|
||||
reasons.append("dropout > 0.0")
|
||||
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
|
||||
reasons.append("has custom scale")
|
||||
# bfloat16 is only supported on A100+
|
||||
# ... although the kernels can still run and give the
|
||||
# correct result
|
||||
if dtype is torch.bfloat16 and (
|
||||
not device_type.startswith("cuda")
|
||||
):
|
||||
reasons.append("bf16 is only supported on A100+ GPUs")
|
||||
if not cls.is_available():
|
||||
reasons.append(
|
||||
"operator wasn't built - see `python -m xformers.info` for more info"
|
||||
)
|
||||
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
|
||||
reasons.append(
|
||||
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
|
||||
)
|
||||
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
|
||||
reasons.append("operator does not support BMGHK format")
|
||||
return reasons
|
||||
|
||||
|
||||
class AttentionFwOpBase(AttentionOpBase):
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 3e-4,
|
||||
torch.half: 4e-3,
|
||||
torch.bfloat16: 2e-2,
|
||||
}
|
||||
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 2e-5,
|
||||
torch.half: 4e-4,
|
||||
torch.bfloat16: 5e-3,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def attn_operator_flop(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
causal: bool = False,
|
||||
seqstart_k: Optional[torch.Tensor] = None,
|
||||
seqstart_q: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Computes total flops for the attention
|
||||
Assumes inputs in format BMHK
|
||||
"""
|
||||
assert query.ndim == 4
|
||||
|
||||
if seqstart_q is not None:
|
||||
seqstart_q_py = seqstart_q.tolist()
|
||||
else:
|
||||
seqstart_q_py = [0, query.shape[1]]
|
||||
if seqstart_k is not None:
|
||||
seqstart_k_py = seqstart_k.tolist()
|
||||
else:
|
||||
seqstart_k_py = [0, key.shape[1]]
|
||||
|
||||
total_flop = 0
|
||||
for q_start, q_end, k_start, k_end in zip(
|
||||
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
|
||||
):
|
||||
num_q = q_end - q_start
|
||||
num_kv = k_end - k_start
|
||||
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
|
||||
# Q @ K.transpose
|
||||
total_flop += num_q * num_kv * query.shape[-1] * 2
|
||||
# (ignore softmax)
|
||||
# attn @ V
|
||||
total_flop += num_q * key.shape[-1] * num_kv * 2
|
||||
# Multiply by num_heads and batches
|
||||
total_flop = total_flop * value.shape[2] * value.shape[0]
|
||||
if causal:
|
||||
total_flop //= 2
|
||||
return total_flop
|
||||
|
||||
|
||||
class AttentionBwOpBase(AttentionOpBase):
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 5e-4,
|
||||
torch.half: 9e-2,
|
||||
torch.bfloat16: 0.7,
|
||||
}
|
||||
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 1e-4,
|
||||
torch.half: 2e-2,
|
||||
torch.bfloat16: 0.1,
|
||||
}
|
||||
SUPPORTS_ATTN_BIAS_GRAD = False
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
|
||||
if (
|
||||
isinstance(d.attn_bias, torch.Tensor)
|
||||
and d.attn_bias.requires_grad
|
||||
and not cls.SUPPORTS_ATTN_BIAS_GRAD
|
||||
):
|
||||
reasons.append(
|
||||
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
|
||||
)
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def attn_operator_flop(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
causal: bool = False,
|
||||
seqstart_k: Optional[torch.Tensor] = None,
|
||||
seqstart_q: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Computes total flops for the attention
|
||||
Assumes inputs in format BMHK
|
||||
"""
|
||||
assert query.ndim == 4
|
||||
|
||||
if seqstart_q is not None:
|
||||
seqstart_q_py = seqstart_q.tolist()
|
||||
else:
|
||||
seqstart_q_py = [0, query.shape[1]]
|
||||
if seqstart_k is not None:
|
||||
seqstart_k_py = seqstart_k.tolist()
|
||||
else:
|
||||
seqstart_k_py = [0, key.shape[1]]
|
||||
|
||||
total_flop = 0
|
||||
for q_start, q_end, k_start, k_end in zip(
|
||||
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
|
||||
):
|
||||
num_q = q_end - q_start
|
||||
num_kv = k_end - k_start
|
||||
Kqk = query.shape[-1]
|
||||
Kv = value.shape[-1]
|
||||
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
|
||||
# att = Q @ K.transpose
|
||||
total_flop += num_q * num_kv * Kqk * 2
|
||||
# att @ dO
|
||||
total_flop += num_kv * num_q * Kv * 2
|
||||
# dov = dO @ V
|
||||
total_flop += num_q * Kv * num_kv * 2
|
||||
# dov @ K
|
||||
total_flop += num_q * Kqk * num_kv * 2
|
||||
# dov @ Q
|
||||
total_flop += num_q * Kqk * num_kv * 2
|
||||
# Multiply by num_heads and batches
|
||||
total_flop = total_flop * value.shape[2] * value.shape[0]
|
||||
if causal:
|
||||
total_flop //= 2
|
||||
return total_flop
|
||||
|
||||
|
||||
AttentionOp = Tuple[
|
||||
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionOpDispatch:
|
||||
"""Dispatcher to automatically select
|
||||
the best operator to run memory-efficient attention.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This class is deprecated and will be removed in a later version
|
||||
"""
|
||||
|
||||
op: AttentionOp
|
||||
|
||||
@classmethod
|
||||
def from_arguments(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
) -> "AttentionOpDispatch":
|
||||
"""Here for backward compatibility"""
|
||||
from .dispatch import _dispatch_bw, _dispatch_fw
|
||||
|
||||
inp = Inputs(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_bias,
|
||||
p=p,
|
||||
scale=scale,
|
||||
)
|
||||
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))
|
||||
|
||||
|
||||
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
|
||||
(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
|
||||
def check_lastdim_alignment_stride1(
|
||||
reasons: List[str], name: str, x: torch.Tensor, alignment: int
|
||||
) -> None:
|
||||
if x.shape[-1] % alignment != 0:
|
||||
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
|
||||
elif x.stride(-2) % alignment != 0:
|
||||
reasons.append(
|
||||
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
|
||||
)
|
||||
# We can have stride=0 sometimes if dimension=1
|
||||
if x.stride(-1) > 1:
|
||||
reasons.append(
|
||||
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
|
||||
)
|
||||
479
pkgs/xformers/ops/fmha/cutlass.py
Normal file
479
pkgs/xformers/ops/fmha/cutlass.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from enum import Enum
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from . import attn_bias
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
|
||||
|
||||
def _uses_tensorcores(sm: int, is_half: bool) -> bool:
|
||||
if sm >= 80:
|
||||
return True
|
||||
if sm >= 70:
|
||||
return is_half
|
||||
return False
|
||||
|
||||
|
||||
def _minimum_gemm_alignment(inp: Inputs) -> int:
|
||||
if inp.device.type != "cuda":
|
||||
return 1
|
||||
cap = torch.cuda.get_device_capability(inp.device)
|
||||
sm = cap[0] * 10 + cap[1]
|
||||
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
|
||||
inp.query.dtype
|
||||
]
|
||||
uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
|
||||
matmul_alignment_mn = 1
|
||||
if sm >= 80:
|
||||
matmul_alignment_mn = 4
|
||||
if uses_tensorcores:
|
||||
matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
|
||||
return matmul_alignment_mn
|
||||
|
||||
|
||||
def _get_seqlen_info(
|
||||
inp: Inputs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
|
||||
attn_bias = inp.attn_bias
|
||||
if isinstance(
|
||||
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
):
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
seqstart_k = attn_bias.k_seqinfo.seqstart
|
||||
seqstart_q = attn_bias.q_seqinfo.seqstart
|
||||
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
||||
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
||||
else:
|
||||
seqstart_k = None
|
||||
seqstart_q = None
|
||||
max_seqlen_q = -1
|
||||
max_seqlen_k = -1
|
||||
|
||||
return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
|
||||
|
||||
|
||||
def _get_tensor_bias(
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if isinstance(attn_bias, torch.Tensor):
|
||||
return attn_bias
|
||||
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
|
||||
return attn_bias._bias
|
||||
return None
|
||||
|
||||
|
||||
def _check_bias_alignment(
|
||||
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> None:
|
||||
attn_bias_tensor = _get_tensor_bias(attn_bias)
|
||||
if attn_bias_tensor is not None:
|
||||
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
|
||||
show_padding_hint = False
|
||||
for d in range(attn_bias_tensor.ndim - 1):
|
||||
if attn_bias_tensor.stride(d) % alignment != 0:
|
||||
reasons.append(
|
||||
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
|
||||
)
|
||||
show_padding_hint = True
|
||||
if show_padding_hint:
|
||||
reasons.append(
|
||||
"""\
|
||||
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
|
||||
you need to ensure memory is aligned by slicing a bigger tensor. \
|
||||
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
|
||||
)
|
||||
# We can have stride=0 sometimes if dimension=1
|
||||
if attn_bias_tensor.stride(-1) > 1:
|
||||
reasons.append(
|
||||
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
|
||||
"you should call `.contiguous()` on the bias"
|
||||
)
|
||||
|
||||
|
||||
class _CustomMaskType(int, Enum):
|
||||
"""
|
||||
(Matches CustomMaskType in C++.)
|
||||
"""
|
||||
|
||||
NoCustomMask = 0
|
||||
CausalFromTopLeft = 1
|
||||
CausalFromBottomRight = 2
|
||||
|
||||
|
||||
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
||||
if isinstance(
|
||||
bias,
|
||||
(
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
),
|
||||
):
|
||||
return int(_CustomMaskType.CausalFromTopLeft)
|
||||
if isinstance(
|
||||
bias,
|
||||
(
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
):
|
||||
return int(_CustomMaskType.CausalFromBottomRight)
|
||||
return int(_CustomMaskType.NoCustomMask)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""xFormers' MHA kernel based on CUTLASS.
|
||||
Supports a large number of settings (including without TensorCores, f32 ...)
|
||||
and GPUs as old as P100 (Sm60)
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 65536
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
torch.Tensor,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = True
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = "cutlassF"
|
||||
|
||||
_TEST_K: List[int] = [
|
||||
32, # 64x64 kernel
|
||||
128, # 64x128 kernel
|
||||
256, # 64x128 with accumulation in gmem
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
if inp.query.ndim in [3, 4]:
|
||||
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
|
||||
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
|
||||
|
||||
# Workaround until this is properly implemented in C++
|
||||
# run each head group in a different stream
|
||||
n_groups = inp.key.shape[2]
|
||||
main_stream = torch.cuda.current_stream()
|
||||
streams = [main_stream] + [
|
||||
torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
|
||||
]
|
||||
outs = []
|
||||
for group, stream in enumerate(streams):
|
||||
stream.wait_stream(main_stream)
|
||||
with torch.cuda.stream(stream):
|
||||
query = inp.query[:, :, group]
|
||||
key = inp.key[:, :, group]
|
||||
value = inp.value[:, :, group]
|
||||
bias = inp.attn_bias
|
||||
if isinstance(bias, torch.Tensor):
|
||||
bias = bias[:, group]
|
||||
if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias):
|
||||
bias = attn_bias.LowerTriangularMaskWithTensorBias(
|
||||
bias._bias[:, group]
|
||||
)
|
||||
outs.append(
|
||||
cls.apply_bmhk(
|
||||
replace(inp, query=query, key=key, value=value, attn_bias=bias),
|
||||
needs_gradient=needs_gradient,
|
||||
)
|
||||
)
|
||||
for s in streams[1:]:
|
||||
main_stream.wait_stream(s)
|
||||
out = torch.stack([o[0] for o in outs], dim=2)
|
||||
ctx: Optional[Context] = None
|
||||
if needs_gradient:
|
||||
ctx = Context(
|
||||
out=out,
|
||||
lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
|
||||
op_bw=outs[0][1].op_bw, # type: ignore
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
@classmethod
|
||||
def apply_bmhk(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
|
||||
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
||||
query=inp.query,
|
||||
key=inp.key,
|
||||
value=inp.value,
|
||||
attn_bias=_get_tensor_bias(inp.attn_bias),
|
||||
seqstart_q=seqstart_q,
|
||||
seqstart_k=seqstart_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
dropout_p=inp.p,
|
||||
compute_logsumexp=needs_gradient,
|
||||
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||||
scale=inp.scale,
|
||||
seqlen_k=inp.attn_bias.k_seqinfo.seqlen
|
||||
if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
else None,
|
||||
window_size=inp.attn_bias._window_size
|
||||
if isinstance(
|
||||
inp.attn_bias,
|
||||
(
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
else None,
|
||||
)
|
||||
ctx: Optional[Context] = None
|
||||
if needs_gradient:
|
||||
ctx = Context(
|
||||
out=out,
|
||||
lse=lse,
|
||||
# cutlass forward is only compatible with cutlass backward if
|
||||
# dropout is used (because of the way RNG states are passed and the
|
||||
# way random numbers are generated during backward)
|
||||
op_bw=BwOp if inp.p != 0 else None,
|
||||
)
|
||||
if inp.p != 0:
|
||||
ctx.rng_state = torch.tensor(
|
||||
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||||
_check_bias_alignment(reasons, d.attn_bias)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
b,
|
||||
seqstart_q,
|
||||
seqstart_k,
|
||||
max_seqlen_q_,
|
||||
compute_lse,
|
||||
custom_mask_type,
|
||||
*a,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
causal=custom_mask_type > 0,
|
||||
seqstart_k=seqstart_k,
|
||||
seqstart_q=seqstart_q,
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
torch.Tensor,
|
||||
LowerTriangularMask,
|
||||
# TODO: Fix handling of gradient through the fMHA autograd function
|
||||
# LowerTriangularMaskWithTensorBias,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||||
}
|
||||
SUPPORTS_ATTN_BIAS_GRAD = True
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
NAME = "cutlassB"
|
||||
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 5e-4,
|
||||
# increased from 9e-2, more opportunities for numerical errors when bias is
|
||||
# used, noticed in gK on SM80
|
||||
torch.half: 1e-1,
|
||||
torch.bfloat16: 7e-1,
|
||||
}
|
||||
|
||||
_TEST_K: List[int] = [
|
||||
32, # 64x64 kernel
|
||||
128, # 64x128/128x128 kernel
|
||||
256, # 64x128 with accumulation in gmem
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||||
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||||
_check_bias_alignment(reasons, d.attn_bias)
|
||||
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
|
||||
|
||||
# Backprop of gradient through broadcasted bias is not supported
|
||||
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
|
||||
# Don't forget that inputs are either in BMK or BMHK!
|
||||
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
|
||||
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
|
||||
else:
|
||||
# bias is B H Mq Mk
|
||||
expected_bias_shape = (
|
||||
d.query.shape[0],
|
||||
d.query.shape[2] if d.query.ndim == 4 else 1,
|
||||
d.query.shape[1],
|
||||
d.key.shape[1],
|
||||
)
|
||||
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
|
||||
reasons.append(
|
||||
"Broadcasting the `attn_bias` tensor is not supported "
|
||||
f"(shape: {tuple(attn_bias_tensor.shape)}"
|
||||
f"/ expected: {expected_bias_shape})"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
|
||||
seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
|
||||
dtype = inp.query.dtype
|
||||
|
||||
rng_seed = rng_offset = 0
|
||||
if inp.p != 0.0:
|
||||
if (
|
||||
ctx.rng_state is None
|
||||
or ctx.rng_state.dtype != torch.int64
|
||||
or ctx.rng_state.device.type != "cpu"
|
||||
or ctx.rng_state.shape != (2,)
|
||||
):
|
||||
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
||||
rng_seed, rng_offset = ctx.rng_state.tolist()
|
||||
|
||||
force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
|
||||
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
|
||||
grad.to(dtype),
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
_get_tensor_bias(inp.attn_bias),
|
||||
cu_seqlens_q=seqstart_q,
|
||||
cu_seqlens_k=seqstart_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
|
||||
output=ctx.out.to(dtype),
|
||||
dropout_p=inp.p,
|
||||
# if not using dropout, seed and offset are irrelevant but still expected
|
||||
# in function signature so just pass 0
|
||||
# seed and offset could be None if a different FW op other than cutlass
|
||||
# was used.
|
||||
rng_seed=rng_seed,
|
||||
rng_offset=rng_offset,
|
||||
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||||
scale=inp.scale,
|
||||
num_splits_key=-1, # Let C++ determine it
|
||||
window_size=inp.attn_bias._window_size
|
||||
if isinstance(
|
||||
inp.attn_bias,
|
||||
(
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
else None,
|
||||
)
|
||||
|
||||
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
|
||||
# require grad
|
||||
if not (
|
||||
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
|
||||
):
|
||||
grad_bias = None
|
||||
|
||||
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
dO,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
b,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
logsumexp,
|
||||
output,
|
||||
dropout_p,
|
||||
rng_seed,
|
||||
rng_offset,
|
||||
custom_mask_type,
|
||||
scale,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seqstart_q=cu_seqlens_q,
|
||||
seqstart_k=cu_seqlens_k,
|
||||
causal=custom_mask_type > 0,
|
||||
)
|
||||
108
pkgs/xformers/ops/fmha/decoder.py
Normal file
108
pkgs/xformers/ops/fmha/decoder.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
from .common import AttentionFwOpBase, Context, Inputs
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""An operator optimized for very small values of K (``K <= 32``) \
|
||||
and f32 pre-Ampere as it does not use TensorCores.
|
||||
Only supports contiguous inputs in BMK format, so an extra reshape \
|
||||
or contiguous call might be done.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This operator is deprecated and should not be used in new code
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
|
||||
SUPPORTED_MAX_K: float = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
NAME = "decoderF"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
|
||||
attn_bias = d.attn_bias
|
||||
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
||||
# If we don't get here, we've an error elsewhere
|
||||
if d.query.ndim != 4 or d.key.ndim != 4:
|
||||
reasons.append("Inputs must be BMHK. BMK not supported")
|
||||
|
||||
if d.query.shape[0] != 1:
|
||||
reasons.append("One formal batch element expected")
|
||||
|
||||
if d.query.shape[-1] != 128:
|
||||
reasons.append("Only head_dim==128 for now.")
|
||||
|
||||
if d.key.stride(-1) != 1:
|
||||
reasons.append("expect keys to have last dim contiguous")
|
||||
|
||||
if d.value.stride(-1) != 1:
|
||||
reasons.append("expect values to have last dim contiguous")
|
||||
|
||||
q_starts = attn_bias.q_seqinfo.seqstart_py
|
||||
if attn_bias.q_seqinfo.max_seqlen != 1:
|
||||
reasons.append("decoding expects one query")
|
||||
elif d.query.shape[1] != len(q_starts) - 1:
|
||||
reasons.append("empty lanes not supported yet")
|
||||
|
||||
if attn_bias.k_seqinfo.padding > 8192:
|
||||
reasons.append("key padding exceeds 8192")
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if needs_gradient:
|
||||
raise NotImplementedError("gradient")
|
||||
attn_bias = inp.attn_bias
|
||||
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
|
||||
padding = attn_bias.k_seqinfo.padding
|
||||
multiquery = inp.key.stride(2) == 0
|
||||
if multiquery:
|
||||
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
|
||||
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
|
||||
else:
|
||||
key = inp.key[0].unflatten(0, (-1, padding))
|
||||
value = inp.value[0].unflatten(0, (-1, padding))
|
||||
|
||||
seq_positions = attn_bias.k_seqinfo.seqlen
|
||||
|
||||
query = inp.query[0, :, None]
|
||||
|
||||
if inp.scale is not None:
|
||||
qk_scale = inp.scale
|
||||
else:
|
||||
qk_scale = 1.0 / np.sqrt(key.shape[-1])
|
||||
|
||||
out = cls.OPERATOR(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
seq_positions=seq_positions,
|
||||
scale=qk_scale,
|
||||
)
|
||||
return out, None
|
||||
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import textwrap
|
||||
from collections import deque
|
||||
from typing import List, Sequence, Type, TypeVar
|
||||
|
||||
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
|
||||
|
||||
|
||||
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
|
||||
# TODO: fill out
|
||||
return False
|
||||
|
||||
|
||||
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
|
||||
|
||||
|
||||
def _format_inputs_description(inp: Inputs) -> str:
|
||||
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
|
||||
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
|
||||
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
|
||||
attn_bias : {type(inp.attn_bias)}
|
||||
p : {inp.p}"""
|
||||
|
||||
|
||||
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
|
||||
reasons = op.not_supported_reasons(inp)
|
||||
if not reasons:
|
||||
return
|
||||
raise exc_type(
|
||||
f"""Operator `{name}` does not support inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}
|
||||
{_format_not_supported_reasons(op, reasons)}"""
|
||||
)
|
||||
|
||||
|
||||
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
|
||||
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
|
||||
|
||||
|
||||
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
|
||||
not_supported_reasons: List[List[str]] = []
|
||||
for op in priority_list:
|
||||
not_supported = op.not_supported_reasons(inp)
|
||||
if not not_supported:
|
||||
return op
|
||||
not_supported_reasons.append(not_supported)
|
||||
|
||||
# Let's write a nice message explaining what we tried and why it's not supported
|
||||
msg = f"""No operator found for `{name}` with inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
|
||||
for op, not_supported in zip(priority_list, not_supported_reasons):
|
||||
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _dispatch_fw_priority_list(
|
||||
inp: Inputs, needs_gradient: bool
|
||||
) -> Sequence[Type[AttentionFwOpBase]]:
|
||||
priority_list_ops = deque(
|
||||
[
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
cutlass.FwOp,
|
||||
small_k.FwOp,
|
||||
]
|
||||
)
|
||||
if _is_cutlass_fwd_faster_than_flash(inp):
|
||||
priority_list_ops.remove(cutlass.FwOp)
|
||||
priority_list_ops.appendleft(cutlass.FwOp)
|
||||
if _is_triton_fwd_fastest(inp):
|
||||
priority_list_ops.remove(triton.FwOp)
|
||||
priority_list_ops.appendleft(triton.FwOp)
|
||||
if not needs_gradient:
|
||||
mqa_or_gqa = (
|
||||
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
|
||||
)
|
||||
if not mqa_or_gqa:
|
||||
# With multiquery, cutlass is sometimes faster than decoder
|
||||
# but it's not currently clear when.
|
||||
priority_list_ops.appendleft(decoder.FwOp)
|
||||
# Split-KV is useful with MQA
|
||||
# for short Q-seqlen / long K-seqlen
|
||||
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
|
||||
parallelism_BH = 0 # BMK
|
||||
if inp.query.ndim == 3:
|
||||
parallelism_BH = inp.query.shape[0]
|
||||
elif inp.query.ndim == 4: # BMHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
elif inp.query.ndim == 5: # BMGHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
if parallelism_BH > 0 and parallelism_BH < 64:
|
||||
priority_list_ops.appendleft(triton_splitk.FwOp)
|
||||
# Without variable seqlen flash is fastest
|
||||
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
|
||||
priority_list_ops.remove(flash.FwOp)
|
||||
priority_list_ops.appendleft(flash.FwOp)
|
||||
|
||||
return priority_list_ops
|
||||
|
||||
|
||||
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
|
||||
"""Computes the best operator for forward
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if not operator was found
|
||||
|
||||
Returns:
|
||||
AttentionOp: The best operator for the configuration
|
||||
"""
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_forward",
|
||||
# _dispatch_fw_priority_list(inp, needs_gradient),
|
||||
# inp,
|
||||
# )
|
||||
return flash.FwOp
|
||||
|
||||
|
||||
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
|
||||
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
||||
flash.BwOp,
|
||||
cutlass.BwOp,
|
||||
# CUDA illegal memory issues, race conditions etc..
|
||||
# triton.BwOp,
|
||||
# Deprecated
|
||||
small_k.BwOp,
|
||||
]
|
||||
# if _is_cutlassB_faster_than_flash(inp):
|
||||
# priority_list_ops.remove(cutlass.BwOp)
|
||||
# priority_list_ops.insert(0, cutlass.BwOp)
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_backward", priority_list_ops, inp
|
||||
# )
|
||||
return flash.BwOp
|
||||
666
pkgs/xformers/ops/fmha/flash.py
Normal file
666
pkgs/xformers/ops/fmha/flash.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from itertools import zip_longest
|
||||
from typing import Any, List, Optional, Set, Tuple, Union
|
||||
import os
|
||||
import torch
|
||||
|
||||
from ..common import _get_storage_base, get_operator, register_operator
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
)
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
global enable_ixdnn
|
||||
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
|
||||
|
||||
FLASH_VERSION = "0.0.0"
|
||||
try:
|
||||
try:
|
||||
from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
from ..._cpp_lib import _build_metadata
|
||||
|
||||
if _build_metadata is not None:
|
||||
FLASH_VERSION = _build_metadata.flash_version
|
||||
except ImportError:
|
||||
import flash_attn
|
||||
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
|
||||
FLASH_VERSION = flash_attn.__version__
|
||||
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
||||
if flash_ver_parsed < (2, 3):
|
||||
raise ImportError("Requires 2.3 for sliding window support")
|
||||
|
||||
# create library so that flash-attn goes through the PyTorch Dispatcher
|
||||
_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
|
||||
_flash_lib.define(
|
||||
"flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
||||
"int max_seqlen_q, int max_seqlen_k, "
|
||||
"float p, float softmax_scale, "
|
||||
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
_flash_lib.define(
|
||||
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
"int max_seqlen_q, int max_seqlen_k, "
|
||||
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
def _flash_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode
|
||||
):
|
||||
if enable_ixdnn:
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.fwd_ixdnn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None, # out
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
p,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode,
|
||||
-1, -1,
|
||||
None, # rng
|
||||
)
|
||||
else:
|
||||
if cu_seq_lens_q is None:
|
||||
assert cu_seq_lens_k is None
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None, # out
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None, # rng
|
||||
)
|
||||
else:
|
||||
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
|
||||
|
||||
def _flash_bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode
|
||||
):
|
||||
if enable_ixdnn:
|
||||
_C_flashattention.bwd_ixdnn(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
p,
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode,
|
||||
-1, -1,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
if cu_seq_lens_k is None:
|
||||
assert cu_seq_lens_q is None
|
||||
_C_flashattention.bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
_C_flashattention.varlen_bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _convert_input_format(
|
||||
inp: Inputs,
|
||||
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
|
||||
assert inp.query.ndim in [4, 5]
|
||||
query, key, value = inp.query, inp.key, inp.value
|
||||
batch = query.shape[0]
|
||||
seqlen_q = query.shape[1]
|
||||
seqlen_kv = key.shape[1]
|
||||
head_dim_q = query.shape[-1]
|
||||
head_dim_v = value.shape[-1]
|
||||
|
||||
attn_bias = inp.attn_bias
|
||||
if enable_ixdnn:
|
||||
if query.shape[0] == 1:
|
||||
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
|
||||
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
|
||||
else:
|
||||
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
|
||||
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
|
||||
max_seqlen_q = inp.query.shape[1]
|
||||
max_seqlen_k = inp.key.shape[1]
|
||||
else:
|
||||
if isinstance(attn_bias, BlockDiagonalMask):
|
||||
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
|
||||
inp.query.device, non_blocking=True
|
||||
)
|
||||
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
|
||||
inp.query.device, non_blocking=True
|
||||
)
|
||||
|
||||
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
|
||||
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
|
||||
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
||||
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
||||
else:
|
||||
cu_seqlen_k = None
|
||||
cu_seqlen_q = None
|
||||
max_seqlen_q = inp.query.shape[1]
|
||||
max_seqlen_k = inp.key.shape[1]
|
||||
|
||||
if query.ndim == 5: # QGA
|
||||
# Fold the group/head_in_group dimensions together
|
||||
def fold(x):
|
||||
# Either the head is replicated
|
||||
if x.stride(3) == 0:
|
||||
return x[:, :, :, 0]
|
||||
# Or we reshape
|
||||
return x.reshape(
|
||||
[
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
-1,
|
||||
x.shape[4],
|
||||
]
|
||||
)
|
||||
|
||||
query = fold(query)
|
||||
key = fold(key)
|
||||
value = fold(value)
|
||||
# Optimize for MHA
|
||||
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
|
||||
key = key[:, :, :1]
|
||||
value = value[:, :, :1]
|
||||
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
|
||||
# We want format `[batch * seqlen, num_heads, head_dim_q]`
|
||||
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
|
||||
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
|
||||
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
|
||||
new_inp = replace(
|
||||
inp,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
|
||||
|
||||
|
||||
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
|
||||
return isinstance(
|
||||
attn_bias,
|
||||
(
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
||||
if isinstance(
|
||||
attn_bias,
|
||||
(BlockDiagonalCausalLocalAttentionMask,),
|
||||
):
|
||||
return attn_bias._window_size or 0
|
||||
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
|
||||
return attn_bias._window_size
|
||||
return 0
|
||||
|
||||
|
||||
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
|
||||
# Flash does not support TopLeft, so only allow causal masks with TopLeft
|
||||
# if each batch element has equal number of queries and keys.
|
||||
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
|
||||
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
|
||||
# if each batch element has equal number of queries and keys.
|
||||
for k_start, q_start in zip_longest(
|
||||
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
|
||||
):
|
||||
if k_start != q_start:
|
||||
reasons.append(
|
||||
"Only support BlockDiagonalCausalMask if equal"
|
||||
" numbers of keys and queries"
|
||||
)
|
||||
break
|
||||
elif isinstance(d.attn_bias, LowerTriangularMask):
|
||||
if d.query.shape[1] != d.key.shape[1]:
|
||||
reasons.append(
|
||||
"Only support LowerTriangularMask if equal number of" "keys and queries"
|
||||
)
|
||||
|
||||
|
||||
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
|
||||
"""
|
||||
We want to be able to collapse the G/H dimensions together
|
||||
"""
|
||||
if x.ndim == 5:
|
||||
stride_g, stride_h = x.stride(2), x.stride(3)
|
||||
if x.shape[2] == 1:
|
||||
return
|
||||
if x.shape[3] == 1 or stride_h == 0:
|
||||
return
|
||||
if stride_g != stride_h * x.shape[-2]:
|
||||
reasons.append(
|
||||
f"GQA is only supported when the G/H dimensions are contiguous\n"
|
||||
f" {name}.stride: {x.stride()}\n"
|
||||
f" {name}.shape : {list(x.shape)}"
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Operator that computes memory-efficient attention using \
|
||||
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
|
||||
implementation.
|
||||
"""
|
||||
|
||||
OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 256
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = False
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = f"flshattF@{FLASH_VERSION}"
|
||||
VERSION = FLASH_VERSION
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
_check_needs_no_topleft(d, reasons)
|
||||
_check_strides_for_bmghk(d.query, "query", reasons)
|
||||
_check_strides_for_bmghk(d.key, "key", reasons)
|
||||
_check_strides_for_bmghk(d.value, "value", reasons)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
return_softmax = False
|
||||
out_shape = [
|
||||
*inp.query.shape[:-1],
|
||||
inp.value.shape[-1],
|
||||
]
|
||||
# no cumulative seqlen
|
||||
(
|
||||
inp,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_k,
|
||||
) = _convert_input_format(inp)
|
||||
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
inp.p,
|
||||
inp.scale_float,
|
||||
_is_causal(inp.attn_bias),
|
||||
_window_size(inp.attn_bias),
|
||||
return_softmax,
|
||||
inp.use_alibi,
|
||||
inp.alibi_mode,
|
||||
inp.imp_mode,
|
||||
)
|
||||
out = out.reshape(out_shape)
|
||||
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
|
||||
if inp.p != 0.0:
|
||||
ctx.op_bw = BwOp
|
||||
return (out, ctx)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_softmax,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
causal=causal,
|
||||
seqstart_k=cu_seq_lens_k,
|
||||
seqstart_q=cu_seq_lens_q,
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_operator("xformers_flash", "flash_bwd")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
IS_DETERMINISTIC = False
|
||||
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
|
||||
NAME = f"flshattB@{FLASH_VERSION}"
|
||||
VERSION = FLASH_VERSION
|
||||
|
||||
MAX_HEADDIM_SM8x = 192
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
||||
|
||||
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
|
||||
# is a strange embedding dimension.
|
||||
# if K not in {8, 16, 32, 64, 128, 256}:
|
||||
# reasons.append(f"Embed dim {K} not supported")
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
_check_needs_no_topleft(d, reasons)
|
||||
if d.device.type == "cuda":
|
||||
# Due to limited shared-memory, some GPUs are limited in head dimension
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
|
||||
if (
|
||||
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
|
||||
and not is_sm80_or_sm90
|
||||
):
|
||||
reasons.append(
|
||||
"requires a GPU with compute capability 8.0 "
|
||||
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
|
||||
(
|
||||
inp,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_k,
|
||||
) = _convert_input_format(inp)
|
||||
assert ctx.lse.is_contiguous
|
||||
ctx_lse = ctx.lse
|
||||
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
assert ctx_lse.shape[2] >= max_seqlen_q
|
||||
if max_seqlen_q != ctx_lse.shape[2]:
|
||||
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
|
||||
kernel_out_shape = [
|
||||
*inp.query.shape[:-1],
|
||||
inp.value.shape[-1],
|
||||
]
|
||||
|
||||
# Create dq,dk,dv
|
||||
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
|
||||
# right strides, so we can avoid a `cat`
|
||||
if (
|
||||
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
|
||||
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
|
||||
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
||||
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
||||
):
|
||||
# Create one big contiguous chunk
|
||||
# This is because q, k and v usually come from a single
|
||||
# output of a linear layer that is chunked.
|
||||
# Creating the gradients with the right layout saves us
|
||||
# a `torch.cat` call in the backward pass
|
||||
chunk = torch.empty(
|
||||
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
|
||||
dtype=ctx.q_padded.dtype,
|
||||
device=inp.device,
|
||||
)
|
||||
grads = Gradients(
|
||||
dq=chunk.select(-3, 0),
|
||||
dk=chunk.select(-3, 1),
|
||||
dv=chunk.select(-3, 2),
|
||||
)
|
||||
else:
|
||||
grads = Gradients(
|
||||
dq=torch.empty_like(ctx.q_padded),
|
||||
dk=torch.empty_like(ctx.k_padded),
|
||||
dv=torch.empty_like(ctx.v_padded),
|
||||
)
|
||||
|
||||
assert grad.dtype in cls.SUPPORTED_DTYPES
|
||||
cls.OPERATOR(
|
||||
grad.reshape(kernel_out_shape).contiguous(),
|
||||
ctx.q_padded,
|
||||
ctx.k_padded,
|
||||
ctx.v_padded,
|
||||
# ctx.out.reshape(kernel_out_shape),
|
||||
ctx.o_padded,
|
||||
ctx_lse,
|
||||
grads.dq,
|
||||
grads.dk,
|
||||
grads.dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
inp.p,
|
||||
inp.scale_float,
|
||||
_is_causal(inp.attn_bias),
|
||||
_window_size(inp.attn_bias),
|
||||
inp.use_alibi,
|
||||
inp.alibi_mode,
|
||||
inp.imp_mode,
|
||||
)
|
||||
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
|
||||
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
|
||||
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
|
||||
|
||||
return grads
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
causal=causal,
|
||||
seqstart_k=cu_seq_lens_k,
|
||||
seqstart_q=cu_seq_lens_q,
|
||||
)
|
||||
186
pkgs/xformers/ops/fmha/small_k.py
Normal file
186
pkgs/xformers/ops/fmha/small_k.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from .attn_bias import AttentionBias
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
bmk2bmhk,
|
||||
)
|
||||
|
||||
|
||||
def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
|
||||
return (
|
||||
tensor.permute((0, 2, 1, 3))
|
||||
.contiguous()
|
||||
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def _get_tensor_bias_bmk(
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if not isinstance(attn_bias, torch.Tensor):
|
||||
assert attn_bias is None
|
||||
return None
|
||||
# BMK -> BMHK
|
||||
if attn_bias.ndim == 4:
|
||||
attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]])
|
||||
return attn_bias
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""An operator optimized for very small values of K (``K <= 32``) \
|
||||
and f32 pre-Ampere as it does not use TensorCores.
|
||||
Only supports contiguous inputs in BMK format, so an extra reshape \
|
||||
or contiguous call might be done.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This operator is deprecated and should not be used in new code
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_small_k")
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
SUPPORTED_DTYPES = {torch.float}
|
||||
SUPPORTED_MAX_K: float = 32
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = False
|
||||
NAME = "smallkF"
|
||||
|
||||
BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 4e-3,
|
||||
}
|
||||
# as this kernel is a bit slow, this should make tests run faster
|
||||
_TEST_BATCH_SIZES = [1, 3]
|
||||
_TEST_K = [2, 3, 8, 16, 32]
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
|
||||
reasons.append("bias with non-zero stride not supported")
|
||||
buffer_size = 8
|
||||
k = d.query.shape[-1]
|
||||
for pack in [1, 2, 4]:
|
||||
if (k % pack) == 0 and (k // pack) <= buffer_size:
|
||||
return reasons
|
||||
reasons.append(f"unsupported embed per head: {k}")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if inp.scale is not None:
|
||||
raise NotImplementedError("Unsupport custom scale")
|
||||
num_heads = inp.query.shape[2]
|
||||
query = _bmhk2bmk_contiguous(inp.query)
|
||||
key = _bmhk2bmk_contiguous(inp.key)
|
||||
value = _bmhk2bmk_contiguous(inp.value)
|
||||
|
||||
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
compute_logsumexp=needs_gradient,
|
||||
attn_bias=_get_tensor_bias_bmk(inp.attn_bias),
|
||||
p=inp.p,
|
||||
)
|
||||
out = bmk2bmhk(out, num_heads)
|
||||
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
|
||||
if not needs_gradient:
|
||||
return out, None
|
||||
ctx = Context(out=out, lse=lse)
|
||||
if inp.p != 0.0:
|
||||
ctx.op_bw = BwOp
|
||||
ctx.rng_state = torch.tensor(
|
||||
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
|
||||
# there is some extra precision loss in the CPU implementation due to an
|
||||
# extra accumulation step in grad_q, which is not present in the CUDA
|
||||
# implementation
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 4e-3,
|
||||
}
|
||||
NAME = "smallkB"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
|
||||
reasons.append("bias with non-zero stride not supported")
|
||||
buffer_size = 8
|
||||
k = d.query.shape[-1]
|
||||
for pack in [1, 2, 4]:
|
||||
if (k % pack) == 0 and (k // pack) <= buffer_size:
|
||||
return reasons
|
||||
reasons.append(f"unsupported embed per head: {k}")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
num_heads = grad.shape[2]
|
||||
grad = _bmhk2bmk_contiguous(grad)
|
||||
query = _bmhk2bmk_contiguous(inp.query)
|
||||
key = _bmhk2bmk_contiguous(inp.key)
|
||||
value = _bmhk2bmk_contiguous(inp.value)
|
||||
out = _bmhk2bmk_contiguous(ctx.out)
|
||||
|
||||
rng_seed = rng_offset = 0
|
||||
if inp.p != 0.0:
|
||||
if (
|
||||
ctx.rng_state is None
|
||||
or ctx.rng_state.dtype != torch.int64
|
||||
or ctx.rng_state.device.type != "cpu"
|
||||
or ctx.rng_state.shape != (2,)
|
||||
):
|
||||
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
||||
rng_seed, rng_offset = ctx.rng_state.tolist()
|
||||
grad_q, grad_k, grad_v = cls.OPERATOR(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
# LSE: BHM -> (BH)M
|
||||
ctx.lse.reshape([-1, ctx.lse.shape[-1]]),
|
||||
out,
|
||||
_get_tensor_bias_bmk(inp.attn_bias),
|
||||
inp.p,
|
||||
rng_seed,
|
||||
rng_offset,
|
||||
)
|
||||
return Gradients(
|
||||
dq=bmk2bmhk(grad_q, num_heads),
|
||||
dk=bmk2bmhk(grad_k, num_heads),
|
||||
dv=bmk2bmhk(grad_v, num_heads),
|
||||
)
|
||||
201
pkgs/xformers/ops/fmha/triton.py
Normal file
201
pkgs/xformers/ops/fmha/triton.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ... import _is_triton_available
|
||||
from ..common import register_operator
|
||||
|
||||
# This implementation needs pre-MLIR triton
|
||||
# The BW pass is not stable/well tested
|
||||
# And also does not have the latest improvements
|
||||
if TYPE_CHECKING or (False and _is_triton_available()):
|
||||
try:
|
||||
from flash_attn.flash_attn_triton import (
|
||||
_flash_attn_backward,
|
||||
_flash_attn_forward,
|
||||
)
|
||||
except ImportError:
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
def import_module_from_path(path: str) -> types.ModuleType:
|
||||
"""Import a module from the given path, w/o __init__.py"""
|
||||
module_path = pathlib.Path(path).resolve()
|
||||
module_name = module_path.stem # 'path/x.py' -> 'x'
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
|
||||
assert isinstance(spec, importlib.machinery.ModuleSpec)
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
sys.modules[module_name] = module
|
||||
assert isinstance(spec.loader, importlib.abc.Loader)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
flash_attn = import_module_from_path(
|
||||
"third_party/flash-attention/flash_attn/flash_attn_triton.py"
|
||||
)
|
||||
_flash_attn_backward = flash_attn._flash_attn_backward
|
||||
_flash_attn_forward = flash_attn._flash_attn_forward
|
||||
|
||||
triton_flash_backward = _flash_attn_backward
|
||||
triton_flash_forward = _flash_attn_forward
|
||||
else:
|
||||
triton_flash_backward = None
|
||||
triton_flash_forward = None
|
||||
|
||||
from .attn_bias import LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_inputs(inp: Inputs) -> Inputs:
|
||||
attn_bias = inp.attn_bias
|
||||
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
|
||||
B = inp.query.shape[0]
|
||||
h = attn_bias.shape[0] // B
|
||||
attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])
|
||||
|
||||
# Make sure that the last dimension is contiguous
|
||||
query, key, value = [
|
||||
x if x.stride(-1) == 1 else x.contiguous()
|
||||
for x in [inp.query, inp.key, inp.value]
|
||||
]
|
||||
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Operator that computes memory-efficient attention using \
|
||||
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
|
||||
implementation, based on
|
||||
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
|
||||
"""
|
||||
|
||||
OPERATOR = triton_flash_forward
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
LowerTriangularMask,
|
||||
# TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now.
|
||||
# torch.Tensor,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
NAME = "tritonflashattF"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
# Fails on 7.5 with illegal memory access
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
out, lse, softmax_scale = triton_flash_forward(
|
||||
q=inp.query,
|
||||
k=inp.key,
|
||||
v=inp.value,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return out, Context(lse=lse, out=out)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = triton_flash_backward
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
NAME = "tritonflashattB"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
if torch.cuda.get_device_capability(d.device) != (8, 0):
|
||||
reasons.append("requires A100 GPU")
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
grads = Gradients(
|
||||
dq=torch.empty_like(inp.query),
|
||||
dk=torch.empty_like(inp.key),
|
||||
dv=torch.empty_like(inp.value),
|
||||
)
|
||||
cls.OPERATOR(
|
||||
grad,
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
ctx.out,
|
||||
ctx.get_padded_lse(128),
|
||||
grads.dq,
|
||||
grads.dk,
|
||||
grads.dv,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return grads
|
||||
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import _has_triton21, register_operator
|
||||
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
|
||||
|
||||
|
||||
def _strides(x: torch.Tensor, *stride_names: str):
|
||||
assert x.ndim == len(stride_names)
|
||||
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
|
||||
|
||||
|
||||
if TYPE_CHECKING or _has_triton21():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_splitK(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Seq_len,
|
||||
stride_qz,
|
||||
stride_qm,
|
||||
stride_qg,
|
||||
stride_qh,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kn,
|
||||
stride_kg,
|
||||
stride_kh,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vn,
|
||||
stride_vg,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
Z,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
BLOCK_N_PER_SPLIT,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
USE_SEQ_LEN: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr = 1,
|
||||
N_GROUPS: tl.constexpr = 1,
|
||||
):
|
||||
"""This kernel can accept non-quantized or int4-quantized keys/values.
|
||||
PACKED_PER_VAL determines the quantization type:
|
||||
- PACKED_PER_VAL == 1 means no quantization
|
||||
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
|
||||
For the quantized case K/V should be int32 tensors.
|
||||
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
|
||||
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
|
||||
So K[B, H, M, :] has a form
|
||||
[ quant_coef0, quant_coef1, ...|
|
||||
group0_quant_value0, group0_quant_value1,... |
|
||||
group1_quant_value0, group1_quant_value1,...]
|
||||
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
|
||||
|
||||
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
|
||||
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
|
||||
See how FwOp.apply does it below.
|
||||
"""
|
||||
tl.static_assert(
|
||||
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
|
||||
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
|
||||
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
|
||||
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
|
||||
)
|
||||
tl.static_assert(
|
||||
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
|
||||
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
|
||||
)
|
||||
|
||||
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
|
||||
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
|
||||
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
|
||||
|
||||
start_m = tl.program_id(0)
|
||||
off_zhg = tl.program_id(1)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
splitk_idx = tl.program_id(2)
|
||||
|
||||
lo = splitk_idx * BLOCK_N_PER_SPLIT
|
||||
if USE_SEQ_LEN:
|
||||
kv_len = tl.load(Seq_len + off_z)
|
||||
else:
|
||||
kv_len = N_CTX_K
|
||||
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
|
||||
# Additional shift by 1 along the last dimension in the quantized case, since
|
||||
# the first element along that dim contains packed quantization coefficients.
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
|
||||
shape=(PACKED_D_PER_GROUP, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
|
||||
shape=(hi, PACKED_D_PER_GROUP),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
if QUANTIZED:
|
||||
# Pointers to quantization coefficients. Even those they are 1D,
|
||||
# we have to use block pointers, since usual pointers
|
||||
# don't support boundary checks
|
||||
K_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=k_base,
|
||||
shape=(1, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(1, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=v_base,
|
||||
shape=(hi, 1),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, 1),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
K_scale_shift_block_ptr = None
|
||||
V_scale_shift_block_ptr = None
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
|
||||
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
|
||||
# This is a solution for Triton native lack of support for lists of tensors.
|
||||
acc: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
q[i] = tl.load( # noqa: F821
|
||||
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
||||
)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
k: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
v: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N,
|
||||
PACKED_PER_VAL,
|
||||
PACKED_D_PER_GROUP,
|
||||
Q.dtype.element_ty,
|
||||
i,
|
||||
)
|
||||
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
qk += tl.dot(q[i], k[i]) # noqa: F821
|
||||
qk *= qk_scale
|
||||
|
||||
# TODO: This is slow, and only needed at the last iteration.
|
||||
# Maybe we can unroll the last iteration instead?
|
||||
if BOUNDS_CHECKS_N:
|
||||
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
|
||||
# -- scale and update acc --
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] *= alpha[:, None] # noqa: F821
|
||||
acc[i] += tl.dot(p, v[i]) # noqa: F821
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
if PACKED_PER_VAL > 1:
|
||||
K_scale_shift_block_ptr = tl.advance(
|
||||
K_scale_shift_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.advance(
|
||||
V_scale_shift_block_ptr, (BLOCK_N, 0)
|
||||
)
|
||||
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_osk_m, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
tl.store(
|
||||
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
|
||||
acc[i], # noqa: F821
|
||||
boundary_check=(0,),
|
||||
)
|
||||
# Write metadata for split-K reduction
|
||||
Metadata_ptr = (
|
||||
Metadata
|
||||
+ off_zhg * stride_mzhg
|
||||
+ splitk_idx * stride_ms
|
||||
+ start_m * BLOCK_M
|
||||
+ tl.arange(0, BLOCK_M)
|
||||
)
|
||||
tl.store(Metadata_ptr, m_i)
|
||||
tl.store(Metadata_ptr + stride_m2, l_i)
|
||||
|
||||
@triton.jit
|
||||
def load_dequantize_k_v_group(
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr,
|
||||
PACKED_D_PER_GROUP: tl.constexpr,
|
||||
dtype: tl.constexpr,
|
||||
group_id: tl.constexpr,
|
||||
):
|
||||
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
|
||||
If quantization is group-wise, use group_id to advance the pointers to the current group.
|
||||
"""
|
||||
# Advance to the current quantization group
|
||||
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
|
||||
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
|
||||
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
|
||||
|
||||
if PACKED_PER_VAL > 1:
|
||||
# K/V are quantized, load quantization coefficients and dequantize
|
||||
|
||||
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
|
||||
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
|
||||
|
||||
k_scale_shift = tl.load(
|
||||
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
v_scale_shift = tl.load(
|
||||
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
|
||||
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
|
||||
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
|
||||
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
|
||||
k_t = dequantize(
|
||||
tl.trans(k),
|
||||
tl.trans(k_scale),
|
||||
tl.trans(k_shift),
|
||||
PACKED_PER_VAL,
|
||||
).to(dtype)
|
||||
k = tl.trans(k_t)
|
||||
return k, v
|
||||
|
||||
@triton.jit
|
||||
def cast_uint32_to_half2(scale_shift):
|
||||
"""Extract two float16 packed into one int32"""
|
||||
scale = scale_shift & 0xFFFF
|
||||
shift = scale_shift >> 16
|
||||
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
return scale, shift
|
||||
|
||||
@triton.jit
|
||||
def dequantize(
|
||||
x_,
|
||||
scale,
|
||||
shift,
|
||||
PACKED_PER_VAL: tl.constexpr = 8,
|
||||
):
|
||||
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
||||
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
||||
"""
|
||||
|
||||
# Axis along which offsets are applied matters here
|
||||
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
||||
# and expand K/V to that shape before applying offsets
|
||||
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
|
||||
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
|
||||
# on the implementation details which might change in the future.
|
||||
# Ideally we would like to use tl.reshape, but it's not implemented yet.
|
||||
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
|
||||
|
||||
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
||||
# scale: (BLOCK_N, 1)
|
||||
# offsets: (PACKED_PER_VAL,)
|
||||
BLOCK_N: tl.constexpr = x_.shape[0]
|
||||
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
||||
offsets = tl.arange(0, PACKED_PER_VAL) * 4
|
||||
quant_offset = (
|
||||
x_[:, None, :] >> offsets[None, :, None]
|
||||
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
|
||||
|
||||
quant_offset = tl.view(
|
||||
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
||||
)
|
||||
# Trick - instead of converting int4 to float16 we view it as float16
|
||||
# and then multiply by 32768 * 512 == 2**24
|
||||
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
||||
scale_512 = scale * 512
|
||||
|
||||
dequant = quant_offset * scale_512 + shift
|
||||
return dequant
|
||||
|
||||
@triton.jit
|
||||
def _splitK_reduce(
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Out, # [B, H, M, K]
|
||||
LSE, # [B, H, M]
|
||||
split_k,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_og,
|
||||
stride_om,
|
||||
stride_ok,
|
||||
stride_lse_zhg,
|
||||
stride_lse_m,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
):
|
||||
off_zhg = tl.program_id(0)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
off_m = tl.program_id(1)
|
||||
|
||||
Out_splitK_ptr = (
|
||||
Out_splitK
|
||||
+ stride_osk_zhg * off_zhg
|
||||
+ stride_osk_m * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
|
||||
m = tl.load(Metadata_ptr)
|
||||
l_sum = tl.load(Metadata_ptr + stride_m2)
|
||||
acc = tl.load(Out_splitK_ptr)
|
||||
|
||||
for split_k_idx in range(1, split_k):
|
||||
Metadata_ptr = Metadata_ptr + stride_ms
|
||||
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
|
||||
|
||||
m_k = tl.load(Metadata_ptr)
|
||||
l_k = tl.load(Metadata_ptr + stride_m2)
|
||||
acc_k = tl.load(Out_splitK_ptr)
|
||||
|
||||
m_new = tl.maximum(m, m_k)
|
||||
if m_k < m:
|
||||
# Scale incoming values
|
||||
alpha = tl.math.exp2(m_k - m_new)
|
||||
acc_k = acc_k * alpha
|
||||
l_k = l_k * alpha
|
||||
else:
|
||||
# Scale our values
|
||||
alpha = tl.math.exp2(m - m_new)
|
||||
acc = acc * alpha
|
||||
l_sum = l_sum * alpha
|
||||
|
||||
m = m_new
|
||||
l_sum = l_sum + l_k
|
||||
acc = acc + acc_k
|
||||
|
||||
acc = acc / l_sum
|
||||
Out_ptr = (
|
||||
Out
|
||||
+ stride_oz * off_z
|
||||
+ stride_oh * off_h
|
||||
+ stride_og * off_g
|
||||
+ stride_om * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
tl.store(Out_ptr, acc)
|
||||
|
||||
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
|
||||
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
|
||||
|
||||
else:
|
||||
_fwd_kernel_splitK = None
|
||||
_splitK_reduce = None
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
|
||||
Quantized path will be taken if input K/V have type int32.
|
||||
|
||||
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
|
||||
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
|
||||
Quantization coefficients (scale and shift) are represented as two
|
||||
float16 constants per group, packed into int32. Quantization coefficients of
|
||||
all groups are placed at the beginning of the row. So, if unquantized K/V have head
|
||||
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
|
||||
and dtype int32.
|
||||
Pseudocode for dequantizing one row can look like:
|
||||
group_size = D // 8
|
||||
for i in range(NUM_GROUPS):
|
||||
group_start = NUM_GROUPS + i * group_size
|
||||
group_quant = K[..., group_start: group_start + group_size]
|
||||
scale, shift = unpack_int32_into_float16x2(group_quant[0])
|
||||
group_dequant = group_quant[..., 1:] * scale + shift
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
OPERATOR = _fwd_kernel_splitK
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = "triton_splitKF"
|
||||
|
||||
SPLIT_K: Optional[int] = None
|
||||
BLOCK_M = 16
|
||||
BLOCK_N = 64
|
||||
|
||||
NUM_GROUPS = 1 # Default quantization is row-wise
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
||||
if K not in {16, 32, 64, 128}:
|
||||
reasons.append(f"Embed dim {K} not supported")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
if d.key.dtype != torch.int32:
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
|
||||
q_len = d.query.shape[1]
|
||||
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
||||
seqinfo = d.attn_bias.q_seqinfo
|
||||
if q_len != seqinfo.seqstart_py[-1]:
|
||||
reasons.append(
|
||||
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
||||
)
|
||||
q_len = seqinfo.min_seqlen
|
||||
if q_len != seqinfo.max_seqlen:
|
||||
reasons.append(
|
||||
"Variable query len is not supported in the presence of causal mask."
|
||||
)
|
||||
|
||||
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
|
||||
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
|
||||
reasons.append("multiquery is only supported with query seqlen=1")
|
||||
|
||||
if d.attn_bias is not None and q_len > 1:
|
||||
reasons.append(
|
||||
"query with seqlen > 1 is not supported in the presence of causal mask"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
|
||||
"""Heuristic for the number of splits"""
|
||||
bh = B * H
|
||||
split_k = max(Mk, 1024) // bh
|
||||
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
||||
while split_k > 0 and Mk / split_k < max_chunk_size:
|
||||
split_k = split_k // 2
|
||||
split_k = min(split_k, 64)
|
||||
split_k = max(split_k, 1)
|
||||
return split_k
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
attn_bias = inp.attn_bias
|
||||
seq_len = None
|
||||
q, k, v = inp.get_qkv_in_bmghk()
|
||||
|
||||
if attn_bias is not None:
|
||||
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
# TODO: do we really need to do this cast? seems fishy but
|
||||
# I just copied it from the decoder.py
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
seq_len = attn_bias.k_seqinfo.seqlen
|
||||
B = len(seq_len)
|
||||
G, H, Kq = q.shape[-3:]
|
||||
Kkv = v.shape[-1]
|
||||
|
||||
# assume kv has been padded
|
||||
q = q.reshape(B, -1, G, H, Kq)
|
||||
k = k.reshape(B, -1, G, H, Kkv)
|
||||
v = v.reshape(B, -1, G, H, Kkv)
|
||||
|
||||
# Transpose in the case of MQA/GQA
|
||||
mqa_swap_seqlen_head = False
|
||||
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
|
||||
mqa_swap_seqlen_head = True
|
||||
assert q.shape[1] == 1
|
||||
q = q.transpose(1, 3)
|
||||
k = k[:, :, :, :1]
|
||||
v = v[:, :, :, :1]
|
||||
|
||||
if k.dtype == torch.int32:
|
||||
# Quantized K/V
|
||||
PACKED_PER_VAL = 8
|
||||
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
|
||||
else:
|
||||
Lk = k.shape[-1]
|
||||
PACKED_PER_VAL = 1
|
||||
|
||||
B, Mk, G, H, Kkv = k.shape
|
||||
B, M, G, H, Kq = q.shape
|
||||
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
|
||||
|
||||
BLOCK_M = cls.BLOCK_M
|
||||
BLOCK_N = cls.BLOCK_N
|
||||
if cls.SPLIT_K is not None:
|
||||
split_k = cls.SPLIT_K
|
||||
else:
|
||||
# Use heuristics
|
||||
split_k = cls.get_split_k(B, H, Mk)
|
||||
|
||||
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
|
||||
o_splitk = torch.empty(
|
||||
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
|
||||
)
|
||||
metadata = torch.empty(
|
||||
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
|
||||
)
|
||||
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
|
||||
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
|
||||
|
||||
num_warps = 2
|
||||
split_size = (Mk + split_k - 1) // split_k
|
||||
use_seq_len = seq_len is not None
|
||||
_fwd_kernel_splitK_unrolled = unroll_varargs(
|
||||
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
|
||||
)
|
||||
|
||||
_fwd_kernel_splitK_unrolled[grid](
|
||||
Q=q,
|
||||
K=k,
|
||||
V=v,
|
||||
sm_scale=inp.scale_float,
|
||||
Out_splitK=o_splitk,
|
||||
Metadata=metadata,
|
||||
Seq_len=seq_len,
|
||||
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
|
||||
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
|
||||
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
Z=B,
|
||||
H=H,
|
||||
G=G,
|
||||
N_CTX_Q=M,
|
||||
N_CTX_K=Mk,
|
||||
BLOCK_N_PER_SPLIT=split_size,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
|
||||
USE_SEQ_LEN=use_seq_len,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
PACKED_PER_VAL=PACKED_PER_VAL,
|
||||
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
|
||||
)
|
||||
|
||||
if mqa_swap_seqlen_head:
|
||||
out = torch.empty(
|
||||
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
|
||||
).transpose(1, 3)
|
||||
else:
|
||||
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
|
||||
|
||||
# Merge together
|
||||
grid = (B * G * H, M, 1)
|
||||
_splitK_reduce[grid](
|
||||
o_splitk,
|
||||
metadata,
|
||||
out,
|
||||
lse,
|
||||
split_k=split_k,
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
**_strides(out, "oz", "om", "og", "oh", "ok"),
|
||||
**_strides(lse, "lse_zhg", "lse_m"),
|
||||
BLOCK_SIZE=out.shape[-1],
|
||||
G=G,
|
||||
H=H,
|
||||
# TODO: Tune num_warps
|
||||
)
|
||||
lse = lse.reshape([B, G, H, M])
|
||||
if mqa_swap_seqlen_head:
|
||||
# H/M dimensions have been swapped
|
||||
out = out.transpose(1, 3)
|
||||
lse = lse.transpose(2, 3)
|
||||
if inp.query.ndim == 4:
|
||||
# BMGHK -> BMHK
|
||||
assert G == 1
|
||||
out = out[:, :, 0]
|
||||
lse = lse[:, 0]
|
||||
|
||||
return out, Context(out=out, lse=lse)
|
||||
|
||||
|
||||
class FwOp_S1(FwOp):
|
||||
SPLIT_K = 1
|
||||
NAME = "triton_splitK1"
|
||||
|
||||
|
||||
class FwOp_S2(FwOp):
|
||||
SPLIT_K = 2
|
||||
NAME = "triton_splitK2"
|
||||
|
||||
|
||||
class FwOp_S4(FwOp):
|
||||
SPLIT_K = 4
|
||||
NAME = "triton_splitK4"
|
||||
|
||||
|
||||
class FwOp_S8(FwOp):
|
||||
SPLIT_K = 8
|
||||
NAME = "triton_splitK8"
|
||||
|
||||
|
||||
class FwOp_S16(FwOp):
|
||||
SPLIT_K = 16
|
||||
NAME = "triton_splitK16"
|
||||
|
||||
|
||||
class FwOp_S32(FwOp):
|
||||
SPLIT_K = 32
|
||||
NAME = "triton_splitK32"
|
||||
|
||||
|
||||
class FwOp_S64(FwOp):
|
||||
SPLIT_K = 64
|
||||
NAME = "triton_splitK64"
|
||||
|
||||
|
||||
class FwOp_S128(FwOp):
|
||||
SPLIT_K = 128
|
||||
NAME = "triton_splitK128"
|
||||
Reference in New Issue
Block a user