First commit
This commit is contained in:
97
pkgs/xformers/ops/__init__.py
Normal file
97
pkgs/xformers/ops/__init__.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from .fmha import (
|
||||
AttentionBias,
|
||||
AttentionOp,
|
||||
AttentionOpBase,
|
||||
AttentionOpDispatch,
|
||||
LowerTriangularMask,
|
||||
MemoryEfficientAttentionCutlassFwdFlashBwOp,
|
||||
MemoryEfficientAttentionCutlassOp,
|
||||
MemoryEfficientAttentionFlashAttentionOp,
|
||||
MemoryEfficientAttentionOp,
|
||||
MemoryEfficientAttentionTritonFwdFlashBwOp,
|
||||
TritonFlashAttentionOp,
|
||||
memory_efficient_attention,
|
||||
memory_efficient_attention_backward,
|
||||
memory_efficient_attention_forward,
|
||||
memory_efficient_attention_forward_requires_grad,
|
||||
)
|
||||
from .indexing import index_select_cat, scaled_index_add
|
||||
from .rmsnorm import RMSNorm
|
||||
from .rope_padded import rope_padded
|
||||
from .swiglu_op import (
|
||||
SwiGLU,
|
||||
SwiGLUEagerOp,
|
||||
SwiGLUFusedOp,
|
||||
SwiGLUOp,
|
||||
SwiGLUOpDispatch,
|
||||
SwiGLUPackedFusedOp,
|
||||
swiglu,
|
||||
)
|
||||
from .unbind import get_stack_strides, stack_or_none, unbind
|
||||
|
||||
# BW compatibility
|
||||
AttentionMask = AttentionBias
|
||||
|
||||
|
||||
def masked_matmul(a, b, mask=None):
|
||||
if torch.overrides.has_torch_function((a, b, mask)):
|
||||
return torch.overrides.handle_torch_function(
|
||||
masked_matmul, (a, b, mask), a, b, mask
|
||||
)
|
||||
|
||||
att = a @ b
|
||||
|
||||
if mask is None:
|
||||
return att
|
||||
|
||||
if mask.dtype == torch.bool:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
|
||||
# mask is presumed false == ignore
|
||||
att[~mask] = float("-inf")
|
||||
else:
|
||||
# mask is presumed additive
|
||||
att += mask
|
||||
return att
|
||||
|
||||
|
||||
__all__ = [
|
||||
"memory_efficient_attention",
|
||||
"AttentionBias",
|
||||
"AttentionMask",
|
||||
"AttentionOp",
|
||||
"AttentionOpBase",
|
||||
"AttentionOpDispatch",
|
||||
"LowerTriangularMask",
|
||||
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionCutlassOp",
|
||||
"MemoryEfficientAttentionFlashAttentionOp",
|
||||
"MemoryEfficientAttentionOp",
|
||||
"MemoryEfficientAttentionTritonFwdFlashBwOp",
|
||||
"memory_efficient_attention_backward",
|
||||
"memory_efficient_attention_forward",
|
||||
"memory_efficient_attention_forward_requires_grad",
|
||||
"RMSNorm",
|
||||
"SwiGLU",
|
||||
"SwiGLUEagerOp",
|
||||
"SwiGLUFusedOp",
|
||||
"SwiGLUOp",
|
||||
"SwiGLUOpDispatch",
|
||||
"SwiGLUPackedFusedOp",
|
||||
"swiglu",
|
||||
"TritonFlashAttentionOp",
|
||||
"unbind",
|
||||
"stack_or_none",
|
||||
"get_stack_strides",
|
||||
"masked_matmul",
|
||||
"scaled_index_add",
|
||||
"index_select_cat",
|
||||
"rope_padded",
|
||||
]
|
||||
BIN
pkgs/xformers/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/common.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/common.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/indexing.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/indexing.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/rmsnorm.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/rmsnorm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/rope_padded.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/rope_padded.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/swiglu_op.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/swiglu_op.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/__pycache__/unbind.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/__pycache__/unbind.cpython-310.pyc
Normal file
Binary file not shown.
133
pkgs/xformers/ops/common.py
Normal file
133
pkgs/xformers/ops/common.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Dict, List, Type, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
|
||||
def get_operator(library: str, name: str):
|
||||
def no_such_operator(*args, **kwargs):
|
||||
raise RuntimeError(
|
||||
f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?"
|
||||
)
|
||||
|
||||
try:
|
||||
return getattr(getattr(torch.ops, library), name)
|
||||
except (RuntimeError, AttributeError):
|
||||
return no_such_operator
|
||||
|
||||
|
||||
def get_xformers_operator(name: str):
|
||||
return get_operator("xformers", name)
|
||||
|
||||
|
||||
class BaseOperator:
|
||||
OPERATOR: Any
|
||||
NAME: str
|
||||
OPERATOR_CATEGORY: str
|
||||
|
||||
@classmethod
|
||||
def is_available(cls) -> bool:
|
||||
if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator":
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def operator_flop(cls, *inputs) -> int:
|
||||
"""Calculate number of FLOP given inputs to `OPERATOR`"""
|
||||
return -1
|
||||
|
||||
|
||||
OPERATORS_REGISTRY: List[Type[BaseOperator]] = []
|
||||
FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {}
|
||||
|
||||
ClsT = TypeVar("ClsT")
|
||||
|
||||
|
||||
def register_operator(cls: ClsT) -> ClsT:
|
||||
global OPERATORS_REGISTRY, FUNC_TO_XFORMERS_OPERATOR
|
||||
OPERATORS_REGISTRY.append(cls) # type: ignore
|
||||
FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore
|
||||
return cls
|
||||
|
||||
|
||||
# post-2.0, avoids a warning
|
||||
# (`torch.Tensor.storage` will also be deleted in the future)
|
||||
_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None)
|
||||
if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist
|
||||
_GET_TENSOR_STORAGE = torch.Tensor.storage
|
||||
|
||||
|
||||
def _get_storage_base(x: torch.Tensor) -> int:
|
||||
return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore
|
||||
|
||||
|
||||
def make_pytorch_cuda_operator(fn: ClsT) -> ClsT:
|
||||
from .. import get_python_lib
|
||||
|
||||
def render_arg_type(annotation) -> str:
|
||||
if annotation is torch.Tensor:
|
||||
return "Tensor"
|
||||
if annotation is bool:
|
||||
return "bool"
|
||||
if annotation is int:
|
||||
return "int"
|
||||
if annotation is List[int]:
|
||||
return "int[]"
|
||||
if annotation is List[torch.Tensor]:
|
||||
return "Tensor[]"
|
||||
assert False, f"Unable to parse annotation: `{annotation}`"
|
||||
|
||||
sign = inspect.signature(fn) # type: ignore
|
||||
arguments = [
|
||||
f"{render_arg_type(arg.annotation)} {arg.name}"
|
||||
for arg in sign.parameters.values()
|
||||
]
|
||||
op_name = fn.__name__ # type: ignore
|
||||
definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}"
|
||||
|
||||
xformers_lib = get_python_lib()
|
||||
xformers_lib.define(definition)
|
||||
xformers_lib.impl(op_name, fn, "CUDA")
|
||||
dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
return dispatcher_impl(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
def _has_a_version_of_triton():
|
||||
if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1":
|
||||
return False
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
try:
|
||||
import triton # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _has_triton2():
|
||||
if not _has_a_version_of_triton():
|
||||
return False
|
||||
import triton
|
||||
|
||||
tv = TorchVersion(triton.__version__)
|
||||
return tv >= (2, 1) or tv == (2, 0)
|
||||
|
||||
|
||||
def _has_triton21():
|
||||
if not _has_a_version_of_triton():
|
||||
return False
|
||||
import triton
|
||||
|
||||
tv = TorchVersion(triton.__version__)
|
||||
return tv >= (2, 1)
|
||||
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
474
pkgs/xformers/ops/fmha/__init__.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple, Type, Union
|
||||
import os
|
||||
import torch
|
||||
|
||||
from . import cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
AttentionOp,
|
||||
AttentionOpBase,
|
||||
AttentionOpDispatch,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
bmk2bmhk,
|
||||
)
|
||||
from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise
|
||||
|
||||
MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp)
|
||||
MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp)
|
||||
MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp)
|
||||
TritonFlashAttentionOp = (triton.FwOp, triton.BwOp)
|
||||
|
||||
|
||||
class _fMHA(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx, op: AttentionOp, *args: Any) -> Any:
|
||||
inp = Inputs(*args)
|
||||
op_fw = op[0] if op is not None else None
|
||||
op_bw = op[1] if op is not None else None
|
||||
|
||||
out, op_ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
inp=inp, op=op_fw
|
||||
)
|
||||
|
||||
# Saving attn_bias is a bit complicated, as the
|
||||
# torch part should go in `save_for_backward`
|
||||
if isinstance(inp.attn_bias, torch.Tensor):
|
||||
attn_bias_tensor = inp.attn_bias
|
||||
attn_bias_ctx = None
|
||||
else:
|
||||
attn_bias_tensor = None
|
||||
attn_bias_ctx = inp.attn_bias
|
||||
|
||||
ctx.save_for_backward(
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
op_ctx.q_padded,
|
||||
op_ctx.k_padded,
|
||||
op_ctx.v_padded,
|
||||
op_ctx.o_padded,
|
||||
op_ctx.out,
|
||||
op_ctx.lse,
|
||||
)
|
||||
ctx.rng_state = op_ctx.rng_state
|
||||
ctx.attn_bias_tensor = attn_bias_tensor
|
||||
if op_ctx.op_bw is not None:
|
||||
if op_bw is not None and op_bw is not op_ctx.op_bw:
|
||||
raise ValueError(
|
||||
f"Specified op_bw={op_bw.NAME}, but forward op "
|
||||
f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None."
|
||||
)
|
||||
op_bw = op_ctx.op_bw
|
||||
ctx.op_fw = op_fw
|
||||
ctx.op_bw = op_bw
|
||||
ctx.p = inp.p
|
||||
ctx.use_alibi = inp.use_alibi
|
||||
ctx.alibi_mode = inp.alibi_mode
|
||||
ctx.imp_mode = inp.imp_mode
|
||||
|
||||
ctx.scale = inp.scale
|
||||
ctx.attn_bias_ctx = attn_bias_ctx
|
||||
ctx.n_args = len(args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def deserialize_bias(
|
||||
attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor]
|
||||
) -> Any:
|
||||
if attn_bias_tensor is None:
|
||||
return attn_bias_ctx
|
||||
return attn_bias_tensor
|
||||
|
||||
@classmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(cls, ctx, grad):
|
||||
# Re-create context
|
||||
query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors
|
||||
attn_bias_tensor = ctx.attn_bias_tensor
|
||||
rng_state = ctx.rng_state
|
||||
inp = Inputs(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor),
|
||||
p=ctx.p,
|
||||
scale=ctx.scale,
|
||||
use_alibi=ctx.use_alibi,
|
||||
alibi_mode=ctx.alibi_mode,
|
||||
imp_mode = ctx.imp_mode,
|
||||
)
|
||||
op_ctx = Context(
|
||||
lse=lse,
|
||||
out=out,
|
||||
q_padded=q_padded,
|
||||
k_padded=k_padded,
|
||||
v_padded=v_padded,
|
||||
o_padded=o_padded,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
grads = _memory_efficient_attention_backward(
|
||||
ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw
|
||||
)
|
||||
return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * (
|
||||
ctx.n_args - 2
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[AttentionOp] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Implements the memory-efficient attention mechanism following
|
||||
`"Self-Attention Does Not Need O(n^2) Memory" <http://arxiv.org/abs/2112.05682>`_.
|
||||
|
||||
:Inputs shape:
|
||||
|
||||
- Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \
|
||||
the sequence length, H the number of heads, and K the embeding size per head
|
||||
|
||||
- If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1``
|
||||
|
||||
- Inputs can also be of dimension 5 with GQA - see note below
|
||||
|
||||
- Inputs can be non-contiguous - we only require the last dimension's stride to be 1
|
||||
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = query @ key.transpose(-2, -1)
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
attn = F.dropout(attn, p)
|
||||
return attn @ value
|
||||
|
||||
:Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import xformers.ops as xops
|
||||
|
||||
# Compute regular attention
|
||||
y = xops.memory_efficient_attention(q, k, v)
|
||||
|
||||
# With a dropout of 0.2
|
||||
y = xops.memory_efficient_attention(q, k, v, p=0.2)
|
||||
|
||||
# Causal attention
|
||||
y = xops.memory_efficient_attention(
|
||||
q, k, v,
|
||||
attn_bias=xops.LowerTriangularMask()
|
||||
)
|
||||
|
||||
:Supported hardware:
|
||||
|
||||
NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``.
|
||||
|
||||
:EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA):
|
||||
|
||||
MQA/GQA is an experimental feature supported only for the forward pass.
|
||||
If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors
|
||||
in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and
|
||||
``H`` is the number of heads per group (8 in the example).
|
||||
|
||||
Please note that xFormers will not automatically broadcast the inputs, so you will need
|
||||
to broadcast it manually before calling `memory_efficient_attention`.
|
||||
|
||||
:GQA/MQA example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
import xformers.ops as xops
|
||||
|
||||
B, M, K = 3, 32, 128
|
||||
kwargs = dict(device="cuda", dtype=torch.float16)
|
||||
q = torch.randn([B, M, 8, K], **kwargs)
|
||||
k = torch.randn([B, M, 2, K], **kwargs)
|
||||
v = torch.randn([B, M, 2, K], **kwargs)
|
||||
out_gqa = xops.memory_efficient_attention(
|
||||
q.reshape([B, M, 2, 4, K]),
|
||||
k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]),
|
||||
)
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if there is no operator available to compute the MHA
|
||||
ValueError: if inputs are invalid
|
||||
|
||||
:parameter query: Tensor of shape ``[B, Mq, H, K]``
|
||||
:parameter key: Tensor of shape ``[B, Mkv, H, K]``
|
||||
:parameter value: Tensor of shape ``[B, Mkv, H, Kv]``
|
||||
:parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \
|
||||
For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \
|
||||
This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower).
|
||||
:parameter p: Dropout probability. Disabled if set to ``0.0``
|
||||
:parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \
|
||||
scale (q.shape[-1]**-0.5) will be used.
|
||||
:parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \
|
||||
If set to ``None`` (recommended), xFormers \
|
||||
will dispatch to the best available operator, depending on the inputs \
|
||||
and options.
|
||||
:return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]``
|
||||
"""
|
||||
return _memory_efficient_attention(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
|
||||
"""
|
||||
return _memory_efficient_attention_forward(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
|
||||
|
||||
def memory_efficient_attention_forward_requires_grad(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionFwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later.
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments
|
||||
See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
out, ctx = _memory_efficient_attention_forward_requires_grad(
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
op=op,
|
||||
)
|
||||
return out, ctx.lse
|
||||
|
||||
|
||||
def memory_efficient_attention_backward(
|
||||
grad: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
use_alibi: bool = False,
|
||||
alibi_mode: int = 1,
|
||||
imp_mode: int = 0,
|
||||
*,
|
||||
op: Optional[Type[AttentionBwOpBase]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the gradient of the attention.
|
||||
Returns a tuple (dq, dk, dv)
|
||||
See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments.
|
||||
`lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad`
|
||||
"""
|
||||
if p != 0.0:
|
||||
raise NotImplementedError(
|
||||
"dropout is not supported on the non-autograd API."
|
||||
" If you want to use dropout, please call `memory_efficient_attention` directly"
|
||||
)
|
||||
gradients = _memory_efficient_attention_backward(
|
||||
Context(out=output, lse=lse),
|
||||
Inputs(
|
||||
query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode
|
||||
),
|
||||
grad,
|
||||
op=op,
|
||||
)
|
||||
return (gradients.dq, gradients.dk, gradients.dv)
|
||||
|
||||
|
||||
def _memory_efficient_attention(
|
||||
inp: Inputs, op: Optional[AttentionOp] = None
|
||||
) -> torch.Tensor:
|
||||
# fast-path that doesn't require computing the logsumexp for backward computation
|
||||
if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]):
|
||||
return _memory_efficient_attention_forward(
|
||||
inp, op=op[0] if op is not None else None
|
||||
)
|
||||
|
||||
output_shape = inp.normalize_bmhk()
|
||||
return _fMHA.apply(
|
||||
op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode
|
||||
).reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> torch.Tensor:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, False)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
|
||||
out, *_ = op.apply(inp, needs_gradient=False)
|
||||
return out.reshape(output_shape)
|
||||
|
||||
|
||||
def _memory_efficient_attention_forward_requires_grad(
|
||||
inp: Inputs, op: Optional[Type[AttentionFwOpBase]]
|
||||
) -> Tuple[torch.Tensor, Context]:
|
||||
inp.validate_inputs()
|
||||
output_shape = inp.normalize_bmhk()
|
||||
if op is None:
|
||||
op = _dispatch_fw(inp, True)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
|
||||
out = op.apply(inp, needs_gradient=True)
|
||||
assert out[1] is not None
|
||||
return (out[0].reshape(output_shape), out[1])
|
||||
|
||||
|
||||
def _memory_efficient_attention_backward(
|
||||
ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]]
|
||||
) -> Gradients:
|
||||
"""Warning: grad/ctx.out is potentially in BMK format"""
|
||||
inp.validate_inputs()
|
||||
if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim:
|
||||
raise ValueError(
|
||||
"All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n"
|
||||
f"grad.shape : {grad.shape} \n"
|
||||
f"out.shape : {ctx.out.shape} \n"
|
||||
f"query.shape: {inp.query.shape}"
|
||||
)
|
||||
shape_dq, shape_dk, shape_dv = tuple(
|
||||
x.shape for x in (inp.query, inp.key, inp.value)
|
||||
)
|
||||
inp.normalize_bmhk()
|
||||
# LSE has shape [B, H, M] while query has shape [B, M, H, K]
|
||||
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
if (
|
||||
ctx.lse.ndim != 3
|
||||
# Dim 0
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.query.shape[0]
|
||||
)
|
||||
or (
|
||||
isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1
|
||||
)
|
||||
# Dim 1
|
||||
or ctx.lse.shape[1] != inp.query.shape[2]
|
||||
# Dim 2
|
||||
or (
|
||||
not isinstance(inp.attn_bias, BlockDiagonalMask)
|
||||
and ctx.lse.shape[2] < inp.query.shape[1]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Input tensors have incompatible shapes."
|
||||
f"lse.shape : {ctx.lse.shape} \n"
|
||||
f"query.shape : {inp.query.shape}"
|
||||
)
|
||||
grad = bmk2bmhk(grad, 1)
|
||||
ctx.out = bmk2bmhk(ctx.out, 1)
|
||||
|
||||
if op is None:
|
||||
op = _dispatch_bw(inp)
|
||||
else:
|
||||
_ensure_op_supports_or_raise(
|
||||
ValueError, "memory_efficient_attention_backward", op, inp
|
||||
)
|
||||
|
||||
grads = op.apply(ctx, inp, grad)
|
||||
grads.dq = grads.dq.reshape(shape_dq)
|
||||
grads.dk = grads.dk.reshape(shape_dk)
|
||||
grads.dv = grads.dv.reshape(shape_dv)
|
||||
return grads
|
||||
|
||||
|
||||
ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [
|
||||
cutlass.FwOp,
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
small_k.FwOp,
|
||||
triton_splitk.FwOp,
|
||||
]
|
||||
|
||||
ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [
|
||||
cutlass.BwOp,
|
||||
flash.BwOp,
|
||||
triton.BwOp,
|
||||
small_k.BwOp,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"AttentionBias",
|
||||
"AttentionOp",
|
||||
"AttentionOpBase",
|
||||
"AttentionOpDispatch",
|
||||
"LowerTriangularMask",
|
||||
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionTritonFwdFlashBwOp",
|
||||
"MemoryEfficientAttentionCutlassOp",
|
||||
"MemoryEfficientAttentionFlashAttentionOp",
|
||||
"MemoryEfficientAttentionOp",
|
||||
"TritonFlashAttentionOp",
|
||||
"memory_efficient_attention",
|
||||
"ALL_FW_OPS",
|
||||
"ALL_BW_OPS",
|
||||
]
|
||||
BIN
pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc
Normal file
Binary file not shown.
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
779
pkgs/xformers/ops/fmha/attn_bias.py
Normal file
@@ -0,0 +1,779 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionBias:
|
||||
"""Base class for a custom bias that can be applied \
|
||||
as the attn_bias argument in
|
||||
:attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
That function has the ability to add a tensor, the
|
||||
attention bias, to the QK^T matrix before it is used
|
||||
in the softmax part of the attention calculation.
|
||||
The attention bias tensor with shape
|
||||
(B or 1, n_queries, number of keys)
|
||||
can be given as the attn_bias input.
|
||||
The most common use case is for an attention bias is
|
||||
to contain only zeros and negative infinities, which forms
|
||||
a mask so that some queries only attend to some keys.
|
||||
|
||||
Children of this class define alternative things which can
|
||||
be used as the attn_bias input to define an attention bias which
|
||||
forms such a mask, for some common cases.
|
||||
|
||||
When using an :attr:`xformers.ops.AttentionBias`
|
||||
instead of a :attr:`torch.Tensor`, the mask matrix does
|
||||
not need to be materialized, and can be
|
||||
hardcoded into some kernels for better performance.
|
||||
|
||||
See:
|
||||
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`
|
||||
- :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`
|
||||
|
||||
"""
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Materializes the bias as a `torch.Tensor`. This is very slow
|
||||
and we don't attempt to make it fast. Only use for debugging/testing.
|
||||
|
||||
Shape should be like `[*, q_seqlen, k_seqlen]`
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LowerTriangularMask(AttentionBias):
|
||||
"""
|
||||
A lower-triangular (aka causal) mask
|
||||
|
||||
A query Q cannot attend to a key which is farther from the
|
||||
initial key than Q is from the initial query.
|
||||
"""
|
||||
|
||||
def __init__(self, *tensor_args, **tensor_kwargs) -> None:
|
||||
# NOTE: Unused arguments, we keep them for backward compatibility
|
||||
super().__init__()
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore
|
||||
|
||||
def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias":
|
||||
return LowerTriangularMaskWithTensorBias(bias)
|
||||
|
||||
|
||||
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
|
||||
"""A lower-triangular (aka causal) mask with an additive bias"""
|
||||
|
||||
def __init__(self, bias: torch.Tensor) -> None:
|
||||
self._bias = bias
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return super().materialize(shape, dtype=dtype, device=device) + self._bias
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SeqLenInfo:
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks.
|
||||
|
||||
For example, to represents a dimension of length 7 divided into
|
||||
three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 2, 5, 7]
|
||||
seqstart: torch.IntTensor([0, 2, 5, 7])
|
||||
"""
|
||||
|
||||
seqstart: torch.Tensor
|
||||
max_seqlen: int
|
||||
min_seqlen: int
|
||||
seqstart_py: List[int]
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqstart = self.seqstart.to(device, non_blocking=True)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
yield from zip(self.seqstart_py, self.seqstart_py[1:])
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
seqstart_py = [0]
|
||||
max_seqlen = -1
|
||||
min_seqlen = -1
|
||||
for seqlen in seqlens:
|
||||
min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
|
||||
max_seqlen = max(max_seqlen, seqlen)
|
||||
seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
|
||||
seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
|
||||
return cls(
|
||||
max_seqlen=max_seqlen,
|
||||
min_seqlen=min_seqlen,
|
||||
seqstart=seqstart,
|
||||
seqstart_py=seqstart_py,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
|
||||
f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
|
||||
f" seqstart: {self.seqstart_py}"
|
||||
)
|
||||
if batch_sizes is None:
|
||||
batch_sizes = [1] * (len(self.seqstart_py) - 1)
|
||||
split_chunks = []
|
||||
it = 0
|
||||
for batch_size in batch_sizes:
|
||||
split_chunks.append(
|
||||
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
|
||||
)
|
||||
it += batch_size
|
||||
return [
|
||||
tensor.reshape([bs, -1, *tensor.shape[2:]])
|
||||
for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PaddedSeqLenInfo(_SeqLenInfo):
|
||||
"""
|
||||
(Internal) Represents the division of a dimension into blocks which are
|
||||
padded out to the same total length.
|
||||
|
||||
For example, to represent a dimension of length 12 with space for
|
||||
three blocks of length 4, but where the occupied lengths are
|
||||
2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`.
|
||||
|
||||
The layout along the dimension is
|
||||
|
||||
0 ─► block 0
|
||||
block 0
|
||||
<space>
|
||||
<space>
|
||||
4 ─► block 1
|
||||
block 1
|
||||
block 1
|
||||
<space>
|
||||
8 ─► block 2
|
||||
block 2
|
||||
<space>
|
||||
<space>
|
||||
12 ─►
|
||||
|
||||
The members will be:
|
||||
max_seqlen: 3
|
||||
min_seqlen: 2
|
||||
seqstart_py: [0, 4, 8, 12]
|
||||
seqstart: torch.IntTensor([0, 4, 8, 12])
|
||||
seqlen_py: [2, 3, 2]
|
||||
seqlen: torch.IntTensor([2, 3, 2])
|
||||
padding: 4
|
||||
"""
|
||||
|
||||
seqlen: torch.Tensor
|
||||
seqlen_py: Sequence[int]
|
||||
padding: int
|
||||
# From parent: seqstart[i] contains the start position
|
||||
# of the i-th sequence
|
||||
# seqstart: torch.Tensor
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert len(self.seqstart_py) == len(self.seqlen_py) + 1
|
||||
|
||||
def to(self, device: torch.device) -> None:
|
||||
self.seqlen = self.seqlen.to(device, non_blocking=True)
|
||||
super().to(device)
|
||||
|
||||
def intervals(self) -> Iterable[Tuple[int, int]]:
|
||||
for (start, _), length in zip(super().intervals(), self.seqlen_py):
|
||||
yield start, start + length
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
|
||||
raise RuntimeError(
|
||||
"Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens_padded(
|
||||
cls, seqlens: Sequence[int], padding: int
|
||||
) -> "_PaddedSeqLenInfo":
|
||||
"""
|
||||
Input tensors are assumed to be in shape [B, M, *]
|
||||
seqstart = padding * torch.arange(batch_size)
|
||||
"""
|
||||
assert not isinstance(seqlens, torch.Tensor)
|
||||
assert all(seqlen <= padding for seqlen in seqlens)
|
||||
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
|
||||
return cls(
|
||||
seqlen=torch.tensor(seqlens, dtype=torch.int32),
|
||||
seqlen_py=seqlens,
|
||||
max_seqlen=max(seqlens),
|
||||
min_seqlen=min(seqlens),
|
||||
seqstart=torch.tensor(seqstart_py, dtype=torch.int32),
|
||||
seqstart_py=seqstart_py,
|
||||
padding=padding,
|
||||
)
|
||||
|
||||
def split(
|
||||
self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None
|
||||
) -> List[torch.Tensor]:
|
||||
raise NotImplementedError("_PaddedSeqLenInfo.split")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalMask(AttentionBias):
|
||||
"""
|
||||
A block-diagonal mask that can be passed as ``attn_bias``
|
||||
argument to :attr:`xformers.ops.memory_efficient_attention`.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
Queries in block i only attend to keys in block i.
|
||||
|
||||
.. figure:: /_static/block_diag_bias.png
|
||||
|
||||
This bias can be used to handle a batch of sequences of
|
||||
different lengths, via :attr:`BlockDiagonalMask.from_tensor_list`
|
||||
|
||||
:Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import torch
|
||||
from xformers.ops import fmha
|
||||
|
||||
K = 16
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
list_x = [
|
||||
torch.randn([1, 3, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 6, 1, K], dtype=dtype, device=device),
|
||||
torch.randn([1, 2, 1, K], dtype=dtype, device=device),
|
||||
]
|
||||
attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x)
|
||||
linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype)
|
||||
|
||||
q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2)
|
||||
out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
list_out = attn_bias.split(out)
|
||||
print(list_out[0].shape) # [1, 3, 1, K]
|
||||
assert tuple(list_out[0].shape) == (1, 3, 1, K)
|
||||
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _SeqLenInfo
|
||||
_batch_sizes: Optional[Sequence[int]] = None
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
assert shape[-1] == self.k_seqinfo.seqstart_py[-1], (
|
||||
shape[-1],
|
||||
self.k_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
assert shape[-2] == self.q_seqinfo.seqstart_py[-1], (
|
||||
shape[-2],
|
||||
self.q_seqinfo.seqstart_py[-1],
|
||||
)
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_seqlen: Optional[Sequence[int]] = None,
|
||||
) -> "BlockDiagonalMask":
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors
|
||||
kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value.
|
||||
(Defaults to ``q_seqlen``.)
|
||||
Returns:
|
||||
BlockDiagonalMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
if kv_seqlen is None or q_seqlen == kv_seqlen:
|
||||
k_seqinfo = q_seqinfo
|
||||
else:
|
||||
k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
@classmethod
|
||||
def from_tensor_list(
|
||||
cls,
|
||||
tensors: Sequence[torch.Tensor],
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor]:
|
||||
"""Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors
|
||||
concatenated on the sequence length dimension
|
||||
|
||||
.. figure:: /_static/block_diag_cat_split.png
|
||||
|
||||
See also :attr:`BlockDiagonalMask.split` to split the returned
|
||||
:attr:`torch.Tensor` back to a list of tensors of varying sequence length
|
||||
|
||||
Args:
|
||||
tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``.
|
||||
All tensors should have the same dimension and the same batch size ``B``, but
|
||||
they can have different sequence length ``M``.
|
||||
|
||||
Returns:
|
||||
Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention
|
||||
along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]``
|
||||
"""
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors]
|
||||
seqlens = []
|
||||
for x in tensors:
|
||||
for _ in range(x.shape[0]):
|
||||
seqlens.append(x.shape[1])
|
||||
block_diag = cls.from_seqlens(seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors)
|
||||
concat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
return block_diag, concat_tensors
|
||||
|
||||
@classmethod
|
||||
def from_tensor_lists_qkv(
|
||||
cls,
|
||||
tensors_q: Sequence[torch.Tensor],
|
||||
tensors_k: Sequence[torch.Tensor],
|
||||
tensors_v: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert len(tensors_q) == len(tensors_k)
|
||||
assert tensors_v is None or len(tensors_v) == len(tensors_q)
|
||||
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
|
||||
q_seqlens, kv_seqlens = [], []
|
||||
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
|
||||
assert q.shape[0] == k.shape[0]
|
||||
q_seqlens += [q.shape[1]] * q.shape[0]
|
||||
kv_seqlens += [k.shape[1]] * k.shape[0]
|
||||
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
|
||||
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
|
||||
block_diag._batch_sizes = batch_sizes
|
||||
return (
|
||||
block_diag,
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1),
|
||||
torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1)
|
||||
if tensors_v is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
return self.k_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]:
|
||||
"""The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list`
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]``
|
||||
|
||||
Returns:
|
||||
Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths
|
||||
"""
|
||||
assert self.q_seqinfo is self.k_seqinfo
|
||||
return self.q_seqinfo.split(tensor, self._batch_sizes)
|
||||
|
||||
def make_causal(self) -> "BlockDiagonalCausalMask":
|
||||
"""Makes each block causal"""
|
||||
return BlockDiagonalCausalMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask":
|
||||
"""Makes each block causal with a possible non-causal prefix"""
|
||||
return BlockDiagonalCausalFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
)
|
||||
|
||||
def make_local_attention(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionMask":
|
||||
"""Experimental: Makes each block causal with local attention"""
|
||||
return BlockDiagonalCausalLocalAttentionMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
def make_local_attention_from_bottomright(
|
||||
self, window_size: int
|
||||
) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
|
||||
"""Experimental: Makes each block causal with local attention, start from bottom right"""
|
||||
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
q_seqinfo=self.q_seqinfo,
|
||||
k_seqinfo=self.k_seqinfo,
|
||||
_batch_sizes=self._batch_sizes,
|
||||
_window_size=window_size,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
|
||||
Queries and Keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is farther from the initial key in block i than Q
|
||||
is from the initial query in block i.
|
||||
"""
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
return LowerTriangularMask().materialize(
|
||||
shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal.
|
||||
This mask allows for a non-causal prefix
|
||||
NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not
|
||||
defined (softmax of vector of `-inf` in the attention)
|
||||
|
||||
Queries and keys are each divided into the same number of blocks.
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which nearer the final key in block i than Q is to the
|
||||
final query in block i.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
num_queries = q_end - q_start
|
||||
num_keys = k_end - k_start
|
||||
if num_keys < num_queries:
|
||||
raise ValueError(
|
||||
f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}."
|
||||
" Expected `num_keys >= num_queries`"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
|
||||
"""
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`,
|
||||
except an offset on causality is allowed for each block and we support padding for k/v
|
||||
|
||||
The keys and values are divided into blocks which are padded out to
|
||||
the same total length.
|
||||
For example, if there is space for 12 keys, for three blocks of
|
||||
max length 4, but we only want to use the first 2, 3 and 2
|
||||
of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`.
|
||||
The queries are divided into blocks, without padding, of lengths given by
|
||||
q_seqlen.
|
||||
|
||||
A query Q in block i cannot attend to a key which is not in block i,
|
||||
nor one which is not in use (i.e. in the padded area),
|
||||
nor one which is nearer to the final key in block i
|
||||
than Q is to the final query in block i.
|
||||
"""
|
||||
|
||||
q_seqinfo: _SeqLenInfo
|
||||
k_seqinfo: _PaddedSeqLenInfo
|
||||
causal_diagonal: Any = None # unused. Exists for BC only.
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=float("-inf"),
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore
|
||||
|
||||
def materialize(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
"""Materialize the attention bias - for debugging & testing"""
|
||||
if shape[-1] != self.k_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("k shapes wrong")
|
||||
if shape[-2] != self.q_seqinfo.seqstart_py[-1]:
|
||||
raise ValueError("q shapes wrong")
|
||||
mask = torch.empty(shape[-2:], dtype=dtype, device=device)
|
||||
mask.fill_(-math.inf)
|
||||
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
|
||||
zip(
|
||||
self.q_seqinfo.intervals(),
|
||||
self.k_seqinfo.intervals(),
|
||||
)
|
||||
):
|
||||
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
|
||||
(q_end - q_start, k_end - k_start),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(len(shape) - 2):
|
||||
mask = mask.unsqueeze(0)
|
||||
return mask.expand(shape)
|
||||
|
||||
@classmethod
|
||||
def from_seqlens(
|
||||
cls,
|
||||
q_seqlen: Sequence[int],
|
||||
kv_padding: int,
|
||||
kv_seqlen: Sequence[int],
|
||||
causal_diagonal: Any = None,
|
||||
) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
|
||||
"""Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor
|
||||
lengths for query and key/value.
|
||||
|
||||
Args:
|
||||
q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors
|
||||
kv_padding (int): Padding for k/v - also an upperbound on each individual key length
|
||||
kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value.
|
||||
causal_diagonal: unused, for BC only
|
||||
Returns:
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
"""
|
||||
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
|
||||
q_seqlen,
|
||||
kv_seqlen,
|
||||
)
|
||||
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
|
||||
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
|
||||
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore
|
||||
if self._window_size is not None and self._window_size > 0:
|
||||
mask = torch.triu(mask, diagonal=-self._window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockDiagonalCausalLocalAttentionFromBottomRightMask(
|
||||
BlockDiagonalCausalFromBottomRightMask
|
||||
):
|
||||
"""
|
||||
(Experimental feature)
|
||||
Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`.
|
||||
This makes the mask "local" and the attention pattern banded.
|
||||
|
||||
Query i only attends to keys in its block and cannot attend keys further than "window_size"
|
||||
from it.
|
||||
"""
|
||||
|
||||
_window_size: int = 0 # forced due to inheritance and default arguments
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self._window_size <= 0:
|
||||
raise ValueError(
|
||||
f"Expected `window_size > 0`, but window_size={self._window_size}"
|
||||
)
|
||||
q_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
kv_seqlen = [
|
||||
y - x
|
||||
for x, y in zip(
|
||||
self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:]
|
||||
)
|
||||
]
|
||||
for q, k in zip(q_seqlen, kv_seqlen):
|
||||
if q + (q - k) - self._window_size >= k:
|
||||
raise RuntimeError(
|
||||
f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}"
|
||||
)
|
||||
materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen)))
|
||||
if torch.max(materialized, dim=1).values.min() == -float("inf"):
|
||||
raise RuntimeError("FUCKING FUCK FUCK")
|
||||
|
||||
def _create_block_mask(
|
||||
self,
|
||||
shape: Tuple[int, ...],
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
) -> torch.Tensor:
|
||||
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
|
||||
tensor = torch.full( # type: ignore
|
||||
shape,
|
||||
dtype=create_as,
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
num_queries, num_keys = shape[-2:]
|
||||
mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore
|
||||
if self._window_size is not None:
|
||||
mask = torch.triu(
|
||||
mask, diagonal=num_keys - num_queries - self._window_size + 1
|
||||
)
|
||||
mask = torch.log(mask)
|
||||
return mask.to(dtype)
|
||||
542
pkgs/xformers/ops/fmha/common.py
Normal file
542
pkgs/xformers/ops/fmha/common.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..._cpp_lib import _built_with_cuda
|
||||
from ..common import BaseOperator
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
|
||||
|
||||
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
|
||||
# NoneType
|
||||
if isinstance(None, attn_bias_type):
|
||||
return True
|
||||
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class Inputs:
|
||||
"""
|
||||
Stores inputs to the `memory_efficient_attention` operators
|
||||
"""
|
||||
|
||||
query: torch.Tensor
|
||||
key: torch.Tensor
|
||||
value: torch.Tensor
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
|
||||
p: float = 0.0
|
||||
scale: Optional[float] = None
|
||||
use_alibi: bool = False
|
||||
alibi_mode: int = 1
|
||||
imp_mode: int = 0
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.query.device
|
||||
|
||||
@property
|
||||
def scale_float(self) -> float:
|
||||
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
|
||||
|
||||
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.query.ndim == 5:
|
||||
return self.query, self.key, self.value
|
||||
if self.query.ndim == 4:
|
||||
return (
|
||||
self.query.unsqueeze(2),
|
||||
self.key.unsqueeze(2),
|
||||
self.value.unsqueeze(2),
|
||||
)
|
||||
if self.value.ndim == 3:
|
||||
return (
|
||||
self.query[:, :, None, None],
|
||||
self.key[:, :, None, None],
|
||||
self.value[:, :, None, None],
|
||||
)
|
||||
assert False
|
||||
|
||||
def normalize_bmhk(self) -> Tuple[int, ...]:
|
||||
if self.query.ndim not in [3, 4, 5]:
|
||||
raise ValueError(
|
||||
f"Invalid shape for query: {self.query.shape}. "
|
||||
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
|
||||
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
|
||||
)
|
||||
if self.value.dtype == torch.int32:
|
||||
# Quantized K/V case, in which the last dims of Q and K are different.
|
||||
# NB we currently don't have any implementations for quantized KV with
|
||||
# SUPPORTS_DIFFERENT_VALUE_EMBED.
|
||||
output_shape = tuple(self.query.shape)
|
||||
else:
|
||||
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
|
||||
# Convert from legacy format
|
||||
if self.query.ndim == 3:
|
||||
self.query = self.query.unsqueeze(2)
|
||||
self.key = self.key.unsqueeze(2)
|
||||
self.value = self.value.unsqueeze(2)
|
||||
if isinstance(self.attn_bias, torch.Tensor):
|
||||
if self.attn_bias.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
|
||||
)
|
||||
self.attn_bias = self.attn_bias.unsqueeze(1)
|
||||
return output_shape
|
||||
|
||||
def validate_inputs(self) -> None:
|
||||
qkv = (self.query, self.key, self.value)
|
||||
if self.query.ndim not in (3, 4, 5) or any(
|
||||
x.ndim != self.query.ndim for x in qkv
|
||||
):
|
||||
raise ValueError(
|
||||
f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if any(x.device != self.query.device for x in qkv):
|
||||
raise ValueError("Query/Key/Value should all be on the same device")
|
||||
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
|
||||
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
|
||||
if not (quantized_dtypes or non_quantized_dtypes):
|
||||
raise ValueError(
|
||||
"Query/Key/Value should either all have the same dtype, or "
|
||||
"(in the quantized case) Key/Value should have dtype torch.int32\n"
|
||||
f" query.dtype: {self.query.dtype}\n"
|
||||
f" key.dtype : {self.key.dtype}\n"
|
||||
f" value.dtype: {self.value.dtype}"
|
||||
)
|
||||
# Biases with tensors attached are meant to be in BMHK format
|
||||
# This would require to permute biases/gradients which can be expensive,
|
||||
# so let's just forbid it - BMK is a legacy format anyway
|
||||
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
|
||||
type(self.attn_bias)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Please provide inputs in BMHK format rather "
|
||||
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
|
||||
)
|
||||
attn_bias_t: Optional[torch.Tensor] = None
|
||||
if isinstance(self.attn_bias, torch.Tensor):
|
||||
attn_bias_t = self.attn_bias
|
||||
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
|
||||
attn_bias_t = self.attn_bias._bias
|
||||
if self.query.ndim == 4 and attn_bias_t is not None:
|
||||
expected_shape = (
|
||||
self.query.shape[0],
|
||||
self.query.shape[2],
|
||||
self.query.shape[1],
|
||||
self.key.shape[1],
|
||||
)
|
||||
if attn_bias_t.shape != expected_shape:
|
||||
raise ValueError(
|
||||
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if isinstance(self.attn_bias, BlockDiagonalMask):
|
||||
if any(x.shape[0] != 1 for x in qkv):
|
||||
raise ValueError(
|
||||
f"Expected batch_size=1 when using block-diagonal bias\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}"
|
||||
)
|
||||
if self.p < 0.0 or self.p > 1.0:
|
||||
raise ValueError(f"Invalid dropout probability: p={self.p}")
|
||||
# Check that shapes match between inputs
|
||||
B, Mq = self.query.shape[:2]
|
||||
K = self.query.shape[-1]
|
||||
B, Mkv = self.key.shape[:2]
|
||||
Kv = self.value.shape[-1]
|
||||
|
||||
valid_shapes = True
|
||||
if self.query.ndim == 3: # BMK
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, K)
|
||||
and self.key.shape == (B, Mkv, K)
|
||||
and self.value.shape == (B, Mkv, Kv)
|
||||
)
|
||||
H = self.query.shape[-2]
|
||||
if self.query.ndim == 4: # BMHK
|
||||
quantized_kv_cache = self.value.dtype == torch.int32
|
||||
key_embed_dim = Kv if quantized_kv_cache else K
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, H, K)
|
||||
and self.key.shape == (B, Mkv, H, key_embed_dim)
|
||||
and self.value.shape == (B, Mkv, H, Kv)
|
||||
)
|
||||
G = self.query.shape[2]
|
||||
if self.query.ndim == 5: # BMNHK
|
||||
valid_shapes = (
|
||||
self.query.shape == (B, Mq, G, H, K)
|
||||
and self.key.shape == (B, Mkv, G, H, K)
|
||||
and self.value.shape == (B, Mkv, G, H, Kv)
|
||||
)
|
||||
if not valid_shapes:
|
||||
raise ValueError(
|
||||
f"Incompatible shapes for attention inputs:\n"
|
||||
f" query.shape: {self.query.shape}\n"
|
||||
f" key.shape : {self.key.shape}\n"
|
||||
f" value.shape: {self.value.shape}\n"
|
||||
"HINT: We don't support broadcasting, please use `expand` "
|
||||
"yourself before calling `memory_efficient_attention` if you need to"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
lse: torch.Tensor
|
||||
out: torch.Tensor
|
||||
q_padded: Optional[torch.Tensor] = None
|
||||
k_padded: Optional[torch.Tensor] = None
|
||||
v_padded: Optional[torch.Tensor] = None
|
||||
o_padded: Optional[torch.Tensor] = None
|
||||
op_bw: Optional[Type["AttentionBwOpBase"]] = None
|
||||
rng_state: Optional[torch.Tensor] = None
|
||||
|
||||
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
|
||||
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
|
||||
lse = self.lse
|
||||
if pad_amount > 0:
|
||||
if force_pad_inf:
|
||||
lse = lse[:, :, : self.out.shape[1]]
|
||||
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
|
||||
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
|
||||
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
|
||||
lse[:, :, self.out.shape[1] :].fill_(math.inf)
|
||||
return lse
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gradients:
|
||||
dq: torch.Tensor
|
||||
dk: torch.Tensor
|
||||
dv: torch.Tensor
|
||||
# bias gradient. None if there is no tensor bias or if it doesn't require grad
|
||||
db: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class AttentionOpBase(BaseOperator):
|
||||
"""Base class for any attention operator in xFormers
|
||||
|
||||
See:
|
||||
|
||||
- :attr:`xformers.ops.fmha.cutlass.FwOp`
|
||||
- :attr:`xformers.ops.fmha.cutlass.BwOp`
|
||||
- :attr:`xformers.ops.fmha.flash.FwOp`
|
||||
- :attr:`xformers.ops.fmha.flash.BwOp`
|
||||
- :attr:`xformers.ops.fmha.triton.FwOp`
|
||||
- :attr:`xformers.ops.fmha.triton.BwOp`
|
||||
- :attr:`xformers.ops.fmha.small_k.FwOp`
|
||||
- :attr:`xformers.ops.fmha.small_k.BwOp`
|
||||
"""
|
||||
|
||||
OPERATOR: Any
|
||||
SUPPORTED_DEVICES: Set[str]
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
|
||||
SUPPORTED_DTYPES: Set[torch.dtype]
|
||||
SUPPORTED_MAX_K: float
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
|
||||
SUPPORTS_DROPOUT: bool
|
||||
SUPPORTS_CUSTOM_SCALE: bool = False
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
|
||||
IS_DETERMINISTIC: bool = True
|
||||
SUPPORTS_BMGHK: bool = False
|
||||
NAME: str
|
||||
OPERATOR_CATEGORY = "memory_efficient_attention"
|
||||
|
||||
_TEST_BATCH_SIZES: List[int] = [1, 300]
|
||||
_TEST_K: List[int] = [32, 128]
|
||||
|
||||
@classmethod
|
||||
def supports(cls, d: Inputs) -> bool:
|
||||
return not cls.not_supported_reasons(d)
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = []
|
||||
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
|
||||
reasons.append("query.shape[-1] != value.shape[-1]")
|
||||
if max(K, Kv) > cls.SUPPORTED_MAX_K:
|
||||
reasons.append(
|
||||
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
"""
|
||||
Returns a list of reasons why this is not supported.
|
||||
The kernel can run these inputs only if the returned list is empty
|
||||
"""
|
||||
reasons = cls.shape_not_supported_reasons(
|
||||
Mq=d.query.shape[1],
|
||||
Mkv=d.key.shape[1],
|
||||
K=d.query.shape[-1],
|
||||
Kv=d.query.shape[-1],
|
||||
)
|
||||
device_type = d.query.device.type
|
||||
dtype = d.query.dtype
|
||||
if device_type not in cls.SUPPORTED_DEVICES:
|
||||
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
||||
if device_type == "cuda" and not _built_with_cuda:
|
||||
reasons.append("xFormers wasn't build with CUDA support")
|
||||
if device_type == "cuda":
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
||||
reasons.append(
|
||||
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
|
||||
f"but your GPU has capability {device_capability} (too old)"
|
||||
)
|
||||
if dtype not in cls.SUPPORTED_DTYPES:
|
||||
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
|
||||
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
|
||||
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
|
||||
reasons.append("dropout > 0.0")
|
||||
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
|
||||
reasons.append("has custom scale")
|
||||
# bfloat16 is only supported on A100+
|
||||
# ... although the kernels can still run and give the
|
||||
# correct result
|
||||
if dtype is torch.bfloat16 and (
|
||||
not device_type.startswith("cuda")
|
||||
):
|
||||
reasons.append("bf16 is only supported on A100+ GPUs")
|
||||
if not cls.is_available():
|
||||
reasons.append(
|
||||
"operator wasn't built - see `python -m xformers.info` for more info"
|
||||
)
|
||||
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
|
||||
reasons.append(
|
||||
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
|
||||
)
|
||||
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
|
||||
reasons.append("operator does not support BMGHK format")
|
||||
return reasons
|
||||
|
||||
|
||||
class AttentionFwOpBase(AttentionOpBase):
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 3e-4,
|
||||
torch.half: 4e-3,
|
||||
torch.bfloat16: 2e-2,
|
||||
}
|
||||
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 2e-5,
|
||||
torch.half: 4e-4,
|
||||
torch.bfloat16: 5e-3,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def attn_operator_flop(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
causal: bool = False,
|
||||
seqstart_k: Optional[torch.Tensor] = None,
|
||||
seqstart_q: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Computes total flops for the attention
|
||||
Assumes inputs in format BMHK
|
||||
"""
|
||||
assert query.ndim == 4
|
||||
|
||||
if seqstart_q is not None:
|
||||
seqstart_q_py = seqstart_q.tolist()
|
||||
else:
|
||||
seqstart_q_py = [0, query.shape[1]]
|
||||
if seqstart_k is not None:
|
||||
seqstart_k_py = seqstart_k.tolist()
|
||||
else:
|
||||
seqstart_k_py = [0, key.shape[1]]
|
||||
|
||||
total_flop = 0
|
||||
for q_start, q_end, k_start, k_end in zip(
|
||||
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
|
||||
):
|
||||
num_q = q_end - q_start
|
||||
num_kv = k_end - k_start
|
||||
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
|
||||
# Q @ K.transpose
|
||||
total_flop += num_q * num_kv * query.shape[-1] * 2
|
||||
# (ignore softmax)
|
||||
# attn @ V
|
||||
total_flop += num_q * key.shape[-1] * num_kv * 2
|
||||
# Multiply by num_heads and batches
|
||||
total_flop = total_flop * value.shape[2] * value.shape[0]
|
||||
if causal:
|
||||
total_flop //= 2
|
||||
return total_flop
|
||||
|
||||
|
||||
class AttentionBwOpBase(AttentionOpBase):
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 5e-4,
|
||||
torch.half: 9e-2,
|
||||
torch.bfloat16: 0.7,
|
||||
}
|
||||
ERROR_RTOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 1e-4,
|
||||
torch.half: 2e-2,
|
||||
torch.bfloat16: 0.1,
|
||||
}
|
||||
SUPPORTS_ATTN_BIAS_GRAD = False
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
|
||||
if (
|
||||
isinstance(d.attn_bias, torch.Tensor)
|
||||
and d.attn_bias.requires_grad
|
||||
and not cls.SUPPORTS_ATTN_BIAS_GRAD
|
||||
):
|
||||
reasons.append(
|
||||
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
|
||||
)
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def attn_operator_flop(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
causal: bool = False,
|
||||
seqstart_k: Optional[torch.Tensor] = None,
|
||||
seqstart_q: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Computes total flops for the attention
|
||||
Assumes inputs in format BMHK
|
||||
"""
|
||||
assert query.ndim == 4
|
||||
|
||||
if seqstart_q is not None:
|
||||
seqstart_q_py = seqstart_q.tolist()
|
||||
else:
|
||||
seqstart_q_py = [0, query.shape[1]]
|
||||
if seqstart_k is not None:
|
||||
seqstart_k_py = seqstart_k.tolist()
|
||||
else:
|
||||
seqstart_k_py = [0, key.shape[1]]
|
||||
|
||||
total_flop = 0
|
||||
for q_start, q_end, k_start, k_end in zip(
|
||||
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
|
||||
):
|
||||
num_q = q_end - q_start
|
||||
num_kv = k_end - k_start
|
||||
Kqk = query.shape[-1]
|
||||
Kv = value.shape[-1]
|
||||
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
|
||||
# att = Q @ K.transpose
|
||||
total_flop += num_q * num_kv * Kqk * 2
|
||||
# att @ dO
|
||||
total_flop += num_kv * num_q * Kv * 2
|
||||
# dov = dO @ V
|
||||
total_flop += num_q * Kv * num_kv * 2
|
||||
# dov @ K
|
||||
total_flop += num_q * Kqk * num_kv * 2
|
||||
# dov @ Q
|
||||
total_flop += num_q * Kqk * num_kv * 2
|
||||
# Multiply by num_heads and batches
|
||||
total_flop = total_flop * value.shape[2] * value.shape[0]
|
||||
if causal:
|
||||
total_flop //= 2
|
||||
return total_flop
|
||||
|
||||
|
||||
AttentionOp = Tuple[
|
||||
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionOpDispatch:
|
||||
"""Dispatcher to automatically select
|
||||
the best operator to run memory-efficient attention.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This class is deprecated and will be removed in a later version
|
||||
"""
|
||||
|
||||
op: AttentionOp
|
||||
|
||||
@classmethod
|
||||
def from_arguments(
|
||||
cls,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
) -> "AttentionOpDispatch":
|
||||
"""Here for backward compatibility"""
|
||||
from .dispatch import _dispatch_bw, _dispatch_fw
|
||||
|
||||
inp = Inputs(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_bias,
|
||||
p=p,
|
||||
scale=scale,
|
||||
)
|
||||
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))
|
||||
|
||||
|
||||
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
|
||||
(0, 2, 1, 3)
|
||||
)
|
||||
|
||||
|
||||
def check_lastdim_alignment_stride1(
|
||||
reasons: List[str], name: str, x: torch.Tensor, alignment: int
|
||||
) -> None:
|
||||
if x.shape[-1] % alignment != 0:
|
||||
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
|
||||
elif x.stride(-2) % alignment != 0:
|
||||
reasons.append(
|
||||
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
|
||||
)
|
||||
# We can have stride=0 sometimes if dimension=1
|
||||
if x.stride(-1) > 1:
|
||||
reasons.append(
|
||||
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
|
||||
)
|
||||
479
pkgs/xformers/ops/fmha/cutlass.py
Normal file
479
pkgs/xformers/ops/fmha/cutlass.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from enum import Enum
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from . import attn_bias
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
)
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
|
||||
|
||||
def _uses_tensorcores(sm: int, is_half: bool) -> bool:
|
||||
if sm >= 80:
|
||||
return True
|
||||
if sm >= 70:
|
||||
return is_half
|
||||
return False
|
||||
|
||||
|
||||
def _minimum_gemm_alignment(inp: Inputs) -> int:
|
||||
if inp.device.type != "cuda":
|
||||
return 1
|
||||
cap = torch.cuda.get_device_capability(inp.device)
|
||||
sm = cap[0] * 10 + cap[1]
|
||||
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
|
||||
inp.query.dtype
|
||||
]
|
||||
uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
|
||||
matmul_alignment_mn = 1
|
||||
if sm >= 80:
|
||||
matmul_alignment_mn = 4
|
||||
if uses_tensorcores:
|
||||
matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
|
||||
return matmul_alignment_mn
|
||||
|
||||
|
||||
def _get_seqlen_info(
|
||||
inp: Inputs,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
|
||||
attn_bias = inp.attn_bias
|
||||
if isinstance(
|
||||
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
):
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
seqstart_k = attn_bias.k_seqinfo.seqstart
|
||||
seqstart_q = attn_bias.q_seqinfo.seqstart
|
||||
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
||||
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
||||
else:
|
||||
seqstart_k = None
|
||||
seqstart_q = None
|
||||
max_seqlen_q = -1
|
||||
max_seqlen_k = -1
|
||||
|
||||
return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
|
||||
|
||||
|
||||
def _get_tensor_bias(
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if isinstance(attn_bias, torch.Tensor):
|
||||
return attn_bias
|
||||
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
|
||||
return attn_bias._bias
|
||||
return None
|
||||
|
||||
|
||||
def _check_bias_alignment(
|
||||
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> None:
|
||||
attn_bias_tensor = _get_tensor_bias(attn_bias)
|
||||
if attn_bias_tensor is not None:
|
||||
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
|
||||
show_padding_hint = False
|
||||
for d in range(attn_bias_tensor.ndim - 1):
|
||||
if attn_bias_tensor.stride(d) % alignment != 0:
|
||||
reasons.append(
|
||||
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
|
||||
)
|
||||
show_padding_hint = True
|
||||
if show_padding_hint:
|
||||
reasons.append(
|
||||
"""\
|
||||
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
|
||||
you need to ensure memory is aligned by slicing a bigger tensor. \
|
||||
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
|
||||
)
|
||||
# We can have stride=0 sometimes if dimension=1
|
||||
if attn_bias_tensor.stride(-1) > 1:
|
||||
reasons.append(
|
||||
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
|
||||
"you should call `.contiguous()` on the bias"
|
||||
)
|
||||
|
||||
|
||||
class _CustomMaskType(int, Enum):
|
||||
"""
|
||||
(Matches CustomMaskType in C++.)
|
||||
"""
|
||||
|
||||
NoCustomMask = 0
|
||||
CausalFromTopLeft = 1
|
||||
CausalFromBottomRight = 2
|
||||
|
||||
|
||||
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
||||
if isinstance(
|
||||
bias,
|
||||
(
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
),
|
||||
):
|
||||
return int(_CustomMaskType.CausalFromTopLeft)
|
||||
if isinstance(
|
||||
bias,
|
||||
(
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
):
|
||||
return int(_CustomMaskType.CausalFromBottomRight)
|
||||
return int(_CustomMaskType.NoCustomMask)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""xFormers' MHA kernel based on CUTLASS.
|
||||
Supports a large number of settings (including without TensorCores, f32 ...)
|
||||
and GPUs as old as P100 (Sm60)
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 65536
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
torch.Tensor,
|
||||
LowerTriangularMask,
|
||||
LowerTriangularMaskWithTensorBias,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = True
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = "cutlassF"
|
||||
|
||||
_TEST_K: List[int] = [
|
||||
32, # 64x64 kernel
|
||||
128, # 64x128 kernel
|
||||
256, # 64x128 with accumulation in gmem
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
if inp.query.ndim in [3, 4]:
|
||||
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
|
||||
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
|
||||
|
||||
# Workaround until this is properly implemented in C++
|
||||
# run each head group in a different stream
|
||||
n_groups = inp.key.shape[2]
|
||||
main_stream = torch.cuda.current_stream()
|
||||
streams = [main_stream] + [
|
||||
torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
|
||||
]
|
||||
outs = []
|
||||
for group, stream in enumerate(streams):
|
||||
stream.wait_stream(main_stream)
|
||||
with torch.cuda.stream(stream):
|
||||
query = inp.query[:, :, group]
|
||||
key = inp.key[:, :, group]
|
||||
value = inp.value[:, :, group]
|
||||
bias = inp.attn_bias
|
||||
if isinstance(bias, torch.Tensor):
|
||||
bias = bias[:, group]
|
||||
if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias):
|
||||
bias = attn_bias.LowerTriangularMaskWithTensorBias(
|
||||
bias._bias[:, group]
|
||||
)
|
||||
outs.append(
|
||||
cls.apply_bmhk(
|
||||
replace(inp, query=query, key=key, value=value, attn_bias=bias),
|
||||
needs_gradient=needs_gradient,
|
||||
)
|
||||
)
|
||||
for s in streams[1:]:
|
||||
main_stream.wait_stream(s)
|
||||
out = torch.stack([o[0] for o in outs], dim=2)
|
||||
ctx: Optional[Context] = None
|
||||
if needs_gradient:
|
||||
ctx = Context(
|
||||
out=out,
|
||||
lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
|
||||
op_bw=outs[0][1].op_bw, # type: ignore
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
@classmethod
|
||||
def apply_bmhk(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
|
||||
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
||||
query=inp.query,
|
||||
key=inp.key,
|
||||
value=inp.value,
|
||||
attn_bias=_get_tensor_bias(inp.attn_bias),
|
||||
seqstart_q=seqstart_q,
|
||||
seqstart_k=seqstart_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
dropout_p=inp.p,
|
||||
compute_logsumexp=needs_gradient,
|
||||
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||||
scale=inp.scale,
|
||||
seqlen_k=inp.attn_bias.k_seqinfo.seqlen
|
||||
if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
else None,
|
||||
window_size=inp.attn_bias._window_size
|
||||
if isinstance(
|
||||
inp.attn_bias,
|
||||
(
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
else None,
|
||||
)
|
||||
ctx: Optional[Context] = None
|
||||
if needs_gradient:
|
||||
ctx = Context(
|
||||
out=out,
|
||||
lse=lse,
|
||||
# cutlass forward is only compatible with cutlass backward if
|
||||
# dropout is used (because of the way RNG states are passed and the
|
||||
# way random numbers are generated during backward)
|
||||
op_bw=BwOp if inp.p != 0 else None,
|
||||
)
|
||||
if inp.p != 0:
|
||||
ctx.rng_state = torch.tensor(
|
||||
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||||
_check_bias_alignment(reasons, d.attn_bias)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
b,
|
||||
seqstart_q,
|
||||
seqstart_k,
|
||||
max_seqlen_q_,
|
||||
compute_lse,
|
||||
custom_mask_type,
|
||||
*a,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
causal=custom_mask_type > 0,
|
||||
seqstart_k=seqstart_k,
|
||||
seqstart_q=seqstart_q,
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
torch.Tensor,
|
||||
LowerTriangularMask,
|
||||
# TODO: Fix handling of gradient through the fMHA autograd function
|
||||
# LowerTriangularMaskWithTensorBias,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||||
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||||
}
|
||||
SUPPORTS_ATTN_BIAS_GRAD = True
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
NAME = "cutlassB"
|
||||
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 5e-4,
|
||||
# increased from 9e-2, more opportunities for numerical errors when bias is
|
||||
# used, noticed in gK on SM80
|
||||
torch.half: 1e-1,
|
||||
torch.bfloat16: 7e-1,
|
||||
}
|
||||
|
||||
_TEST_K: List[int] = [
|
||||
32, # 64x64 kernel
|
||||
128, # 64x128/128x128 kernel
|
||||
256, # 64x128 with accumulation in gmem
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||||
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||||
_check_bias_alignment(reasons, d.attn_bias)
|
||||
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
|
||||
|
||||
# Backprop of gradient through broadcasted bias is not supported
|
||||
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
|
||||
# Don't forget that inputs are either in BMK or BMHK!
|
||||
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
|
||||
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
|
||||
else:
|
||||
# bias is B H Mq Mk
|
||||
expected_bias_shape = (
|
||||
d.query.shape[0],
|
||||
d.query.shape[2] if d.query.ndim == 4 else 1,
|
||||
d.query.shape[1],
|
||||
d.key.shape[1],
|
||||
)
|
||||
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
|
||||
reasons.append(
|
||||
"Broadcasting the `attn_bias` tensor is not supported "
|
||||
f"(shape: {tuple(attn_bias_tensor.shape)}"
|
||||
f"/ expected: {expected_bias_shape})"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||||
raise NotImplementedError("Unsupported attn_bias type")
|
||||
|
||||
seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
|
||||
dtype = inp.query.dtype
|
||||
|
||||
rng_seed = rng_offset = 0
|
||||
if inp.p != 0.0:
|
||||
if (
|
||||
ctx.rng_state is None
|
||||
or ctx.rng_state.dtype != torch.int64
|
||||
or ctx.rng_state.device.type != "cpu"
|
||||
or ctx.rng_state.shape != (2,)
|
||||
):
|
||||
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
||||
rng_seed, rng_offset = ctx.rng_state.tolist()
|
||||
|
||||
force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
|
||||
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
|
||||
grad.to(dtype),
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
_get_tensor_bias(inp.attn_bias),
|
||||
cu_seqlens_q=seqstart_q,
|
||||
cu_seqlens_k=seqstart_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
|
||||
output=ctx.out.to(dtype),
|
||||
dropout_p=inp.p,
|
||||
# if not using dropout, seed and offset are irrelevant but still expected
|
||||
# in function signature so just pass 0
|
||||
# seed and offset could be None if a different FW op other than cutlass
|
||||
# was used.
|
||||
rng_seed=rng_seed,
|
||||
rng_offset=rng_offset,
|
||||
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||||
scale=inp.scale,
|
||||
num_splits_key=-1, # Let C++ determine it
|
||||
window_size=inp.attn_bias._window_size
|
||||
if isinstance(
|
||||
inp.attn_bias,
|
||||
(
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
else None,
|
||||
)
|
||||
|
||||
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
|
||||
# require grad
|
||||
if not (
|
||||
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
|
||||
):
|
||||
grad_bias = None
|
||||
|
||||
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
dO,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
b,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
logsumexp,
|
||||
output,
|
||||
dropout_p,
|
||||
rng_seed,
|
||||
rng_offset,
|
||||
custom_mask_type,
|
||||
scale,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
seqstart_q=cu_seqlens_q,
|
||||
seqstart_k=cu_seqlens_k,
|
||||
causal=custom_mask_type > 0,
|
||||
)
|
||||
108
pkgs/xformers/ops/fmha/decoder.py
Normal file
108
pkgs/xformers/ops/fmha/decoder.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
from .common import AttentionFwOpBase, Context, Inputs
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""An operator optimized for very small values of K (``K <= 32``) \
|
||||
and f32 pre-Ampere as it does not use TensorCores.
|
||||
Only supports contiguous inputs in BMK format, so an extra reshape \
|
||||
or contiguous call might be done.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This operator is deprecated and should not be used in new code
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
|
||||
SUPPORTED_MAX_K: float = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
NAME = "decoderF"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
|
||||
attn_bias = d.attn_bias
|
||||
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
||||
# If we don't get here, we've an error elsewhere
|
||||
if d.query.ndim != 4 or d.key.ndim != 4:
|
||||
reasons.append("Inputs must be BMHK. BMK not supported")
|
||||
|
||||
if d.query.shape[0] != 1:
|
||||
reasons.append("One formal batch element expected")
|
||||
|
||||
if d.query.shape[-1] != 128:
|
||||
reasons.append("Only head_dim==128 for now.")
|
||||
|
||||
if d.key.stride(-1) != 1:
|
||||
reasons.append("expect keys to have last dim contiguous")
|
||||
|
||||
if d.value.stride(-1) != 1:
|
||||
reasons.append("expect values to have last dim contiguous")
|
||||
|
||||
q_starts = attn_bias.q_seqinfo.seqstart_py
|
||||
if attn_bias.q_seqinfo.max_seqlen != 1:
|
||||
reasons.append("decoding expects one query")
|
||||
elif d.query.shape[1] != len(q_starts) - 1:
|
||||
reasons.append("empty lanes not supported yet")
|
||||
|
||||
if attn_bias.k_seqinfo.padding > 8192:
|
||||
reasons.append("key padding exceeds 8192")
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if needs_gradient:
|
||||
raise NotImplementedError("gradient")
|
||||
attn_bias = inp.attn_bias
|
||||
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
|
||||
padding = attn_bias.k_seqinfo.padding
|
||||
multiquery = inp.key.stride(2) == 0
|
||||
if multiquery:
|
||||
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
|
||||
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
|
||||
else:
|
||||
key = inp.key[0].unflatten(0, (-1, padding))
|
||||
value = inp.value[0].unflatten(0, (-1, padding))
|
||||
|
||||
seq_positions = attn_bias.k_seqinfo.seqlen
|
||||
|
||||
query = inp.query[0, :, None]
|
||||
|
||||
if inp.scale is not None:
|
||||
qk_scale = inp.scale
|
||||
else:
|
||||
qk_scale = 1.0 / np.sqrt(key.shape[-1])
|
||||
|
||||
out = cls.OPERATOR(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
seq_positions=seq_positions,
|
||||
scale=qk_scale,
|
||||
)
|
||||
return out, None
|
||||
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
147
pkgs/xformers/ops/fmha/dispatch.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import textwrap
|
||||
from collections import deque
|
||||
from typing import List, Sequence, Type, TypeVar
|
||||
|
||||
from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk
|
||||
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs
|
||||
|
||||
|
||||
def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _is_triton_fwd_fastest(inp: Inputs) -> bool:
|
||||
# TODO: fill out
|
||||
return False
|
||||
|
||||
|
||||
T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase])
|
||||
|
||||
|
||||
def _format_inputs_description(inp: Inputs) -> str:
|
||||
return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype})
|
||||
key : shape={tuple(inp.key.shape)} ({inp.key.dtype})
|
||||
value : shape={tuple(inp.value.shape)} ({inp.value.dtype})
|
||||
attn_bias : {type(inp.attn_bias)}
|
||||
p : {inp.p}"""
|
||||
|
||||
|
||||
def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None:
|
||||
reasons = op.not_supported_reasons(inp)
|
||||
if not reasons:
|
||||
return
|
||||
raise exc_type(
|
||||
f"""Operator `{name}` does not support inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}
|
||||
{_format_not_supported_reasons(op, reasons)}"""
|
||||
)
|
||||
|
||||
|
||||
def _format_not_supported_reasons(op, reasons: List[str]) -> str:
|
||||
return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons)
|
||||
|
||||
|
||||
def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T:
|
||||
not_supported_reasons: List[List[str]] = []
|
||||
for op in priority_list:
|
||||
not_supported = op.not_supported_reasons(inp)
|
||||
if not not_supported:
|
||||
return op
|
||||
not_supported_reasons.append(not_supported)
|
||||
|
||||
# Let's write a nice message explaining what we tried and why it's not supported
|
||||
msg = f"""No operator found for `{name}` with inputs:
|
||||
{textwrap.indent(_format_inputs_description(inp), ' ')}"""
|
||||
for op, not_supported in zip(priority_list, not_supported_reasons):
|
||||
msg += "\n" + _format_not_supported_reasons(op, not_supported)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _dispatch_fw_priority_list(
|
||||
inp: Inputs, needs_gradient: bool
|
||||
) -> Sequence[Type[AttentionFwOpBase]]:
|
||||
priority_list_ops = deque(
|
||||
[
|
||||
flash.FwOp,
|
||||
triton.FwOp,
|
||||
cutlass.FwOp,
|
||||
small_k.FwOp,
|
||||
]
|
||||
)
|
||||
if _is_cutlass_fwd_faster_than_flash(inp):
|
||||
priority_list_ops.remove(cutlass.FwOp)
|
||||
priority_list_ops.appendleft(cutlass.FwOp)
|
||||
if _is_triton_fwd_fastest(inp):
|
||||
priority_list_ops.remove(triton.FwOp)
|
||||
priority_list_ops.appendleft(triton.FwOp)
|
||||
if not needs_gradient:
|
||||
mqa_or_gqa = (
|
||||
inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1
|
||||
)
|
||||
if not mqa_or_gqa:
|
||||
# With multiquery, cutlass is sometimes faster than decoder
|
||||
# but it's not currently clear when.
|
||||
priority_list_ops.appendleft(decoder.FwOp)
|
||||
# Split-KV is useful with MQA
|
||||
# for short Q-seqlen / long K-seqlen
|
||||
if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256:
|
||||
parallelism_BH = 0 # BMK
|
||||
if inp.query.ndim == 3:
|
||||
parallelism_BH = inp.query.shape[0]
|
||||
elif inp.query.ndim == 4: # BMHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
elif inp.query.ndim == 5: # BMGHK
|
||||
parallelism_BH = inp.query.shape[0] * inp.query.shape[2]
|
||||
if parallelism_BH > 0 and parallelism_BH < 64:
|
||||
priority_list_ops.appendleft(triton_splitk.FwOp)
|
||||
# Without variable seqlen flash is fastest
|
||||
if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask):
|
||||
priority_list_ops.remove(flash.FwOp)
|
||||
priority_list_ops.appendleft(flash.FwOp)
|
||||
|
||||
return priority_list_ops
|
||||
|
||||
|
||||
def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]:
|
||||
"""Computes the best operator for forward
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if not operator was found
|
||||
|
||||
Returns:
|
||||
AttentionOp: The best operator for the configuration
|
||||
"""
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_forward",
|
||||
# _dispatch_fw_priority_list(inp, needs_gradient),
|
||||
# inp,
|
||||
# )
|
||||
return flash.FwOp
|
||||
|
||||
|
||||
def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]:
|
||||
priority_list_ops: List[Type[AttentionBwOpBase]] = [
|
||||
flash.BwOp,
|
||||
cutlass.BwOp,
|
||||
# CUDA illegal memory issues, race conditions etc..
|
||||
# triton.BwOp,
|
||||
# Deprecated
|
||||
small_k.BwOp,
|
||||
]
|
||||
# if _is_cutlassB_faster_than_flash(inp):
|
||||
# priority_list_ops.remove(cutlass.BwOp)
|
||||
# priority_list_ops.insert(0, cutlass.BwOp)
|
||||
# return _run_priority_list(
|
||||
# "memory_efficient_attention_backward", priority_list_ops, inp
|
||||
# )
|
||||
return flash.BwOp
|
||||
666
pkgs/xformers/ops/fmha/flash.py
Normal file
666
pkgs/xformers/ops/fmha/flash.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from itertools import zip_longest
|
||||
from typing import Any, List, Optional, Set, Tuple, Union
|
||||
import os
|
||||
import torch
|
||||
|
||||
from ..common import _get_storage_base, get_operator, register_operator
|
||||
from .attn_bias import (
|
||||
AttentionBias,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalMask,
|
||||
LowerTriangularMask,
|
||||
)
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
global enable_ixdnn
|
||||
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
|
||||
|
||||
FLASH_VERSION = "0.0.0"
|
||||
try:
|
||||
try:
|
||||
from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
from ..._cpp_lib import _build_metadata
|
||||
|
||||
if _build_metadata is not None:
|
||||
FLASH_VERSION = _build_metadata.flash_version
|
||||
except ImportError:
|
||||
import flash_attn
|
||||
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
|
||||
FLASH_VERSION = flash_attn.__version__
|
||||
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
||||
if flash_ver_parsed < (2, 3):
|
||||
raise ImportError("Requires 2.3 for sliding window support")
|
||||
|
||||
# create library so that flash-attn goes through the PyTorch Dispatcher
|
||||
_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
|
||||
_flash_lib.define(
|
||||
"flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
||||
"int max_seqlen_q, int max_seqlen_k, "
|
||||
"float p, float softmax_scale, "
|
||||
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
_flash_lib.define(
|
||||
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
"int max_seqlen_q, int max_seqlen_k, "
|
||||
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
|
||||
)
|
||||
|
||||
def _flash_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode
|
||||
):
|
||||
if enable_ixdnn:
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.fwd_ixdnn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None, # out
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
p,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode,
|
||||
-1, -1,
|
||||
None, # rng
|
||||
)
|
||||
else:
|
||||
if cu_seq_lens_q is None:
|
||||
assert cu_seq_lens_k is None
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None, # out
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None, # rng
|
||||
)
|
||||
else:
|
||||
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
|
||||
(
|
||||
out,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
out_padded,
|
||||
softmax_lse,
|
||||
p,
|
||||
) = _C_flashattention.varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False,
|
||||
is_causal,
|
||||
return_softmax,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
|
||||
|
||||
def _flash_bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode
|
||||
):
|
||||
if enable_ixdnn:
|
||||
_C_flashattention.bwd_ixdnn(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
p,
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
imp_mode,
|
||||
-1, -1,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
if cu_seq_lens_k is None:
|
||||
assert cu_seq_lens_q is None
|
||||
_C_flashattention.bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
_C_flashattention.varlen_bwd(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
is_causal,
|
||||
use_alibi,
|
||||
alibi_mode,
|
||||
None,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _convert_input_format(
|
||||
inp: Inputs,
|
||||
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
|
||||
assert inp.query.ndim in [4, 5]
|
||||
query, key, value = inp.query, inp.key, inp.value
|
||||
batch = query.shape[0]
|
||||
seqlen_q = query.shape[1]
|
||||
seqlen_kv = key.shape[1]
|
||||
head_dim_q = query.shape[-1]
|
||||
head_dim_v = value.shape[-1]
|
||||
|
||||
attn_bias = inp.attn_bias
|
||||
if enable_ixdnn:
|
||||
if query.shape[0] == 1:
|
||||
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
|
||||
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
|
||||
else:
|
||||
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
|
||||
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
|
||||
max_seqlen_q = inp.query.shape[1]
|
||||
max_seqlen_k = inp.key.shape[1]
|
||||
else:
|
||||
if isinstance(attn_bias, BlockDiagonalMask):
|
||||
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
|
||||
inp.query.device, non_blocking=True
|
||||
)
|
||||
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
|
||||
inp.query.device, non_blocking=True
|
||||
)
|
||||
|
||||
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
|
||||
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
|
||||
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
||||
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
||||
else:
|
||||
cu_seqlen_k = None
|
||||
cu_seqlen_q = None
|
||||
max_seqlen_q = inp.query.shape[1]
|
||||
max_seqlen_k = inp.key.shape[1]
|
||||
|
||||
if query.ndim == 5: # QGA
|
||||
# Fold the group/head_in_group dimensions together
|
||||
def fold(x):
|
||||
# Either the head is replicated
|
||||
if x.stride(3) == 0:
|
||||
return x[:, :, :, 0]
|
||||
# Or we reshape
|
||||
return x.reshape(
|
||||
[
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
-1,
|
||||
x.shape[4],
|
||||
]
|
||||
)
|
||||
|
||||
query = fold(query)
|
||||
key = fold(key)
|
||||
value = fold(value)
|
||||
# Optimize for MHA
|
||||
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
|
||||
key = key[:, :, :1]
|
||||
value = value[:, :, :1]
|
||||
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
|
||||
# We want format `[batch * seqlen, num_heads, head_dim_q]`
|
||||
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
|
||||
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
|
||||
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
|
||||
new_inp = replace(
|
||||
inp,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
|
||||
|
||||
|
||||
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
|
||||
return isinstance(
|
||||
attn_bias,
|
||||
(
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
||||
if isinstance(
|
||||
attn_bias,
|
||||
(BlockDiagonalCausalLocalAttentionMask,),
|
||||
):
|
||||
return attn_bias._window_size or 0
|
||||
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
|
||||
return attn_bias._window_size
|
||||
return 0
|
||||
|
||||
|
||||
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
|
||||
# Flash does not support TopLeft, so only allow causal masks with TopLeft
|
||||
# if each batch element has equal number of queries and keys.
|
||||
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
|
||||
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
|
||||
# if each batch element has equal number of queries and keys.
|
||||
for k_start, q_start in zip_longest(
|
||||
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
|
||||
):
|
||||
if k_start != q_start:
|
||||
reasons.append(
|
||||
"Only support BlockDiagonalCausalMask if equal"
|
||||
" numbers of keys and queries"
|
||||
)
|
||||
break
|
||||
elif isinstance(d.attn_bias, LowerTriangularMask):
|
||||
if d.query.shape[1] != d.key.shape[1]:
|
||||
reasons.append(
|
||||
"Only support LowerTriangularMask if equal number of" "keys and queries"
|
||||
)
|
||||
|
||||
|
||||
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
|
||||
"""
|
||||
We want to be able to collapse the G/H dimensions together
|
||||
"""
|
||||
if x.ndim == 5:
|
||||
stride_g, stride_h = x.stride(2), x.stride(3)
|
||||
if x.shape[2] == 1:
|
||||
return
|
||||
if x.shape[3] == 1 or stride_h == 0:
|
||||
return
|
||||
if stride_g != stride_h * x.shape[-2]:
|
||||
reasons.append(
|
||||
f"GQA is only supported when the G/H dimensions are contiguous\n"
|
||||
f" {name}.stride: {x.stride()}\n"
|
||||
f" {name}.shape : {list(x.shape)}"
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Operator that computes memory-efficient attention using \
|
||||
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
|
||||
implementation.
|
||||
"""
|
||||
|
||||
OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 256
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
LowerTriangularMask,
|
||||
BlockDiagonalMask,
|
||||
BlockDiagonalCausalMask,
|
||||
BlockDiagonalCausalLocalAttentionMask,
|
||||
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||||
BlockDiagonalCausalFromBottomRightMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = False
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = f"flshattF@{FLASH_VERSION}"
|
||||
VERSION = FLASH_VERSION
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
_check_needs_no_topleft(d, reasons)
|
||||
_check_strides_for_bmghk(d.query, "query", reasons)
|
||||
_check_strides_for_bmghk(d.key, "key", reasons)
|
||||
_check_strides_for_bmghk(d.value, "value", reasons)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
return_softmax = False
|
||||
out_shape = [
|
||||
*inp.query.shape[:-1],
|
||||
inp.value.shape[-1],
|
||||
]
|
||||
# no cumulative seqlen
|
||||
(
|
||||
inp,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_k,
|
||||
) = _convert_input_format(inp)
|
||||
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
inp.p,
|
||||
inp.scale_float,
|
||||
_is_causal(inp.attn_bias),
|
||||
_window_size(inp.attn_bias),
|
||||
return_softmax,
|
||||
inp.use_alibi,
|
||||
inp.alibi_mode,
|
||||
inp.imp_mode,
|
||||
)
|
||||
out = out.reshape(out_shape)
|
||||
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
|
||||
if inp.p != 0.0:
|
||||
ctx.op_bw = BwOp
|
||||
return (out, ctx)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
return_softmax,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
causal=causal,
|
||||
seqstart_k=cu_seq_lens_k,
|
||||
seqstart_q=cu_seq_lens_q,
|
||||
)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_operator("xformers_flash", "flash_bwd")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
IS_DETERMINISTIC = False
|
||||
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
|
||||
NAME = f"flshattB@{FLASH_VERSION}"
|
||||
VERSION = FLASH_VERSION
|
||||
|
||||
MAX_HEADDIM_SM8x = 192
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
||||
|
||||
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
|
||||
# is a strange embedding dimension.
|
||||
# if K not in {8, 16, 32, 64, 128, 256}:
|
||||
# reasons.append(f"Embed dim {K} not supported")
|
||||
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
_check_needs_no_topleft(d, reasons)
|
||||
if d.device.type == "cuda":
|
||||
# Due to limited shared-memory, some GPUs are limited in head dimension
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
|
||||
if (
|
||||
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
|
||||
and not is_sm80_or_sm90
|
||||
):
|
||||
reasons.append(
|
||||
"requires a GPU with compute capability 8.0 "
|
||||
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
|
||||
(
|
||||
inp,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_k,
|
||||
) = _convert_input_format(inp)
|
||||
assert ctx.lse.is_contiguous
|
||||
ctx_lse = ctx.lse
|
||||
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
||||
assert ctx_lse.shape[2] >= max_seqlen_q
|
||||
if max_seqlen_q != ctx_lse.shape[2]:
|
||||
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
|
||||
kernel_out_shape = [
|
||||
*inp.query.shape[:-1],
|
||||
inp.value.shape[-1],
|
||||
]
|
||||
|
||||
# Create dq,dk,dv
|
||||
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
|
||||
# right strides, so we can avoid a `cat`
|
||||
if (
|
||||
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
|
||||
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
|
||||
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
||||
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
||||
):
|
||||
# Create one big contiguous chunk
|
||||
# This is because q, k and v usually come from a single
|
||||
# output of a linear layer that is chunked.
|
||||
# Creating the gradients with the right layout saves us
|
||||
# a `torch.cat` call in the backward pass
|
||||
chunk = torch.empty(
|
||||
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
|
||||
dtype=ctx.q_padded.dtype,
|
||||
device=inp.device,
|
||||
)
|
||||
grads = Gradients(
|
||||
dq=chunk.select(-3, 0),
|
||||
dk=chunk.select(-3, 1),
|
||||
dv=chunk.select(-3, 2),
|
||||
)
|
||||
else:
|
||||
grads = Gradients(
|
||||
dq=torch.empty_like(ctx.q_padded),
|
||||
dk=torch.empty_like(ctx.k_padded),
|
||||
dv=torch.empty_like(ctx.v_padded),
|
||||
)
|
||||
|
||||
assert grad.dtype in cls.SUPPORTED_DTYPES
|
||||
cls.OPERATOR(
|
||||
grad.reshape(kernel_out_shape).contiguous(),
|
||||
ctx.q_padded,
|
||||
ctx.k_padded,
|
||||
ctx.v_padded,
|
||||
# ctx.out.reshape(kernel_out_shape),
|
||||
ctx.o_padded,
|
||||
ctx_lse,
|
||||
grads.dq,
|
||||
grads.dk,
|
||||
grads.dv,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
inp.p,
|
||||
inp.scale_float,
|
||||
_is_causal(inp.attn_bias),
|
||||
_window_size(inp.attn_bias),
|
||||
inp.use_alibi,
|
||||
inp.alibi_mode,
|
||||
inp.imp_mode,
|
||||
)
|
||||
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
|
||||
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
|
||||
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
|
||||
|
||||
return grads
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls,
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
causal,
|
||||
) -> int:
|
||||
return cls.attn_operator_flop(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
causal=causal,
|
||||
seqstart_k=cu_seq_lens_k,
|
||||
seqstart_q=cu_seq_lens_q,
|
||||
)
|
||||
186
pkgs/xformers/ops/fmha/small_k.py
Normal file
186
pkgs/xformers/ops/fmha/small_k.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import get_xformers_operator, register_operator
|
||||
from .attn_bias import AttentionBias
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
bmk2bmhk,
|
||||
)
|
||||
|
||||
|
||||
def _bmhk2bmk_contiguous(tensor) -> torch.Tensor:
|
||||
return (
|
||||
tensor.permute((0, 2, 1, 3))
|
||||
.contiguous()
|
||||
.view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]])
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
|
||||
def _get_tensor_bias_bmk(
|
||||
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||||
) -> Optional[torch.Tensor]:
|
||||
if not isinstance(attn_bias, torch.Tensor):
|
||||
assert attn_bias is None
|
||||
return None
|
||||
# BMK -> BMHK
|
||||
if attn_bias.ndim == 4:
|
||||
attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]])
|
||||
return attn_bias
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""An operator optimized for very small values of K (``K <= 32``) \
|
||||
and f32 pre-Ampere as it does not use TensorCores.
|
||||
Only supports contiguous inputs in BMK format, so an extra reshape \
|
||||
or contiguous call might be done.
|
||||
|
||||
:Deprecated:
|
||||
|
||||
This operator is deprecated and should not be used in new code
|
||||
"""
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_forward_small_k")
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
SUPPORTED_DTYPES = {torch.float}
|
||||
SUPPORTED_MAX_K: float = 32
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor}
|
||||
SUPPORTS_DROPOUT = True
|
||||
SUPPORTS_CUSTOM_SCALE = False
|
||||
NAME = "smallkF"
|
||||
|
||||
BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 4e-3,
|
||||
}
|
||||
# as this kernel is a bit slow, this should make tests run faster
|
||||
_TEST_BATCH_SIZES = [1, 3]
|
||||
_TEST_K = [2, 3, 8, 16, 32]
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
|
||||
reasons.append("bias with non-zero stride not supported")
|
||||
buffer_size = 8
|
||||
k = d.query.shape[-1]
|
||||
for pack in [1, 2, 4]:
|
||||
if (k % pack) == 0 and (k // pack) <= buffer_size:
|
||||
return reasons
|
||||
reasons.append(f"unsupported embed per head: {k}")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
if inp.scale is not None:
|
||||
raise NotImplementedError("Unsupport custom scale")
|
||||
num_heads = inp.query.shape[2]
|
||||
query = _bmhk2bmk_contiguous(inp.query)
|
||||
key = _bmhk2bmk_contiguous(inp.key)
|
||||
value = _bmhk2bmk_contiguous(inp.value)
|
||||
|
||||
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
compute_logsumexp=needs_gradient,
|
||||
attn_bias=_get_tensor_bias_bmk(inp.attn_bias),
|
||||
p=inp.p,
|
||||
)
|
||||
out = bmk2bmhk(out, num_heads)
|
||||
lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]])
|
||||
if not needs_gradient:
|
||||
return out, None
|
||||
ctx = Context(out=out, lse=lse)
|
||||
if inp.p != 0.0:
|
||||
ctx.op_bw = BwOp
|
||||
ctx.rng_state = torch.tensor(
|
||||
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
return out, ctx
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = get_xformers_operator("efficient_attention_backward_small_k")
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
|
||||
# there is some extra precision loss in the CPU implementation due to an
|
||||
# extra accumulation step in grad_q, which is not present in the CUDA
|
||||
# implementation
|
||||
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||||
torch.float: 4e-3,
|
||||
}
|
||||
NAME = "smallkB"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0:
|
||||
reasons.append("bias with non-zero stride not supported")
|
||||
buffer_size = 8
|
||||
k = d.query.shape[-1]
|
||||
for pack in [1, 2, 4]:
|
||||
if (k % pack) == 0 and (k // pack) <= buffer_size:
|
||||
return reasons
|
||||
reasons.append(f"unsupported embed per head: {k}")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
num_heads = grad.shape[2]
|
||||
grad = _bmhk2bmk_contiguous(grad)
|
||||
query = _bmhk2bmk_contiguous(inp.query)
|
||||
key = _bmhk2bmk_contiguous(inp.key)
|
||||
value = _bmhk2bmk_contiguous(inp.value)
|
||||
out = _bmhk2bmk_contiguous(ctx.out)
|
||||
|
||||
rng_seed = rng_offset = 0
|
||||
if inp.p != 0.0:
|
||||
if (
|
||||
ctx.rng_state is None
|
||||
or ctx.rng_state.dtype != torch.int64
|
||||
or ctx.rng_state.device.type != "cpu"
|
||||
or ctx.rng_state.shape != (2,)
|
||||
):
|
||||
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
||||
rng_seed, rng_offset = ctx.rng_state.tolist()
|
||||
grad_q, grad_k, grad_v = cls.OPERATOR(
|
||||
grad,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
# LSE: BHM -> (BH)M
|
||||
ctx.lse.reshape([-1, ctx.lse.shape[-1]]),
|
||||
out,
|
||||
_get_tensor_bias_bmk(inp.attn_bias),
|
||||
inp.p,
|
||||
rng_seed,
|
||||
rng_offset,
|
||||
)
|
||||
return Gradients(
|
||||
dq=bmk2bmhk(grad_q, num_heads),
|
||||
dk=bmk2bmhk(grad_k, num_heads),
|
||||
dv=bmk2bmhk(grad_v, num_heads),
|
||||
)
|
||||
201
pkgs/xformers/ops/fmha/triton.py
Normal file
201
pkgs/xformers/ops/fmha/triton.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ... import _is_triton_available
|
||||
from ..common import register_operator
|
||||
|
||||
# This implementation needs pre-MLIR triton
|
||||
# The BW pass is not stable/well tested
|
||||
# And also does not have the latest improvements
|
||||
if TYPE_CHECKING or (False and _is_triton_available()):
|
||||
try:
|
||||
from flash_attn.flash_attn_triton import (
|
||||
_flash_attn_backward,
|
||||
_flash_attn_forward,
|
||||
)
|
||||
except ImportError:
|
||||
import importlib
|
||||
import pathlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
def import_module_from_path(path: str) -> types.ModuleType:
|
||||
"""Import a module from the given path, w/o __init__.py"""
|
||||
module_path = pathlib.Path(path).resolve()
|
||||
module_name = module_path.stem # 'path/x.py' -> 'x'
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore
|
||||
assert isinstance(spec, importlib.machinery.ModuleSpec)
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
sys.modules[module_name] = module
|
||||
assert isinstance(spec.loader, importlib.abc.Loader)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
flash_attn = import_module_from_path(
|
||||
"third_party/flash-attention/flash_attn/flash_attn_triton.py"
|
||||
)
|
||||
_flash_attn_backward = flash_attn._flash_attn_backward
|
||||
_flash_attn_forward = flash_attn._flash_attn_forward
|
||||
|
||||
triton_flash_backward = _flash_attn_backward
|
||||
triton_flash_forward = _flash_attn_forward
|
||||
else:
|
||||
triton_flash_backward = None
|
||||
triton_flash_forward = None
|
||||
|
||||
from .attn_bias import LowerTriangularMask
|
||||
from .common import (
|
||||
AttentionBwOpBase,
|
||||
AttentionFwOpBase,
|
||||
Context,
|
||||
Gradients,
|
||||
Inputs,
|
||||
check_lastdim_alignment_stride1,
|
||||
)
|
||||
|
||||
|
||||
def _prepare_inputs(inp: Inputs) -> Inputs:
|
||||
attn_bias = inp.attn_bias
|
||||
if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
|
||||
B = inp.query.shape[0]
|
||||
h = attn_bias.shape[0] // B
|
||||
attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])
|
||||
|
||||
# Make sure that the last dimension is contiguous
|
||||
query, key, value = [
|
||||
x if x.stride(-1) == 1 else x.contiguous()
|
||||
for x in [inp.query, inp.key, inp.value]
|
||||
]
|
||||
return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Operator that computes memory-efficient attention using \
|
||||
`Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \
|
||||
implementation, based on
|
||||
`Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_
|
||||
"""
|
||||
|
||||
OPERATOR = triton_flash_forward
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {torch.half, torch.bfloat16}
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
LowerTriangularMask,
|
||||
# TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now.
|
||||
# torch.Tensor,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
NAME = "tritonflashattF"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
# Fails on 7.5 with illegal memory access
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
out, lse, softmax_scale = triton_flash_forward(
|
||||
q=inp.query,
|
||||
k=inp.key,
|
||||
v=inp.value,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return out, Context(lse=lse, out=out)
|
||||
|
||||
|
||||
@register_operator
|
||||
class BwOp(AttentionBwOpBase):
|
||||
__doc__ = FwOp.__doc__
|
||||
|
||||
OPERATOR = triton_flash_backward
|
||||
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
|
||||
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||||
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||||
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
||||
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||||
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||||
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||||
NAME = "tritonflashattB"
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
if torch.cuda.get_device_capability(d.device) != (8, 0):
|
||||
reasons.append("requires A100 GPU")
|
||||
if _is_triton_available():
|
||||
import triton
|
||||
|
||||
if triton.__version__ > "2.0.0":
|
||||
reasons.append("Only work on pre-MLIR triton for now")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||||
inp = _prepare_inputs(inp)
|
||||
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
grads = Gradients(
|
||||
dq=torch.empty_like(inp.query),
|
||||
dk=torch.empty_like(inp.key),
|
||||
dv=torch.empty_like(inp.value),
|
||||
)
|
||||
cls.OPERATOR(
|
||||
grad,
|
||||
inp.query,
|
||||
inp.key,
|
||||
inp.value,
|
||||
ctx.out,
|
||||
ctx.get_padded_lse(128),
|
||||
grads.dq,
|
||||
grads.dk,
|
||||
grads.dv,
|
||||
bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None,
|
||||
softmax_scale=inp.scale_float,
|
||||
causal=isinstance(inp.attn_bias, LowerTriangularMask),
|
||||
)
|
||||
return grads
|
||||
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import _has_triton21, register_operator
|
||||
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
|
||||
|
||||
|
||||
def _strides(x: torch.Tensor, *stride_names: str):
|
||||
assert x.ndim == len(stride_names)
|
||||
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
|
||||
|
||||
|
||||
if TYPE_CHECKING or _has_triton21():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_splitK(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Seq_len,
|
||||
stride_qz,
|
||||
stride_qm,
|
||||
stride_qg,
|
||||
stride_qh,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kn,
|
||||
stride_kg,
|
||||
stride_kh,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vn,
|
||||
stride_vg,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
Z,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
BLOCK_N_PER_SPLIT,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
USE_SEQ_LEN: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr = 1,
|
||||
N_GROUPS: tl.constexpr = 1,
|
||||
):
|
||||
"""This kernel can accept non-quantized or int4-quantized keys/values.
|
||||
PACKED_PER_VAL determines the quantization type:
|
||||
- PACKED_PER_VAL == 1 means no quantization
|
||||
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
|
||||
For the quantized case K/V should be int32 tensors.
|
||||
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
|
||||
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
|
||||
So K[B, H, M, :] has a form
|
||||
[ quant_coef0, quant_coef1, ...|
|
||||
group0_quant_value0, group0_quant_value1,... |
|
||||
group1_quant_value0, group1_quant_value1,...]
|
||||
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
|
||||
|
||||
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
|
||||
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
|
||||
See how FwOp.apply does it below.
|
||||
"""
|
||||
tl.static_assert(
|
||||
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
|
||||
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
|
||||
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
|
||||
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
|
||||
)
|
||||
tl.static_assert(
|
||||
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
|
||||
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
|
||||
)
|
||||
|
||||
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
|
||||
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
|
||||
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
|
||||
|
||||
start_m = tl.program_id(0)
|
||||
off_zhg = tl.program_id(1)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
splitk_idx = tl.program_id(2)
|
||||
|
||||
lo = splitk_idx * BLOCK_N_PER_SPLIT
|
||||
if USE_SEQ_LEN:
|
||||
kv_len = tl.load(Seq_len + off_z)
|
||||
else:
|
||||
kv_len = N_CTX_K
|
||||
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
|
||||
# Additional shift by 1 along the last dimension in the quantized case, since
|
||||
# the first element along that dim contains packed quantization coefficients.
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
|
||||
shape=(PACKED_D_PER_GROUP, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
|
||||
shape=(hi, PACKED_D_PER_GROUP),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
if QUANTIZED:
|
||||
# Pointers to quantization coefficients. Even those they are 1D,
|
||||
# we have to use block pointers, since usual pointers
|
||||
# don't support boundary checks
|
||||
K_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=k_base,
|
||||
shape=(1, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(1, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=v_base,
|
||||
shape=(hi, 1),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, 1),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
K_scale_shift_block_ptr = None
|
||||
V_scale_shift_block_ptr = None
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
|
||||
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
|
||||
# This is a solution for Triton native lack of support for lists of tensors.
|
||||
acc: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
q[i] = tl.load( # noqa: F821
|
||||
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
||||
)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
k: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
v: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N,
|
||||
PACKED_PER_VAL,
|
||||
PACKED_D_PER_GROUP,
|
||||
Q.dtype.element_ty,
|
||||
i,
|
||||
)
|
||||
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
qk += tl.dot(q[i], k[i]) # noqa: F821
|
||||
qk *= qk_scale
|
||||
|
||||
# TODO: This is slow, and only needed at the last iteration.
|
||||
# Maybe we can unroll the last iteration instead?
|
||||
if BOUNDS_CHECKS_N:
|
||||
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
|
||||
# -- scale and update acc --
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] *= alpha[:, None] # noqa: F821
|
||||
acc[i] += tl.dot(p, v[i]) # noqa: F821
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
if PACKED_PER_VAL > 1:
|
||||
K_scale_shift_block_ptr = tl.advance(
|
||||
K_scale_shift_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.advance(
|
||||
V_scale_shift_block_ptr, (BLOCK_N, 0)
|
||||
)
|
||||
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_osk_m, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
tl.store(
|
||||
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
|
||||
acc[i], # noqa: F821
|
||||
boundary_check=(0,),
|
||||
)
|
||||
# Write metadata for split-K reduction
|
||||
Metadata_ptr = (
|
||||
Metadata
|
||||
+ off_zhg * stride_mzhg
|
||||
+ splitk_idx * stride_ms
|
||||
+ start_m * BLOCK_M
|
||||
+ tl.arange(0, BLOCK_M)
|
||||
)
|
||||
tl.store(Metadata_ptr, m_i)
|
||||
tl.store(Metadata_ptr + stride_m2, l_i)
|
||||
|
||||
@triton.jit
|
||||
def load_dequantize_k_v_group(
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr,
|
||||
PACKED_D_PER_GROUP: tl.constexpr,
|
||||
dtype: tl.constexpr,
|
||||
group_id: tl.constexpr,
|
||||
):
|
||||
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
|
||||
If quantization is group-wise, use group_id to advance the pointers to the current group.
|
||||
"""
|
||||
# Advance to the current quantization group
|
||||
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
|
||||
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
|
||||
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
|
||||
|
||||
if PACKED_PER_VAL > 1:
|
||||
# K/V are quantized, load quantization coefficients and dequantize
|
||||
|
||||
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
|
||||
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
|
||||
|
||||
k_scale_shift = tl.load(
|
||||
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
v_scale_shift = tl.load(
|
||||
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
|
||||
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
|
||||
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
|
||||
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
|
||||
k_t = dequantize(
|
||||
tl.trans(k),
|
||||
tl.trans(k_scale),
|
||||
tl.trans(k_shift),
|
||||
PACKED_PER_VAL,
|
||||
).to(dtype)
|
||||
k = tl.trans(k_t)
|
||||
return k, v
|
||||
|
||||
@triton.jit
|
||||
def cast_uint32_to_half2(scale_shift):
|
||||
"""Extract two float16 packed into one int32"""
|
||||
scale = scale_shift & 0xFFFF
|
||||
shift = scale_shift >> 16
|
||||
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
return scale, shift
|
||||
|
||||
@triton.jit
|
||||
def dequantize(
|
||||
x_,
|
||||
scale,
|
||||
shift,
|
||||
PACKED_PER_VAL: tl.constexpr = 8,
|
||||
):
|
||||
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
||||
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
||||
"""
|
||||
|
||||
# Axis along which offsets are applied matters here
|
||||
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
||||
# and expand K/V to that shape before applying offsets
|
||||
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
|
||||
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
|
||||
# on the implementation details which might change in the future.
|
||||
# Ideally we would like to use tl.reshape, but it's not implemented yet.
|
||||
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
|
||||
|
||||
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
||||
# scale: (BLOCK_N, 1)
|
||||
# offsets: (PACKED_PER_VAL,)
|
||||
BLOCK_N: tl.constexpr = x_.shape[0]
|
||||
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
||||
offsets = tl.arange(0, PACKED_PER_VAL) * 4
|
||||
quant_offset = (
|
||||
x_[:, None, :] >> offsets[None, :, None]
|
||||
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
|
||||
|
||||
quant_offset = tl.view(
|
||||
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
||||
)
|
||||
# Trick - instead of converting int4 to float16 we view it as float16
|
||||
# and then multiply by 32768 * 512 == 2**24
|
||||
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
||||
scale_512 = scale * 512
|
||||
|
||||
dequant = quant_offset * scale_512 + shift
|
||||
return dequant
|
||||
|
||||
@triton.jit
|
||||
def _splitK_reduce(
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Out, # [B, H, M, K]
|
||||
LSE, # [B, H, M]
|
||||
split_k,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_og,
|
||||
stride_om,
|
||||
stride_ok,
|
||||
stride_lse_zhg,
|
||||
stride_lse_m,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
):
|
||||
off_zhg = tl.program_id(0)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
off_m = tl.program_id(1)
|
||||
|
||||
Out_splitK_ptr = (
|
||||
Out_splitK
|
||||
+ stride_osk_zhg * off_zhg
|
||||
+ stride_osk_m * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
|
||||
m = tl.load(Metadata_ptr)
|
||||
l_sum = tl.load(Metadata_ptr + stride_m2)
|
||||
acc = tl.load(Out_splitK_ptr)
|
||||
|
||||
for split_k_idx in range(1, split_k):
|
||||
Metadata_ptr = Metadata_ptr + stride_ms
|
||||
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
|
||||
|
||||
m_k = tl.load(Metadata_ptr)
|
||||
l_k = tl.load(Metadata_ptr + stride_m2)
|
||||
acc_k = tl.load(Out_splitK_ptr)
|
||||
|
||||
m_new = tl.maximum(m, m_k)
|
||||
if m_k < m:
|
||||
# Scale incoming values
|
||||
alpha = tl.math.exp2(m_k - m_new)
|
||||
acc_k = acc_k * alpha
|
||||
l_k = l_k * alpha
|
||||
else:
|
||||
# Scale our values
|
||||
alpha = tl.math.exp2(m - m_new)
|
||||
acc = acc * alpha
|
||||
l_sum = l_sum * alpha
|
||||
|
||||
m = m_new
|
||||
l_sum = l_sum + l_k
|
||||
acc = acc + acc_k
|
||||
|
||||
acc = acc / l_sum
|
||||
Out_ptr = (
|
||||
Out
|
||||
+ stride_oz * off_z
|
||||
+ stride_oh * off_h
|
||||
+ stride_og * off_g
|
||||
+ stride_om * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
tl.store(Out_ptr, acc)
|
||||
|
||||
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
|
||||
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
|
||||
|
||||
else:
|
||||
_fwd_kernel_splitK = None
|
||||
_splitK_reduce = None
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
|
||||
Quantized path will be taken if input K/V have type int32.
|
||||
|
||||
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
|
||||
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
|
||||
Quantization coefficients (scale and shift) are represented as two
|
||||
float16 constants per group, packed into int32. Quantization coefficients of
|
||||
all groups are placed at the beginning of the row. So, if unquantized K/V have head
|
||||
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
|
||||
and dtype int32.
|
||||
Pseudocode for dequantizing one row can look like:
|
||||
group_size = D // 8
|
||||
for i in range(NUM_GROUPS):
|
||||
group_start = NUM_GROUPS + i * group_size
|
||||
group_quant = K[..., group_start: group_start + group_size]
|
||||
scale, shift = unpack_int32_into_float16x2(group_quant[0])
|
||||
group_dequant = group_quant[..., 1:] * scale + shift
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
OPERATOR = _fwd_kernel_splitK
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = "triton_splitKF"
|
||||
|
||||
SPLIT_K: Optional[int] = None
|
||||
BLOCK_M = 16
|
||||
BLOCK_N = 64
|
||||
|
||||
NUM_GROUPS = 1 # Default quantization is row-wise
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
||||
if K not in {16, 32, 64, 128}:
|
||||
reasons.append(f"Embed dim {K} not supported")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
if d.key.dtype != torch.int32:
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
|
||||
q_len = d.query.shape[1]
|
||||
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
||||
seqinfo = d.attn_bias.q_seqinfo
|
||||
if q_len != seqinfo.seqstart_py[-1]:
|
||||
reasons.append(
|
||||
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
||||
)
|
||||
q_len = seqinfo.min_seqlen
|
||||
if q_len != seqinfo.max_seqlen:
|
||||
reasons.append(
|
||||
"Variable query len is not supported in the presence of causal mask."
|
||||
)
|
||||
|
||||
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
|
||||
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
|
||||
reasons.append("multiquery is only supported with query seqlen=1")
|
||||
|
||||
if d.attn_bias is not None and q_len > 1:
|
||||
reasons.append(
|
||||
"query with seqlen > 1 is not supported in the presence of causal mask"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
|
||||
"""Heuristic for the number of splits"""
|
||||
bh = B * H
|
||||
split_k = max(Mk, 1024) // bh
|
||||
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
||||
while split_k > 0 and Mk / split_k < max_chunk_size:
|
||||
split_k = split_k // 2
|
||||
split_k = min(split_k, 64)
|
||||
split_k = max(split_k, 1)
|
||||
return split_k
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
attn_bias = inp.attn_bias
|
||||
seq_len = None
|
||||
q, k, v = inp.get_qkv_in_bmghk()
|
||||
|
||||
if attn_bias is not None:
|
||||
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
# TODO: do we really need to do this cast? seems fishy but
|
||||
# I just copied it from the decoder.py
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
seq_len = attn_bias.k_seqinfo.seqlen
|
||||
B = len(seq_len)
|
||||
G, H, Kq = q.shape[-3:]
|
||||
Kkv = v.shape[-1]
|
||||
|
||||
# assume kv has been padded
|
||||
q = q.reshape(B, -1, G, H, Kq)
|
||||
k = k.reshape(B, -1, G, H, Kkv)
|
||||
v = v.reshape(B, -1, G, H, Kkv)
|
||||
|
||||
# Transpose in the case of MQA/GQA
|
||||
mqa_swap_seqlen_head = False
|
||||
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
|
||||
mqa_swap_seqlen_head = True
|
||||
assert q.shape[1] == 1
|
||||
q = q.transpose(1, 3)
|
||||
k = k[:, :, :, :1]
|
||||
v = v[:, :, :, :1]
|
||||
|
||||
if k.dtype == torch.int32:
|
||||
# Quantized K/V
|
||||
PACKED_PER_VAL = 8
|
||||
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
|
||||
else:
|
||||
Lk = k.shape[-1]
|
||||
PACKED_PER_VAL = 1
|
||||
|
||||
B, Mk, G, H, Kkv = k.shape
|
||||
B, M, G, H, Kq = q.shape
|
||||
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
|
||||
|
||||
BLOCK_M = cls.BLOCK_M
|
||||
BLOCK_N = cls.BLOCK_N
|
||||
if cls.SPLIT_K is not None:
|
||||
split_k = cls.SPLIT_K
|
||||
else:
|
||||
# Use heuristics
|
||||
split_k = cls.get_split_k(B, H, Mk)
|
||||
|
||||
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
|
||||
o_splitk = torch.empty(
|
||||
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
|
||||
)
|
||||
metadata = torch.empty(
|
||||
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
|
||||
)
|
||||
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
|
||||
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
|
||||
|
||||
num_warps = 2
|
||||
split_size = (Mk + split_k - 1) // split_k
|
||||
use_seq_len = seq_len is not None
|
||||
_fwd_kernel_splitK_unrolled = unroll_varargs(
|
||||
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
|
||||
)
|
||||
|
||||
_fwd_kernel_splitK_unrolled[grid](
|
||||
Q=q,
|
||||
K=k,
|
||||
V=v,
|
||||
sm_scale=inp.scale_float,
|
||||
Out_splitK=o_splitk,
|
||||
Metadata=metadata,
|
||||
Seq_len=seq_len,
|
||||
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
|
||||
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
|
||||
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
Z=B,
|
||||
H=H,
|
||||
G=G,
|
||||
N_CTX_Q=M,
|
||||
N_CTX_K=Mk,
|
||||
BLOCK_N_PER_SPLIT=split_size,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
|
||||
USE_SEQ_LEN=use_seq_len,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
PACKED_PER_VAL=PACKED_PER_VAL,
|
||||
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
|
||||
)
|
||||
|
||||
if mqa_swap_seqlen_head:
|
||||
out = torch.empty(
|
||||
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
|
||||
).transpose(1, 3)
|
||||
else:
|
||||
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
|
||||
|
||||
# Merge together
|
||||
grid = (B * G * H, M, 1)
|
||||
_splitK_reduce[grid](
|
||||
o_splitk,
|
||||
metadata,
|
||||
out,
|
||||
lse,
|
||||
split_k=split_k,
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
**_strides(out, "oz", "om", "og", "oh", "ok"),
|
||||
**_strides(lse, "lse_zhg", "lse_m"),
|
||||
BLOCK_SIZE=out.shape[-1],
|
||||
G=G,
|
||||
H=H,
|
||||
# TODO: Tune num_warps
|
||||
)
|
||||
lse = lse.reshape([B, G, H, M])
|
||||
if mqa_swap_seqlen_head:
|
||||
# H/M dimensions have been swapped
|
||||
out = out.transpose(1, 3)
|
||||
lse = lse.transpose(2, 3)
|
||||
if inp.query.ndim == 4:
|
||||
# BMGHK -> BMHK
|
||||
assert G == 1
|
||||
out = out[:, :, 0]
|
||||
lse = lse[:, 0]
|
||||
|
||||
return out, Context(out=out, lse=lse)
|
||||
|
||||
|
||||
class FwOp_S1(FwOp):
|
||||
SPLIT_K = 1
|
||||
NAME = "triton_splitK1"
|
||||
|
||||
|
||||
class FwOp_S2(FwOp):
|
||||
SPLIT_K = 2
|
||||
NAME = "triton_splitK2"
|
||||
|
||||
|
||||
class FwOp_S4(FwOp):
|
||||
SPLIT_K = 4
|
||||
NAME = "triton_splitK4"
|
||||
|
||||
|
||||
class FwOp_S8(FwOp):
|
||||
SPLIT_K = 8
|
||||
NAME = "triton_splitK8"
|
||||
|
||||
|
||||
class FwOp_S16(FwOp):
|
||||
SPLIT_K = 16
|
||||
NAME = "triton_splitK16"
|
||||
|
||||
|
||||
class FwOp_S32(FwOp):
|
||||
SPLIT_K = 32
|
||||
NAME = "triton_splitK32"
|
||||
|
||||
|
||||
class FwOp_S64(FwOp):
|
||||
SPLIT_K = 64
|
||||
NAME = "triton_splitK64"
|
||||
|
||||
|
||||
class FwOp_S128(FwOp):
|
||||
SPLIT_K = 128
|
||||
NAME = "triton_splitK128"
|
||||
223
pkgs/xformers/ops/indexing.py
Normal file
223
pkgs/xformers/ops/indexing.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from .common import BaseOperator, get_xformers_operator, register_operator
|
||||
|
||||
|
||||
@register_operator
|
||||
class ScaledIndexAddFw(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("scaled_index_addF")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "scaled_index_addF"
|
||||
|
||||
|
||||
@register_operator
|
||||
class ScaledIndexAddBw(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("scaled_index_addB")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "scaled_index_addB"
|
||||
|
||||
|
||||
@register_operator
|
||||
class IndexSelect(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("index_select")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "index_select"
|
||||
|
||||
|
||||
class _ScaledIndexAdd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
scaling: Optional[torch.Tensor],
|
||||
alpha: float,
|
||||
) -> torch.Tensor:
|
||||
ScaledIndexAddFw.OPERATOR(
|
||||
output=input, # in-place
|
||||
input=input,
|
||||
source=source,
|
||||
index=index,
|
||||
source_scaling=scaling,
|
||||
alpha=alpha,
|
||||
)
|
||||
ctx.mark_dirty(input)
|
||||
ctx.save_for_backward(index, scaling, source)
|
||||
ctx.source_shape = source.shape
|
||||
ctx.alpha = alpha
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
index, scaling, source = ctx.saved_tensors
|
||||
grad_source = torch.empty_like(grad_output[: index.shape[0]])
|
||||
grad_source_scaling = (
|
||||
torch.empty(
|
||||
ctx.source_shape,
|
||||
dtype=scaling.dtype,
|
||||
device=scaling.device,
|
||||
)
|
||||
if scaling is not None
|
||||
else None
|
||||
)
|
||||
ScaledIndexAddBw.OPERATOR(
|
||||
grad_source=grad_source,
|
||||
grad_source_scaling=grad_source_scaling,
|
||||
grad_output=grad_output,
|
||||
source=source,
|
||||
index=index,
|
||||
source_scaling=scaling,
|
||||
alpha=ctx.alpha,
|
||||
)
|
||||
if grad_source_scaling is not None:
|
||||
grad_source_scaling = grad_source_scaling.sum((0, 1))
|
||||
return (
|
||||
grad_output, # input
|
||||
None, # index
|
||||
grad_source, # source
|
||||
grad_source_scaling, # scaling
|
||||
None, # alpha
|
||||
)
|
||||
|
||||
|
||||
def scaled_index_add(
|
||||
input: torch.Tensor, # [B, M, D]
|
||||
index: torch.Tensor, # [Bi] - int64
|
||||
source: torch.Tensor, # [Bi, M, D]
|
||||
scaling: Optional[torch.Tensor] = None, # [D]
|
||||
alpha: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
In-place scaling+index_add
|
||||
|
||||
Indices in ``index`` are assumed to be unique
|
||||
|
||||
:Note:
|
||||
|
||||
The FW pass is done in-place (``input`` is modified)
|
||||
|
||||
:Note:
|
||||
|
||||
This is experimental and has only been optimized for a few shapes
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha)
|
||||
"""
|
||||
return _ScaledIndexAdd.apply(
|
||||
input,
|
||||
index,
|
||||
source,
|
||||
scaling,
|
||||
alpha,
|
||||
)
|
||||
|
||||
|
||||
class _IndexSelectCat(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx,
|
||||
*args: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert len(args) % 2 == 0
|
||||
sources = args[: len(args) // 2]
|
||||
indices = args[len(args) // 2 :]
|
||||
output_shape = 0
|
||||
total_source_elements = 0
|
||||
for source, index in zip(sources, indices):
|
||||
output_shape += index.shape[0] * source.shape[1]
|
||||
total_source_elements += source.shape[0] * source.shape[1]
|
||||
output = torch.empty(
|
||||
[output_shape], dtype=sources[0].dtype, device=sources[0].device
|
||||
)
|
||||
output_i = 0
|
||||
for source, index in zip(sources, indices):
|
||||
elements_here = index.shape[0] * source.shape[1]
|
||||
IndexSelect.OPERATOR(
|
||||
output=output[output_i : output_i + elements_here].view(
|
||||
[index.shape[0], source.shape[1]]
|
||||
),
|
||||
source=source,
|
||||
index=index,
|
||||
)
|
||||
output_i += elements_here
|
||||
ctx.save_for_backward(*indices)
|
||||
ctx.total_source_elements = total_source_elements
|
||||
ctx.source_shapes = [s.shape for s in sources]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
indices = ctx.saved_tensors
|
||||
grad_sources = torch.zeros(
|
||||
[ctx.total_source_elements],
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
)
|
||||
grad_sources_i = 0
|
||||
grad_output_i = 0
|
||||
gradients = []
|
||||
for source_shape, index in zip(ctx.source_shapes, indices):
|
||||
grad_output_slice = grad_output[
|
||||
grad_output_i : grad_output_i + index.shape[0] * source_shape[1]
|
||||
].reshape([index.shape[0], source_shape[1]])
|
||||
grad_output_i += index.shape[0] * source_shape[1]
|
||||
|
||||
gradient_source = grad_sources[
|
||||
grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1]
|
||||
].reshape(source_shape)
|
||||
grad_sources_i += source_shape[0] * source_shape[1]
|
||||
|
||||
ScaledIndexAddFw.OPERATOR(
|
||||
output=gradient_source.unsqueeze(1),
|
||||
input=None,
|
||||
source=grad_output_slice.unsqueeze(1),
|
||||
index=index,
|
||||
source_scaling=None,
|
||||
alpha=1.0,
|
||||
)
|
||||
gradients.append(gradient_source)
|
||||
return (*gradients, *([None] * len(gradients)))
|
||||
|
||||
|
||||
def index_select_cat(
|
||||
sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Indices in ``index`` are assumed to be unique
|
||||
|
||||
:Note:
|
||||
|
||||
This is experimental and has only been optimized for a few shapes
|
||||
|
||||
:Example:
|
||||
|
||||
Given:
|
||||
- ``sources[0]`` of shape ``[S0, D0]``
|
||||
- ``indices[0]`` of shape ``[I0]``
|
||||
- ``sources[1]`` of shape ``[S1, D1]``
|
||||
- ``indices[1]`` of shape ``[I1]``
|
||||
returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]``
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
|
||||
"""
|
||||
return _IndexSelectCat.apply(*sources, *indices)
|
||||
113
pkgs/xformers/ops/rmsnorm.py
Normal file
113
pkgs/xformers/ops/rmsnorm.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .. import _is_triton_available
|
||||
|
||||
|
||||
def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6):
|
||||
"""
|
||||
RMS Normalization along the last dimension.
|
||||
|
||||
This is similar to torch.nn.functional.normalize but with eps being added
|
||||
instead of max.
|
||||
|
||||
Expects x contiguous of shape (..., dim), and returns normalized data
|
||||
of the same shape. For each dim-length vector x, the result has
|
||||
|
||||
x / sqrt( x*x.sum() + eps)
|
||||
|
||||
If weights are included, they are a contiguous parameter of length dim
|
||||
which multiplies the result.
|
||||
|
||||
This functionality is experimental. Its API might be changed without warnings.
|
||||
Use it at your own risk.
|
||||
"""
|
||||
assert _is_triton_available()
|
||||
from .triton.rmsnorm_kernels import _rms_norm_forward
|
||||
|
||||
if torch.is_grad_enabled() and (
|
||||
x.requires_grad or (weight is not None and weight.requires_grad)
|
||||
):
|
||||
raise ValueError("Gradients not supported.")
|
||||
|
||||
return _rms_norm_forward(x, weight, eps)
|
||||
|
||||
|
||||
def rms_norm_add(
|
||||
x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor], eps: float = 1e-6
|
||||
):
|
||||
"""
|
||||
An addition fused with rms_norm.
|
||||
|
||||
z = rms_norm_add(x, y, weight, eps)
|
||||
|
||||
is equivalent to
|
||||
|
||||
x += y
|
||||
z = rms_norm(x, weight, eps)
|
||||
|
||||
where x, y and z are all contiguous.
|
||||
|
||||
This functionality is experimental. Its API might be changed without warnings.
|
||||
Use it at your own risk.
|
||||
"""
|
||||
if torch.is_grad_enabled() and (
|
||||
x.requires_grad
|
||||
or y.requires_grad
|
||||
or (weight is not None and weight.requires_grad)
|
||||
):
|
||||
raise ValueError("Gradients not supported.")
|
||||
assert _is_triton_available()
|
||||
from .triton.rmsnorm_kernels import _rms_norm_add_forward
|
||||
|
||||
return _rms_norm_add_forward(x, y, weight, eps)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
"""
|
||||
RMS Normalization layer along the last dimension.
|
||||
|
||||
This is similar to torch.nn.functional.normalize but with eps being added
|
||||
instead of max.
|
||||
|
||||
Expects contiguous input of shape (..., dim), and returns normalized data
|
||||
of the same shape. For each dim-length vector x, the result has
|
||||
|
||||
x / sqrt( x*x.sum() + eps)
|
||||
|
||||
If weights are included, they are a parameter of length dim which multiplies
|
||||
the result.
|
||||
|
||||
This functionality is experimental. Its API might be changed without warnings.
|
||||
Use it at your own risk.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if include_weight:
|
||||
self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return rms_norm(x, self.weight, self.eps) # type: ignore
|
||||
|
||||
def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
An addition fused with forward.
|
||||
|
||||
z = layer.increment_and_forward_(x, y)
|
||||
|
||||
is equivalent to
|
||||
|
||||
x += y
|
||||
z = layer(x)
|
||||
"""
|
||||
return rms_norm_add(x, y, self.weight, self.eps) # type: ignore
|
||||
188
pkgs/xformers/ops/rope_padded.py
Normal file
188
pkgs/xformers/ops/rope_padded.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.ops.fmha.attn_bias import ( # type: ignore
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
)
|
||||
|
||||
from .. import _is_triton_available
|
||||
|
||||
|
||||
def rope_padded(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
xv: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
*,
|
||||
theta: float = 10000.0,
|
||||
out_q: Optional[torch.Tensor] = None,
|
||||
adjacents: bool = True,
|
||||
internal_dtype: str = "",
|
||||
):
|
||||
"""
|
||||
Performs RoPE (rotary embeddings) and kv-cache emplacement for a heterogeneous
|
||||
batch for inference in the style given by
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask.
|
||||
The batch is concatted along the sequence dimension, so the
|
||||
actual dim-0 length of all tensors is 1.
|
||||
|
||||
xq, xk and xv should be (1, slen, n_heads, dim), where
|
||||
xq's n_heads can differ from xk and xv.
|
||||
|
||||
This function places the roped xk in the right place in cache_k, and
|
||||
xv (unmodified) in the right place in cache_v, and returns out_q
|
||||
(the roped xq) such that things are ready to call
|
||||
|
||||
xformers.ops.memory_efficient_attention(
|
||||
out_q, cache_k, cache_v, attn_bias=attn_bias
|
||||
)
|
||||
|
||||
This functionality is experimental. Its API might be changed without warnings.
|
||||
Use it at your own risk.
|
||||
|
||||
Arguments:
|
||||
xq: tensor of queries to apply rope to
|
||||
xk: tensor of keys to apply rope to
|
||||
xv: tensor of values to copy into cache_v
|
||||
cache_k: cache of keys, MODIFIED IN PLACE
|
||||
cache_v: cache of values, MODIFIED IN PLACE
|
||||
attn_bias: details the layout of caches.
|
||||
Used to determine frequencies for the
|
||||
RoPE calculation as well as the locations in cache_k and cache_v
|
||||
to write to. Must be on the device.
|
||||
adjacents: If True, the inputs are in adjacent pairs along the final dim axis.
|
||||
This is like the released LLaMA model.
|
||||
If False, the dim axis is split in two equal pieces.
|
||||
I.e. the features are ordered with all the real parts before all
|
||||
the imaginary parts. This matches HuggingFace, e.g.
|
||||
https://github.com/huggingface/transformers/blob/
|
||||
f143037789288ba532dada934a118e648e715738/
|
||||
src/transformers/models/llama/modeling_llama.py#L126-L130
|
||||
internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation
|
||||
"""
|
||||
if torch.is_grad_enabled() and (
|
||||
xq.requires_grad
|
||||
or xk.requires_grad
|
||||
or xv.requires_grad
|
||||
or cache_k.requires_grad
|
||||
or cache_v.requires_grad
|
||||
or out_q is not None
|
||||
):
|
||||
raise ValueError("Gradients not supported.")
|
||||
assert _is_triton_available()
|
||||
import triton
|
||||
|
||||
from .triton.rope_padded_kernels import _rope_padded_kernel
|
||||
|
||||
n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1]
|
||||
cache_length = attn_bias.k_seqinfo.seqstart_py[-1]
|
||||
bsz, q_len, n_q_heads, dim = xq.shape
|
||||
assert q_len == n_total_queries
|
||||
if bsz != 1:
|
||||
raise ValueError(
|
||||
"Expected batch size dimension to be 1" "as batches should be concatenated."
|
||||
)
|
||||
xk_shape = xk.shape
|
||||
n_kv_heads = xk_shape[2]
|
||||
if xk_shape != (1, n_total_queries, n_kv_heads, dim):
|
||||
raise ValueError("unexpected k shape")
|
||||
if xv.shape != (1, n_total_queries, n_kv_heads, dim):
|
||||
raise ValueError("unexpected v shape")
|
||||
if cache_k.shape != (1, cache_length, n_kv_heads, dim):
|
||||
raise ValueError("unexpected cache_k length")
|
||||
if cache_v.shape != (1, cache_length, n_kv_heads, dim):
|
||||
raise ValueError("unexpected cache_v length")
|
||||
|
||||
xq_stride = xq.stride()
|
||||
xk_stride = xk.stride()
|
||||
xv_stride = xv.stride()
|
||||
cache_k_stride = cache_k.stride()
|
||||
cache_v_stride = cache_v.stride()
|
||||
if xq_stride[3] != 1:
|
||||
raise ValueError("Each q head must be contiguous")
|
||||
if xk_stride[3] != 1:
|
||||
raise ValueError("Each k head must be contiguous")
|
||||
if xv_stride[3] != 1:
|
||||
raise ValueError("Each v head must be contiguous")
|
||||
if cache_k_stride[3] != 1:
|
||||
raise ValueError("Each cache_k head must be contiguous")
|
||||
if cache_v_stride[3] != 1:
|
||||
raise ValueError("Each cache_v head must be contiguous")
|
||||
n_total_heads = n_q_heads + 2 * n_kv_heads
|
||||
v_start = n_total_heads - n_kv_heads
|
||||
k_start = n_q_heads
|
||||
if out_q is None:
|
||||
out_q = xq.new_empty(1, n_total_queries, n_q_heads, dim)
|
||||
out_q_stride: Tuple[int, ...] = (0, n_q_heads * dim, dim, 1)
|
||||
else:
|
||||
if out_q.shape != xq.shape:
|
||||
raise ValueError("Unexpected shape of out_q")
|
||||
out_q_stride = out_q.stride()
|
||||
if out_q_stride[3] != 1:
|
||||
raise ValueError("Each out_q head must be contiguous")
|
||||
|
||||
assert out_q is not None
|
||||
|
||||
logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // xq.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
device = xq.device
|
||||
# Move these to the right device, like fmha does.
|
||||
attn_bias.k_seqinfo.to(device)
|
||||
attn_bias.q_seqinfo.to(device)
|
||||
seqstartq = attn_bias.q_seqinfo.seqstart
|
||||
seqstartk = attn_bias.k_seqinfo.seqstart
|
||||
seqlenk = attn_bias.k_seqinfo.seqlen
|
||||
assert internal_dtype in ["", "f32", "f64"]
|
||||
# experiment with the order of dims here.
|
||||
_rope_padded_kernel[(logical_bsz, attn_bias.q_seqinfo.max_seqlen, n_total_heads)](
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
out_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
seqstartq,
|
||||
seqstartk,
|
||||
seqlenk,
|
||||
theta,
|
||||
k_start,
|
||||
v_start,
|
||||
dim,
|
||||
xq_stride[1],
|
||||
xq_stride[2],
|
||||
xk_stride[1],
|
||||
xk_stride[2],
|
||||
xv_stride[1],
|
||||
xv_stride[2],
|
||||
cache_k_stride[1],
|
||||
cache_k_stride[2],
|
||||
cache_v_stride[1],
|
||||
cache_v_stride[2],
|
||||
seqstartq.stride(0),
|
||||
seqstartk.stride(0),
|
||||
seqlenk.stride(0),
|
||||
out_q_stride[1],
|
||||
out_q_stride[2],
|
||||
internal_dtype,
|
||||
const_batch_strides=False,
|
||||
cache_padding_length=0,
|
||||
seqlenk_shift=0,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
adjacents=adjacents,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out_q
|
||||
467
pkgs/xformers/ops/swiglu_op.py
Normal file
467
pkgs/xformers/ops/swiglu_op.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .common import BaseOperator, get_xformers_operator, register_operator
|
||||
from .unbind import stack_or_none, unbind
|
||||
|
||||
|
||||
@register_operator
|
||||
class DualGemmSiluOp(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul")
|
||||
OPERATOR_CATEGORY = "swiglu"
|
||||
NAME = "dual_gemm_silu"
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(
|
||||
cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2
|
||||
) -> int:
|
||||
"""NOTE: we neglect the impact of biases / pointwises"""
|
||||
M, N, K = x.shape[0], w1.shape[0], w1.shape[1]
|
||||
return M * N * K * 2 * 2
|
||||
|
||||
|
||||
@register_operator
|
||||
class GemmFusedSumOp(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("gemm_fused_operand_sum")
|
||||
OPERATOR_CATEGORY = "swiglu"
|
||||
NAME = "gemm_fused_operand_sum"
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int:
|
||||
M, N, K = a.shape[0], b.shape[1], a.shape[1]
|
||||
return M * N * K * 2
|
||||
|
||||
|
||||
class _SwiGLUDecomposedFunc(torch.autograd.Function):
|
||||
"""
|
||||
This is just an example implementation with all
|
||||
operations explicited. This implementation is worse
|
||||
than pytorch, because pytorch is able to fuse some operations
|
||||
(eg the linear forward ...) that are decomposed here.
|
||||
|
||||
The time measurements were made on the ViT-Giant setting:
|
||||
- A100/f16
|
||||
- input: [4440, 1536]
|
||||
- hidden: [4440, 4096]
|
||||
"""
|
||||
|
||||
NAME = "decomposed"
|
||||
FORCE_BW_F32 = False
|
||||
|
||||
def _silu_backward(dy, x):
|
||||
# https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483
|
||||
sigm = 1 / (1 + torch.exp(-x.float()))
|
||||
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)
|
||||
|
||||
# 952us
|
||||
@classmethod
|
||||
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
||||
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
|
||||
x2 = x @ w2.transpose(-2, -1) + b2 # 275us
|
||||
x3 = F.silu(x1) # 62us
|
||||
x4 = x3 * x2 # 90us
|
||||
x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us
|
||||
|
||||
ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5)
|
||||
return x5
|
||||
|
||||
# 1900us
|
||||
@classmethod
|
||||
def backward(cls, ctx, dx5):
|
||||
saved_tensors = ctx.saved_tensors
|
||||
if cls.FORCE_BW_F32:
|
||||
dx5 = dx5.float()
|
||||
saved_tensors = [t.float() for t in ctx.saved_tensors]
|
||||
x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors
|
||||
dx4 = dx5 @ w3 # 255us (nn)
|
||||
dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt)
|
||||
db3 = dx5.sum(0) # 25us
|
||||
dx3 = dx4 * x2 # 88us
|
||||
dx2 = dx4 * x3 # 88us
|
||||
dx1 = cls._silu_backward(dx3, x1) # 90us
|
||||
dx = dx2 @ w2 # 260us (nn)
|
||||
dw2 = dx2.transpose(-2, -1) @ x # 245us (nt)
|
||||
db2 = dx2.sum(0) # 50us
|
||||
dx += dx1 @ w1 # 260us (nn)
|
||||
dw1 = dx1.transpose(-2, -1) @ x # 245us (nt)
|
||||
db1 = dx1.sum(0) # 50us
|
||||
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
||||
|
||||
|
||||
class _SwiGLUFusedFunc(torch.autograd.Function):
|
||||
NAME = "fused.py"
|
||||
|
||||
@classmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
|
||||
x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2)
|
||||
|
||||
x5 = F.linear(x4, w3, b3)
|
||||
ctx.save_for_backward(x, w1, w2, w3, x1, x2)
|
||||
ctx.bias = [b1 is not None, b2 is not None, b3 is not None]
|
||||
return x5
|
||||
|
||||
@staticmethod
|
||||
def _linear_bw(
|
||||
dy: torch.Tensor, x: torch.Tensor, bias: bool
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if not bias:
|
||||
return (dy.transpose(-2, -1) @ x), None
|
||||
db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device)
|
||||
dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device)
|
||||
GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db)
|
||||
return dw, db
|
||||
|
||||
@classmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(cls, ctx, dx5):
|
||||
x, w1, w2, w3, x1, x2 = ctx.saved_tensors
|
||||
w1w2 = stack_or_none([w1, w2], dim=0)
|
||||
|
||||
dx4 = dx5 @ w3 # 255us (nn)
|
||||
dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4)
|
||||
dx1, dx2 = dx1dx2.unbind(1)
|
||||
del x1, x2, dx4
|
||||
|
||||
dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2])
|
||||
del x4, dx5
|
||||
if w1w2 is not None:
|
||||
assert dx1dx2.is_contiguous()
|
||||
assert w1w2.is_contiguous()
|
||||
w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]])
|
||||
dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2
|
||||
|
||||
# backward of linear1 + linear2 - packed
|
||||
dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x
|
||||
dw1dw2, db1db2 = cls._linear_bw(
|
||||
dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0]
|
||||
)
|
||||
dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0)
|
||||
if ctx.bias[0]:
|
||||
db1db2 = db1db2.view([2, dx1.shape[1]])
|
||||
db1, db2 = torch.unbind(db1db2, dim=0)
|
||||
else:
|
||||
db1 = db2 = None
|
||||
else:
|
||||
dx = dx2 @ w2 # 260us (nn)
|
||||
torch.addmm(
|
||||
dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx
|
||||
) # dx += dx1 @ w1
|
||||
dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1])
|
||||
dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0])
|
||||
return (dx, dw1, db1, dw2, db2, dw3, db3)
|
||||
|
||||
|
||||
class SwiGLUOp:
|
||||
"""Base class for any swiglu operator in :attr:`xformers.ops.swiglu`"""
|
||||
|
||||
def __init__(self, op, packed_weights: bool, name: str, constraints):
|
||||
self.NAME = name
|
||||
self.PACKED_WEIGHTS = packed_weights
|
||||
self.op = op
|
||||
self.constraints = constraints
|
||||
|
||||
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
||||
if self.PACKED_WEIGHTS and not op.packed_weights:
|
||||
return False
|
||||
return all(c(op) for c in self.constraints)
|
||||
|
||||
def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"SwiGLUOp:{self.NAME}"
|
||||
|
||||
|
||||
class _ForwardToPythonAutogradFunc(SwiGLUOp):
|
||||
def supports(self, op: "SwiGLUOpDispatch") -> bool:
|
||||
# Let's disable autocast in bf16 until this issue is fixed
|
||||
# https://github.com/pytorch/pytorch/issues/87979
|
||||
if op.dtype_autocast_gpu == torch.bfloat16:
|
||||
return False
|
||||
return super().supports(op)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.op.apply(*args, **kwargs)
|
||||
|
||||
|
||||
class _ForwardToFunc(SwiGLUOp):
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.op(*args, **kwargs)
|
||||
|
||||
def info(self):
|
||||
if self.op.__name__ == "no_such_operator":
|
||||
return "not built"
|
||||
return "available"
|
||||
|
||||
|
||||
def _eager_functional_swiglu(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
b2: torch.Tensor,
|
||||
w3: torch.Tensor,
|
||||
b3: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x1 = F.linear(x, w1, b1)
|
||||
x2 = F.linear(x, w2, b2)
|
||||
hidden = F.silu(x1) * x2
|
||||
return F.linear(hidden, w3, b3)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwiGLUOpDispatch:
|
||||
"""Dispatcher to automatically select
|
||||
the best operator in :attr:`xformers.ops.swiglu`
|
||||
"""
|
||||
|
||||
device: Union[torch.device, str]
|
||||
dtype: torch.dtype
|
||||
dtype_autocast_gpu: Optional[torch.dtype]
|
||||
packed_weights: bool
|
||||
bias_enabled: bool
|
||||
|
||||
@property
|
||||
def op(self) -> SwiGLUOp:
|
||||
"""Computes the best operator
|
||||
|
||||
Returns:
|
||||
SwiGLUOp: The best operator for the configuration
|
||||
"""
|
||||
priorities: Sequence[SwiGLUOp] = [
|
||||
SwiGLUPackedFusedOp,
|
||||
SwiGLUFusedOp,
|
||||
]
|
||||
for op in priorities:
|
||||
if op.supports(self):
|
||||
return op
|
||||
return SwiGLUEagerOp
|
||||
|
||||
@staticmethod
|
||||
def from_arguments(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: Optional[torch.Tensor],
|
||||
w2: torch.Tensor,
|
||||
b2: Optional[torch.Tensor],
|
||||
w3: torch.Tensor,
|
||||
b3: Optional[torch.Tensor],
|
||||
) -> "SwiGLUOpDispatch":
|
||||
return SwiGLUOpDispatch(
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
packed_weights=stack_or_none((w1, w2), dim=0) is not None,
|
||||
dtype_autocast_gpu=torch.get_autocast_gpu_dtype()
|
||||
if torch.is_autocast_enabled()
|
||||
else w1.dtype,
|
||||
bias_enabled=b1 is not None and b2 is not None and b3 is not None,
|
||||
)
|
||||
|
||||
|
||||
def _only_sm80(op: SwiGLUOpDispatch) -> bool:
|
||||
device_type = op.device if isinstance(op.device, str) else op.device.type
|
||||
return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8
|
||||
|
||||
|
||||
def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool:
|
||||
HALF_DTYPES = [torch.half, torch.bfloat16]
|
||||
return op.dtype in HALF_DTYPES or (
|
||||
op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES
|
||||
)
|
||||
|
||||
|
||||
def _bias_enabled(op: SwiGLUOpDispatch) -> bool:
|
||||
return op.bias_enabled
|
||||
|
||||
|
||||
_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc(
|
||||
_SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled]
|
||||
)
|
||||
SwiGLUFusedOp = _ForwardToPythonAutogradFunc(
|
||||
_SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast]
|
||||
)
|
||||
SwiGLUPackedFusedOp = _ForwardToFunc(
|
||||
get_xformers_operator("swiglu_packedw"),
|
||||
True,
|
||||
"fused.p.cpp",
|
||||
constraints=[_only_sm80, _only_half_or_autocast],
|
||||
)
|
||||
SwiGLUEagerOp = _ForwardToFunc(
|
||||
_eager_functional_swiglu,
|
||||
False,
|
||||
"eager",
|
||||
constraints=[],
|
||||
)
|
||||
|
||||
|
||||
def _info() -> Dict[str, str]:
|
||||
return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]}
|
||||
|
||||
|
||||
def swiglu(
|
||||
x: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
b1: Optional[torch.Tensor],
|
||||
w2: torch.Tensor,
|
||||
b2: Optional[torch.Tensor],
|
||||
w3: torch.Tensor,
|
||||
b3: Optional[torch.Tensor],
|
||||
*,
|
||||
op: SwiGLUOp = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes a SwiGLU block given the weights/bias of the 3
|
||||
linear layers.
|
||||
|
||||
- It is recommended to keep ``op=None`` so the best implementation \
|
||||
available for the inputs will be used.
|
||||
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
x1 = F.linear(x, w1, b1)
|
||||
x2 = F.linear(x, w2, b2)
|
||||
hidden = F.silu(x1) * x2
|
||||
return F.linear(hidden, w3, b3)
|
||||
|
||||
:Packing weights:
|
||||
|
||||
To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in:
|
||||
.. code-block:: python
|
||||
|
||||
w1, w2 = xformers.ops.unbind(w12, 0)
|
||||
|
||||
:Supported hardware:
|
||||
|
||||
This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \
|
||||
(autocast is supported), and will fallback to a functional pytorch \
|
||||
implementation otherwise.
|
||||
"""
|
||||
|
||||
batch_shape = x.shape[:-1]
|
||||
x = x.reshape([-1, x.shape[-1]])
|
||||
if w1.ndim != 2 or w1.shape != w2.shape:
|
||||
raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}")
|
||||
if b1 is not None:
|
||||
if b1.ndim != 1 or b1.shape[0] != w1.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for b1: {b1.shape}")
|
||||
if b2 is not None:
|
||||
if b2.ndim != 1 or b2.shape[0] != w2.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for b2: {b2.shape}")
|
||||
if w3.ndim != 2 or w3.shape[1] != w2.shape[0]:
|
||||
raise ValueError(f"Invalid shape for w3: {w3.shape}")
|
||||
if b3 is not None:
|
||||
if b3.ndim != 1 or b3.shape[0] != w3.shape[0]:
|
||||
raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}")
|
||||
|
||||
if op is None:
|
||||
op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op
|
||||
|
||||
if not op.PACKED_WEIGHTS:
|
||||
return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1])
|
||||
w1w2 = stack_or_none((w1, w2), dim=0)
|
||||
if b1 is not None and b2 is not None:
|
||||
b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0)
|
||||
if b1b2 is None:
|
||||
raise NotImplementedError("b1/b2 needs to be properly packed")
|
||||
else:
|
||||
b1b2 = None
|
||||
assert b1 is None and b2 is None
|
||||
|
||||
if w1w2 is None:
|
||||
raise NotImplementedError("w1/w2 needs to be properly packed")
|
||||
return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1])
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
A Module that encapsulates the call to :attr:`xformers.ops.swiglu`,
|
||||
and holds the weights for the 3 linear layers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
out_features: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
*,
|
||||
_pack_weights: bool = True,
|
||||
) -> None:
|
||||
"""Create a SwiGLU module
|
||||
|
||||
Args:
|
||||
in_features (int): Number of features of the input
|
||||
hidden_features (int): Number of hidden features
|
||||
out_features (Optional[int], optional): Number of features of the input. Defaults to None.
|
||||
bias (bool, optional): Whether linear layers also include a bias. Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
||||
self.w12: Optional[nn.Linear]
|
||||
if _pack_weights:
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
else:
|
||||
self.w12 = None
|
||||
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
self.hidden_features = hidden_features
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
self.op = None
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes :attr:`swiglu` with the module's weights
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A Tensor of shape ``[..., in_features]``
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A Tensor of shape ``[..., out_features]``
|
||||
"""
|
||||
return swiglu(x, *self._ordered_params(), op=self.op)
|
||||
|
||||
def _ordered_params(self):
|
||||
"""Used for testing - returns ordered arguments for operators"""
|
||||
b1: Optional[torch.Tensor]
|
||||
b2: Optional[torch.Tensor]
|
||||
if self.w12 is not None:
|
||||
w1w2 = self.w12.weight
|
||||
b1b2 = self.w12.bias
|
||||
w1, w2 = unbind(
|
||||
w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]),
|
||||
dim=0,
|
||||
)
|
||||
if b1b2 is not None:
|
||||
b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0)
|
||||
else:
|
||||
b1, b2 = None, None
|
||||
else:
|
||||
w1, w2 = self.w1.weight, self.w2.weight
|
||||
b1, b2 = self.w1.bias, self.w2.bias
|
||||
return [
|
||||
w1,
|
||||
b1,
|
||||
w2,
|
||||
b2,
|
||||
self.w3.weight,
|
||||
self.w3.bias,
|
||||
]
|
||||
4
pkgs/xformers/ops/triton/__init__.py
Normal file
4
pkgs/xformers/ops/triton/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
BIN
pkgs/xformers/ops/triton/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/ops/triton/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
158
pkgs/xformers/ops/triton/rmsnorm_kernels.py
Normal file
158
pkgs/xformers/ops/triton/rmsnorm_kernels.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if hasattr(tl, "libdevice"):
|
||||
tl_math = tl.libdevice
|
||||
else:
|
||||
tl_math = tl.math
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
x_ptr,
|
||||
h1_ptr,
|
||||
w_ptr,
|
||||
eps,
|
||||
stride,
|
||||
N_COLS,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
INCLUDE_WEIGHT: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
x_ptr += row * stride
|
||||
h1_ptr += row * stride
|
||||
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
|
||||
).to(tl.float32)
|
||||
_mean += a * a
|
||||
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
if INCLUDE_WEIGHT:
|
||||
w = tl.load(w_ptr + cols, mask=mask)
|
||||
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
||||
else:
|
||||
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
h1_ptr,
|
||||
w_ptr,
|
||||
eps,
|
||||
stride,
|
||||
N_COLS,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
INCLUDE_WEIGHT: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
x_ptr += row * stride
|
||||
y_ptr += row * stride
|
||||
h1_ptr += row * stride
|
||||
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
ax = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).to(tl.float32)
|
||||
ay = tl.load(
|
||||
y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
a = ax + ay
|
||||
tl.store(x_ptr + cols, a, mask=mask)
|
||||
_mean += a * a
|
||||
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
if INCLUDE_WEIGHT:
|
||||
w = tl.load(w_ptr + cols, mask=mask)
|
||||
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
||||
else:
|
||||
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
||||
|
||||
|
||||
def _rms_norm_forward(x, attn_norm_weights, eps):
|
||||
if not x.is_contiguous():
|
||||
raise ValueError("data must be contiguous")
|
||||
if attn_norm_weights is not None:
|
||||
if not attn_norm_weights.is_contiguous():
|
||||
raise ValueError("weights must be contiguous")
|
||||
out = torch.empty_like(x)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_rms_norm_kernel[(M,)](
|
||||
x_arg,
|
||||
out,
|
||||
attn_norm_weights,
|
||||
eps,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
|
||||
# x, y contiguous of same shape [..., n]
|
||||
# output of same shape, normed over the last dim.
|
||||
if not x.is_contiguous():
|
||||
raise ValueError("x must be contiguous")
|
||||
if not y.is_contiguous():
|
||||
raise ValueError("y must be contiguous")
|
||||
if attn_norm_weights is not None:
|
||||
if not attn_norm_weights.is_contiguous():
|
||||
raise ValueError("weights must be contiguous")
|
||||
out = torch.empty_like(x)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
y_arg = y.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_rms_norm_add_kernel[(M,)](
|
||||
x_arg,
|
||||
y_arg,
|
||||
out,
|
||||
attn_norm_weights,
|
||||
eps,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
||||
)
|
||||
return out
|
||||
161
pkgs/xformers/ops/triton/rope_padded_kernels.py
Normal file
161
pkgs/xformers/ops/triton/rope_padded_kernels.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
|
||||
if hasattr(tl, "libdevice"):
|
||||
tl_math = tl.libdevice
|
||||
else:
|
||||
tl_math = tl.math
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rope_padded_kernel(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
out_q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
seqstartq,
|
||||
seqstartk,
|
||||
seqlenk,
|
||||
theta,
|
||||
k_start: tl.constexpr,
|
||||
v_start: tl.constexpr,
|
||||
dim: tl.constexpr, # dimension of each head
|
||||
stride_xqM,
|
||||
stride_xqH,
|
||||
stride_xkM,
|
||||
stride_xkH,
|
||||
stride_xvM,
|
||||
stride_xvH,
|
||||
stride_cachekM,
|
||||
stride_cachekH,
|
||||
stride_cachevM,
|
||||
stride_cachevH,
|
||||
stride_seqstartq,
|
||||
stride_seqstartk,
|
||||
stride_seqlenk,
|
||||
stride_outqM,
|
||||
stride_outqH,
|
||||
internal_dtype: tl.constexpr,
|
||||
# If True, seqstartq and seqstartk are not used but rather we
|
||||
# assume that every batch element has the same number of
|
||||
# queries (i.e. num_queries := tl.num_programs(1) )
|
||||
# and the same cache space cache_padding_length.
|
||||
# Always False when called below.
|
||||
const_batch_strides: tl.constexpr,
|
||||
# If const_batch_strides==True, the common cache length for each batch element.
|
||||
# (Only the first seqlenk[i] elements are actually in use, and only the last
|
||||
# num_queries of those are actually written to.)
|
||||
cache_padding_length,
|
||||
# offset added to all values in seqlenk before using them.
|
||||
# Always 0 when called below.
|
||||
seqlenk_shift: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
adjacents: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Each letter in this diagram is a whole row of length dim.
|
||||
|
||||
INPUT xq xk xv
|
||||
|
||||
head_dim ─►
|
||||
|
||||
batch qqqqqq kk vv
|
||||
│ qqqqqq kk vv
|
||||
▼ qqqqqq kk vv
|
||||
|
||||
head_idx: (goes across all heads of all 3 inputs)
|
||||
▲ ▲ ▲ ▲ ▲ ▲
|
||||
│ │ │ │ │ │
|
||||
│ │
|
||||
0 k_start │v_start │n_total_heads
|
||||
│ │
|
||||
│ │
|
||||
k_start v_start
|
||||
|
||||
Output is to out_q (same shape as xq), an xk-shaped part
|
||||
of cache_k and an xv-shaped part of cache_v
|
||||
"""
|
||||
batch_elt = tl.program_id(0)
|
||||
query_pos_in_batch_elt = tl.program_id(1)
|
||||
head_idx = tl.program_id(2)
|
||||
|
||||
if internal_dtype == "f32":
|
||||
theta = theta.to(tl.float32)
|
||||
elif internal_dtype == "f64":
|
||||
theta = theta.to(tl.float64)
|
||||
|
||||
if const_batch_strides:
|
||||
query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt
|
||||
end_query_pos = tl.num_programs(1) * (batch_elt + 1)
|
||||
else:
|
||||
query_pos = query_pos_in_batch_elt + tl.load(
|
||||
seqstartq + batch_elt * stride_seqstartq
|
||||
)
|
||||
end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq)
|
||||
if query_pos >= end_query_pos:
|
||||
return
|
||||
|
||||
is_q = head_idx < k_start
|
||||
is_v = head_idx >= v_start
|
||||
|
||||
xq += query_pos * stride_xqM + head_idx * stride_xqH
|
||||
out_q += query_pos * stride_outqM + head_idx * stride_outqH
|
||||
|
||||
if const_batch_strides:
|
||||
cache_start = cache_padding_length * batch_elt
|
||||
else:
|
||||
cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk)
|
||||
end_of_batch_elt_cache = (
|
||||
cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift
|
||||
)
|
||||
|
||||
cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos)
|
||||
seq_pos = cache_pos - cache_start
|
||||
cache_k += (head_idx - k_start) * stride_cachekH + cache_pos * stride_cachekM
|
||||
xk += query_pos * stride_xkM + (head_idx - k_start) * stride_xkH
|
||||
in_qk = tl.where(is_q, xq, xk)
|
||||
out_qk = tl.where(is_q, out_q, cache_k)
|
||||
|
||||
cache_v += (head_idx - v_start) * stride_cachevH + cache_pos * stride_cachevM
|
||||
xv += query_pos * stride_xvM + (head_idx - v_start) * stride_xvH
|
||||
|
||||
out = tl.where(is_v, cache_v, out_qk)
|
||||
x_in = tl.where(is_v, xv, in_qk)
|
||||
|
||||
for offset in range(0, dim // 2, BLOCK_SIZE // 2):
|
||||
c = tl.arange(0, BLOCK_SIZE // 2)
|
||||
powers = (offset + c) * 2.0
|
||||
if adjacents:
|
||||
cols_re = (offset + c) * 2
|
||||
cols_im = cols_re + 1
|
||||
else:
|
||||
cols_re = offset + c
|
||||
cols_im = cols_re + dim // 2
|
||||
|
||||
mask = cols_im < dim
|
||||
|
||||
re_x = tl.load(x_in + cols_re, mask=mask)
|
||||
im_x = tl.load(x_in + cols_im, mask=mask)
|
||||
# freqs = seq_pos / (theta ** (powers / dim))
|
||||
freqs = seq_pos * tl_math.pow(theta, powers / (-dim))
|
||||
sines = tl.sin(freqs)
|
||||
cosines = tl.cos(freqs)
|
||||
re_out = re_x * cosines - im_x * sines
|
||||
im_out = im_x * cosines + re_x * sines
|
||||
|
||||
re_out_ = tl.where(is_v, re_x, re_out)
|
||||
im_out_ = tl.where(is_v, im_x, im_out)
|
||||
if internal_dtype == "f64":
|
||||
if re_x.dtype == tl.bfloat16:
|
||||
# triton 2.0.0 crashes if you try to convert
|
||||
# float64 directly to bfloat16, so make an intermediate step.
|
||||
re_out_ = re_out_.to(tl.float32)
|
||||
im_out_ = im_out_.to(tl.float32)
|
||||
tl.store(out + cols_re, re_out_, mask=mask)
|
||||
tl.store(out + cols_im, im_out_, mask=mask)
|
||||
125
pkgs/xformers/ops/unbind.py
Normal file
125
pkgs/xformers/ops/unbind.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .common import _get_storage_base
|
||||
|
||||
|
||||
def get_stack_strides(
|
||||
tensors: Sequence[torch.Tensor], dim: int
|
||||
) -> Optional[Tuple[int, ...]]:
|
||||
"""
|
||||
If the tensors are already stacked on dimension :code:`dim`, \
|
||||
returns the strides of the stacked tensors. \
|
||||
Otherwise returns :code:`None`.
|
||||
"""
|
||||
if len(tensors) <= 1 or dim > tensors[0].ndim:
|
||||
return None
|
||||
|
||||
final_stride = []
|
||||
for i in range(tensors[0].ndim + 1):
|
||||
if i == dim:
|
||||
final_stride.append(
|
||||
tensors[1].storage_offset() - tensors[0].storage_offset()
|
||||
)
|
||||
continue
|
||||
if i > dim:
|
||||
i -= 1
|
||||
final_stride.append(tensors[0].stride(i))
|
||||
|
||||
storage_data_ptr: Optional[int] = None
|
||||
for i, x in enumerate(tensors[1:]):
|
||||
# Sanity checks
|
||||
if x.shape != tensors[0].shape:
|
||||
return None
|
||||
if x.stride() != tensors[0].stride():
|
||||
return None
|
||||
if (
|
||||
x.storage_offset()
|
||||
!= tensors[0].storage_offset() + (i + 1) * final_stride[dim]
|
||||
):
|
||||
return None
|
||||
if storage_data_ptr is None:
|
||||
storage_data_ptr = _get_storage_base(tensors[0])
|
||||
# Actual storage check
|
||||
if _get_storage_base(x) != storage_data_ptr:
|
||||
return None
|
||||
return tuple(final_stride)
|
||||
|
||||
|
||||
def _stack_or_none_fw(
|
||||
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
|
||||
dim: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
strides = get_stack_strides(tensors, dim)
|
||||
if strides is not None:
|
||||
input_shape = list(tensors[0].shape)
|
||||
input_shape.insert(dim, len(tensors))
|
||||
return tensors[0].as_strided(input_shape, strides)
|
||||
return None
|
||||
|
||||
|
||||
def _stack_fw(
|
||||
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
|
||||
dim: int,
|
||||
) -> torch.Tensor:
|
||||
out = _stack_or_none_fw(tensors, dim)
|
||||
if out is None:
|
||||
out = torch.stack(tensors, dim=dim)
|
||||
return out
|
||||
|
||||
|
||||
class _Unbind(torch.autograd.Function):
|
||||
"""
|
||||
See function `unbind`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx, x: torch.Tensor, dim: int):
|
||||
ctx.dim = dim
|
||||
return x.unbind(dim)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def backward(cls, ctx, *tensors: torch.Tensor):
|
||||
return _stack_fw(tensors, ctx.dim), None
|
||||
|
||||
|
||||
class _StackOrNone(torch.autograd.Function):
|
||||
"""
|
||||
See function `stack_or_none`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(ctx, dim: int, *tensors: torch.Tensor):
|
||||
ctx.dim = dim
|
||||
return _stack_or_none_fw(tensors, dim=dim)
|
||||
|
||||
@classmethod
|
||||
# type: ignore
|
||||
def backward(cls, ctx, grad: torch.Tensor):
|
||||
return (None, *grad.unbind(dim=ctx.dim))
|
||||
|
||||
|
||||
def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Does exactly the same as :attr:`torch.unbind` for the forward.
|
||||
In backward, avoids a :attr:`torch.cat` if the gradients
|
||||
are already multiple views of the same storage
|
||||
"""
|
||||
return _Unbind.apply(x, dim)
|
||||
|
||||
|
||||
def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
|
||||
without any memory operation. Otherwise returns None.
|
||||
"""
|
||||
return _StackOrNone.apply(dim, *tensors)
|
||||
Reference in New Issue
Block a user