667 lines
21 KiB
Python
667 lines
21 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from dataclasses import replace
|
|
from itertools import zip_longest
|
|
from typing import Any, List, Optional, Set, Tuple, Union
|
|
import os
|
|
import torch
|
|
|
|
from ..common import _get_storage_base, get_operator, register_operator
|
|
from .attn_bias import (
|
|
AttentionBias,
|
|
BlockDiagonalCausalFromBottomRightMask,
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
BlockDiagonalCausalLocalAttentionMask,
|
|
BlockDiagonalCausalMask,
|
|
BlockDiagonalMask,
|
|
LowerTriangularMask,
|
|
)
|
|
from .common import (
|
|
AttentionBwOpBase,
|
|
AttentionFwOpBase,
|
|
Context,
|
|
Gradients,
|
|
Inputs,
|
|
check_lastdim_alignment_stride1,
|
|
)
|
|
global enable_ixdnn
|
|
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
|
|
|
|
FLASH_VERSION = "0.0.0"
|
|
try:
|
|
try:
|
|
from ... import _C_flashattention # type: ignore[attr-defined]
|
|
from ..._cpp_lib import _build_metadata
|
|
|
|
if _build_metadata is not None:
|
|
FLASH_VERSION = _build_metadata.flash_version
|
|
except ImportError:
|
|
import flash_attn
|
|
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
|
|
|
FLASH_VERSION = flash_attn.__version__
|
|
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
|
if flash_ver_parsed < (2, 3):
|
|
raise ImportError("Requires 2.3 for sliding window support")
|
|
|
|
# create library so that flash-attn goes through the PyTorch Dispatcher
|
|
_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
|
|
|
_flash_lib.define(
|
|
"flash_fwd(Tensor query, Tensor key, Tensor value, "
|
|
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
|
"int max_seqlen_q, int max_seqlen_k, "
|
|
"float p, float softmax_scale, "
|
|
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
|
|
)
|
|
|
|
_flash_lib.define(
|
|
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
|
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
|
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
|
"int max_seqlen_q, int max_seqlen_k, "
|
|
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
|
|
)
|
|
|
|
def _flash_fwd(
|
|
query,
|
|
key,
|
|
value,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
is_causal,
|
|
window_size,
|
|
return_softmax,
|
|
use_alibi,
|
|
alibi_mode,
|
|
imp_mode
|
|
):
|
|
if enable_ixdnn:
|
|
(
|
|
out,
|
|
q_padded,
|
|
k_padded,
|
|
v_padded,
|
|
out_padded,
|
|
softmax_lse,
|
|
p,
|
|
) = _C_flashattention.fwd_ixdnn(
|
|
query,
|
|
key,
|
|
value,
|
|
None, # out
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
p,
|
|
is_causal,
|
|
return_softmax,
|
|
use_alibi,
|
|
alibi_mode,
|
|
imp_mode,
|
|
-1, -1,
|
|
None, # rng
|
|
)
|
|
else:
|
|
if cu_seq_lens_q is None:
|
|
assert cu_seq_lens_k is None
|
|
(
|
|
out,
|
|
q_padded,
|
|
k_padded,
|
|
v_padded,
|
|
out_padded,
|
|
softmax_lse,
|
|
p,
|
|
) = _C_flashattention.fwd(
|
|
query,
|
|
key,
|
|
value,
|
|
None, # out
|
|
p,
|
|
softmax_scale,
|
|
is_causal,
|
|
return_softmax,
|
|
use_alibi,
|
|
alibi_mode,
|
|
None, # rng
|
|
)
|
|
else:
|
|
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
|
|
(
|
|
out,
|
|
q_padded,
|
|
k_padded,
|
|
v_padded,
|
|
out_padded,
|
|
softmax_lse,
|
|
p,
|
|
) = _C_flashattention.varlen_fwd(
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
False,
|
|
is_causal,
|
|
return_softmax,
|
|
use_alibi,
|
|
alibi_mode,
|
|
None,
|
|
)
|
|
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
|
|
|
|
def _flash_bwd(
|
|
grad,
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
is_causal,
|
|
window_size,
|
|
use_alibi,
|
|
alibi_mode,
|
|
imp_mode
|
|
):
|
|
if enable_ixdnn:
|
|
_C_flashattention.bwd_ixdnn(
|
|
grad,
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
p,
|
|
is_causal,
|
|
use_alibi,
|
|
alibi_mode,
|
|
imp_mode,
|
|
-1, -1,
|
|
None,
|
|
)
|
|
else:
|
|
if cu_seq_lens_k is None:
|
|
assert cu_seq_lens_q is None
|
|
_C_flashattention.bwd(
|
|
grad,
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
p,
|
|
softmax_scale,
|
|
is_causal,
|
|
use_alibi,
|
|
alibi_mode,
|
|
None,
|
|
)
|
|
else:
|
|
_C_flashattention.varlen_bwd(
|
|
grad,
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
False, # zero_tensors
|
|
is_causal,
|
|
use_alibi,
|
|
alibi_mode,
|
|
None,
|
|
)
|
|
return dq, dk, dv
|
|
|
|
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
|
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def _convert_input_format(
|
|
inp: Inputs,
|
|
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
|
|
assert inp.query.ndim in [4, 5]
|
|
query, key, value = inp.query, inp.key, inp.value
|
|
batch = query.shape[0]
|
|
seqlen_q = query.shape[1]
|
|
seqlen_kv = key.shape[1]
|
|
head_dim_q = query.shape[-1]
|
|
head_dim_v = value.shape[-1]
|
|
|
|
attn_bias = inp.attn_bias
|
|
if enable_ixdnn:
|
|
if query.shape[0] == 1:
|
|
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
|
|
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
|
|
else:
|
|
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
|
|
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
|
|
max_seqlen_q = inp.query.shape[1]
|
|
max_seqlen_k = inp.key.shape[1]
|
|
else:
|
|
if isinstance(attn_bias, BlockDiagonalMask):
|
|
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
|
|
inp.query.device, non_blocking=True
|
|
)
|
|
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
|
|
inp.query.device, non_blocking=True
|
|
)
|
|
|
|
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
|
|
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
|
|
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
|
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
|
else:
|
|
cu_seqlen_k = None
|
|
cu_seqlen_q = None
|
|
max_seqlen_q = inp.query.shape[1]
|
|
max_seqlen_k = inp.key.shape[1]
|
|
|
|
if query.ndim == 5: # QGA
|
|
# Fold the group/head_in_group dimensions together
|
|
def fold(x):
|
|
# Either the head is replicated
|
|
if x.stride(3) == 0:
|
|
return x[:, :, :, 0]
|
|
# Or we reshape
|
|
return x.reshape(
|
|
[
|
|
x.shape[0],
|
|
x.shape[1],
|
|
-1,
|
|
x.shape[4],
|
|
]
|
|
)
|
|
|
|
query = fold(query)
|
|
key = fold(key)
|
|
value = fold(value)
|
|
# Optimize for MHA
|
|
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
|
|
key = key[:, :, :1]
|
|
value = value[:, :, :1]
|
|
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
|
|
# We want format `[batch * seqlen, num_heads, head_dim_q]`
|
|
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
|
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
|
|
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
|
|
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
|
|
new_inp = replace(
|
|
inp,
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
)
|
|
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
|
|
|
|
|
|
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
|
|
return isinstance(
|
|
attn_bias,
|
|
(
|
|
LowerTriangularMask,
|
|
BlockDiagonalCausalMask,
|
|
BlockDiagonalCausalLocalAttentionMask,
|
|
BlockDiagonalCausalFromBottomRightMask,
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
),
|
|
)
|
|
|
|
|
|
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
|
if isinstance(
|
|
attn_bias,
|
|
(BlockDiagonalCausalLocalAttentionMask,),
|
|
):
|
|
return attn_bias._window_size or 0
|
|
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
|
|
return attn_bias._window_size
|
|
return 0
|
|
|
|
|
|
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
|
|
# Flash does not support TopLeft, so only allow causal masks with TopLeft
|
|
# if each batch element has equal number of queries and keys.
|
|
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
|
|
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
|
|
# if each batch element has equal number of queries and keys.
|
|
for k_start, q_start in zip_longest(
|
|
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
|
|
):
|
|
if k_start != q_start:
|
|
reasons.append(
|
|
"Only support BlockDiagonalCausalMask if equal"
|
|
" numbers of keys and queries"
|
|
)
|
|
break
|
|
elif isinstance(d.attn_bias, LowerTriangularMask):
|
|
if d.query.shape[1] != d.key.shape[1]:
|
|
reasons.append(
|
|
"Only support LowerTriangularMask if equal number of" "keys and queries"
|
|
)
|
|
|
|
|
|
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
|
|
"""
|
|
We want to be able to collapse the G/H dimensions together
|
|
"""
|
|
if x.ndim == 5:
|
|
stride_g, stride_h = x.stride(2), x.stride(3)
|
|
if x.shape[2] == 1:
|
|
return
|
|
if x.shape[3] == 1 or stride_h == 0:
|
|
return
|
|
if stride_g != stride_h * x.shape[-2]:
|
|
reasons.append(
|
|
f"GQA is only supported when the G/H dimensions are contiguous\n"
|
|
f" {name}.stride: {x.stride()}\n"
|
|
f" {name}.shape : {list(x.shape)}"
|
|
)
|
|
|
|
|
|
@register_operator
|
|
class FwOp(AttentionFwOpBase):
|
|
"""Operator that computes memory-efficient attention using \
|
|
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
|
|
implementation.
|
|
"""
|
|
|
|
OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
|
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
|
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
|
SUPPORTED_MAX_K = 256
|
|
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
|
type(None),
|
|
LowerTriangularMask,
|
|
BlockDiagonalMask,
|
|
BlockDiagonalCausalMask,
|
|
BlockDiagonalCausalLocalAttentionMask,
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
|
BlockDiagonalCausalFromBottomRightMask,
|
|
}
|
|
SUPPORTS_DROPOUT = True
|
|
SUPPORTS_CUSTOM_SCALE = True
|
|
SUPPORTS_DIFFERENT_VALUE_EMBED = False
|
|
SUPPORTS_BMGHK = True
|
|
NAME = f"flshattF@{FLASH_VERSION}"
|
|
VERSION = FLASH_VERSION
|
|
|
|
@classmethod
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
_check_needs_no_topleft(d, reasons)
|
|
_check_strides_for_bmghk(d.query, "query", reasons)
|
|
_check_strides_for_bmghk(d.key, "key", reasons)
|
|
_check_strides_for_bmghk(d.value, "value", reasons)
|
|
return reasons
|
|
|
|
@classmethod
|
|
def apply(
|
|
cls, inp: Inputs, needs_gradient: bool
|
|
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
return_softmax = False
|
|
out_shape = [
|
|
*inp.query.shape[:-1],
|
|
inp.value.shape[-1],
|
|
]
|
|
# no cumulative seqlen
|
|
(
|
|
inp,
|
|
cu_seqlens_q,
|
|
max_seqlen_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_k,
|
|
) = _convert_input_format(inp)
|
|
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
|
|
inp.query,
|
|
inp.key,
|
|
inp.value,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
inp.p,
|
|
inp.scale_float,
|
|
_is_causal(inp.attn_bias),
|
|
_window_size(inp.attn_bias),
|
|
return_softmax,
|
|
inp.use_alibi,
|
|
inp.alibi_mode,
|
|
inp.imp_mode,
|
|
)
|
|
out = out.reshape(out_shape)
|
|
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
|
|
if inp.p != 0.0:
|
|
ctx.op_bw = BwOp
|
|
return (out, ctx)
|
|
|
|
@classmethod
|
|
# type: ignore
|
|
def operator_flop(
|
|
cls,
|
|
query,
|
|
key,
|
|
value,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
causal,
|
|
return_softmax,
|
|
) -> int:
|
|
return cls.attn_operator_flop(
|
|
query.unsqueeze(0),
|
|
key.unsqueeze(0),
|
|
value.unsqueeze(0),
|
|
causal=causal,
|
|
seqstart_k=cu_seq_lens_k,
|
|
seqstart_q=cu_seq_lens_q,
|
|
)
|
|
|
|
|
|
@register_operator
|
|
class BwOp(AttentionBwOpBase):
|
|
__doc__ = FwOp.__doc__
|
|
|
|
OPERATOR = get_operator("xformers_flash", "flash_bwd")
|
|
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
|
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
|
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
|
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
|
|
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
|
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
|
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
|
IS_DETERMINISTIC = False
|
|
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
|
|
NAME = f"flshattB@{FLASH_VERSION}"
|
|
VERSION = FLASH_VERSION
|
|
|
|
MAX_HEADDIM_SM8x = 192
|
|
|
|
@classmethod
|
|
def shape_not_supported_reasons(
|
|
cls, Mq: int, Mkv: int, K: int, Kv: int
|
|
) -> List[str]:
|
|
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
|
|
|
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
|
|
# is a strange embedding dimension.
|
|
# if K not in {8, 16, 32, 64, 128, 256}:
|
|
# reasons.append(f"Embed dim {K} not supported")
|
|
|
|
return reasons
|
|
|
|
@classmethod
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
reasons = super(BwOp, cls).not_supported_reasons(d)
|
|
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
_check_needs_no_topleft(d, reasons)
|
|
if d.device.type == "cuda":
|
|
# Due to limited shared-memory, some GPUs are limited in head dimension
|
|
device_capability = torch.cuda.get_device_capability(d.device)
|
|
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
|
|
if (
|
|
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
|
|
and not is_sm80_or_sm90
|
|
):
|
|
reasons.append(
|
|
"requires a GPU with compute capability 8.0 "
|
|
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
|
|
)
|
|
return reasons
|
|
|
|
@classmethod
|
|
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
|
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
|
|
(
|
|
inp,
|
|
cu_seqlens_q,
|
|
max_seqlen_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_k,
|
|
) = _convert_input_format(inp)
|
|
assert ctx.lse.is_contiguous
|
|
ctx_lse = ctx.lse
|
|
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
|
|
assert ctx_lse.shape[2] >= max_seqlen_q
|
|
if max_seqlen_q != ctx_lse.shape[2]:
|
|
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
|
|
kernel_out_shape = [
|
|
*inp.query.shape[:-1],
|
|
inp.value.shape[-1],
|
|
]
|
|
|
|
# Create dq,dk,dv
|
|
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
|
|
# right strides, so we can avoid a `cat`
|
|
if (
|
|
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
|
|
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
|
|
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
|
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
|
|
):
|
|
# Create one big contiguous chunk
|
|
# This is because q, k and v usually come from a single
|
|
# output of a linear layer that is chunked.
|
|
# Creating the gradients with the right layout saves us
|
|
# a `torch.cat` call in the backward pass
|
|
chunk = torch.empty(
|
|
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
|
|
dtype=ctx.q_padded.dtype,
|
|
device=inp.device,
|
|
)
|
|
grads = Gradients(
|
|
dq=chunk.select(-3, 0),
|
|
dk=chunk.select(-3, 1),
|
|
dv=chunk.select(-3, 2),
|
|
)
|
|
else:
|
|
grads = Gradients(
|
|
dq=torch.empty_like(ctx.q_padded),
|
|
dk=torch.empty_like(ctx.k_padded),
|
|
dv=torch.empty_like(ctx.v_padded),
|
|
)
|
|
|
|
assert grad.dtype in cls.SUPPORTED_DTYPES
|
|
cls.OPERATOR(
|
|
grad.reshape(kernel_out_shape).contiguous(),
|
|
ctx.q_padded,
|
|
ctx.k_padded,
|
|
ctx.v_padded,
|
|
# ctx.out.reshape(kernel_out_shape),
|
|
ctx.o_padded,
|
|
ctx_lse,
|
|
grads.dq,
|
|
grads.dk,
|
|
grads.dv,
|
|
cu_seqlens_q,
|
|
cu_seqlens_k,
|
|
max_seqlen_q,
|
|
max_seqlen_k,
|
|
inp.p,
|
|
inp.scale_float,
|
|
_is_causal(inp.attn_bias),
|
|
_window_size(inp.attn_bias),
|
|
inp.use_alibi,
|
|
inp.alibi_mode,
|
|
inp.imp_mode,
|
|
)
|
|
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
|
|
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
|
|
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
|
|
|
|
return grads
|
|
|
|
@classmethod
|
|
# type: ignore
|
|
def operator_flop(
|
|
cls,
|
|
grad,
|
|
query,
|
|
key,
|
|
value,
|
|
out,
|
|
lse,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
cu_seq_lens_q,
|
|
cu_seq_lens_k,
|
|
max_seq_len_q,
|
|
max_seq_len_k,
|
|
p,
|
|
softmax_scale,
|
|
causal,
|
|
) -> int:
|
|
return cls.attn_operator_flop(
|
|
query.unsqueeze(0),
|
|
key.unsqueeze(0),
|
|
value.unsqueeze(0),
|
|
causal=causal,
|
|
seqstart_k=cu_seq_lens_k,
|
|
seqstart_q=cu_seq_lens_q,
|
|
)
|