480 lines
16 KiB
Python
480 lines
16 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
|
||
|
|
from dataclasses import replace
|
||
|
|
from enum import Enum
|
||
|
|
from typing import Any, List, Mapping, Optional, Set, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
from ..common import get_xformers_operator, register_operator
|
||
|
|
from . import attn_bias
|
||
|
|
from .attn_bias import (
|
||
|
|
AttentionBias,
|
||
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
BlockDiagonalCausalMask,
|
||
|
|
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||
|
|
BlockDiagonalMask,
|
||
|
|
LowerTriangularMask,
|
||
|
|
LowerTriangularMaskWithTensorBias,
|
||
|
|
)
|
||
|
|
from .common import (
|
||
|
|
AttentionBwOpBase,
|
||
|
|
AttentionFwOpBase,
|
||
|
|
Context,
|
||
|
|
Gradients,
|
||
|
|
Inputs,
|
||
|
|
check_lastdim_alignment_stride1,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _uses_tensorcores(sm: int, is_half: bool) -> bool:
|
||
|
|
if sm >= 80:
|
||
|
|
return True
|
||
|
|
if sm >= 70:
|
||
|
|
return is_half
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def _minimum_gemm_alignment(inp: Inputs) -> int:
|
||
|
|
if inp.device.type != "cuda":
|
||
|
|
return 1
|
||
|
|
cap = torch.cuda.get_device_capability(inp.device)
|
||
|
|
sm = cap[0] * 10 + cap[1]
|
||
|
|
bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[
|
||
|
|
inp.query.dtype
|
||
|
|
]
|
||
|
|
uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16)
|
||
|
|
matmul_alignment_mn = 1
|
||
|
|
if sm >= 80:
|
||
|
|
matmul_alignment_mn = 4
|
||
|
|
if uses_tensorcores:
|
||
|
|
matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar)
|
||
|
|
return matmul_alignment_mn
|
||
|
|
|
||
|
|
|
||
|
|
def _get_seqlen_info(
|
||
|
|
inp: Inputs,
|
||
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]:
|
||
|
|
attn_bias = inp.attn_bias
|
||
|
|
if isinstance(
|
||
|
|
attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||
|
|
):
|
||
|
|
attn_bias.k_seqinfo.to(inp.query.device)
|
||
|
|
attn_bias.q_seqinfo.to(inp.query.device)
|
||
|
|
seqstart_k = attn_bias.k_seqinfo.seqstart
|
||
|
|
seqstart_q = attn_bias.q_seqinfo.seqstart
|
||
|
|
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
|
||
|
|
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
|
||
|
|
else:
|
||
|
|
seqstart_k = None
|
||
|
|
seqstart_q = None
|
||
|
|
max_seqlen_q = -1
|
||
|
|
max_seqlen_k = -1
|
||
|
|
|
||
|
|
return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k
|
||
|
|
|
||
|
|
|
||
|
|
def _get_tensor_bias(
|
||
|
|
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||
|
|
) -> Optional[torch.Tensor]:
|
||
|
|
if isinstance(attn_bias, torch.Tensor):
|
||
|
|
return attn_bias
|
||
|
|
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
|
||
|
|
return attn_bias._bias
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def _check_bias_alignment(
|
||
|
|
reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
|
||
|
|
) -> None:
|
||
|
|
attn_bias_tensor = _get_tensor_bias(attn_bias)
|
||
|
|
if attn_bias_tensor is not None:
|
||
|
|
alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits
|
||
|
|
show_padding_hint = False
|
||
|
|
for d in range(attn_bias_tensor.ndim - 1):
|
||
|
|
if attn_bias_tensor.stride(d) % alignment != 0:
|
||
|
|
reasons.append(
|
||
|
|
f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})"
|
||
|
|
)
|
||
|
|
show_padding_hint = True
|
||
|
|
if show_padding_hint:
|
||
|
|
reasons.append(
|
||
|
|
"""\
|
||
|
|
HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \
|
||
|
|
you need to ensure memory is aligned by slicing a bigger tensor. \
|
||
|
|
Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`"""
|
||
|
|
)
|
||
|
|
# We can have stride=0 sometimes if dimension=1
|
||
|
|
if attn_bias_tensor.stride(-1) > 1:
|
||
|
|
reasons.append(
|
||
|
|
f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - "
|
||
|
|
"you should call `.contiguous()` on the bias"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class _CustomMaskType(int, Enum):
|
||
|
|
"""
|
||
|
|
(Matches CustomMaskType in C++.)
|
||
|
|
"""
|
||
|
|
|
||
|
|
NoCustomMask = 0
|
||
|
|
CausalFromTopLeft = 1
|
||
|
|
CausalFromBottomRight = 2
|
||
|
|
|
||
|
|
|
||
|
|
def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
|
||
|
|
if isinstance(
|
||
|
|
bias,
|
||
|
|
(
|
||
|
|
LowerTriangularMask,
|
||
|
|
BlockDiagonalCausalMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
),
|
||
|
|
):
|
||
|
|
return int(_CustomMaskType.CausalFromTopLeft)
|
||
|
|
if isinstance(
|
||
|
|
bias,
|
||
|
|
(
|
||
|
|
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||
|
|
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||
|
|
),
|
||
|
|
):
|
||
|
|
return int(_CustomMaskType.CausalFromBottomRight)
|
||
|
|
return int(_CustomMaskType.NoCustomMask)
|
||
|
|
|
||
|
|
|
||
|
|
@register_operator
|
||
|
|
class FwOp(AttentionFwOpBase):
|
||
|
|
"""xFormers' MHA kernel based on CUTLASS.
|
||
|
|
Supports a large number of settings (including without TensorCores, f32 ...)
|
||
|
|
and GPUs as old as P100 (Sm60)
|
||
|
|
"""
|
||
|
|
|
||
|
|
OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass")
|
||
|
|
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||
|
|
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16}
|
||
|
|
SUPPORTED_MAX_K = 65536
|
||
|
|
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||
|
|
type(None),
|
||
|
|
torch.Tensor,
|
||
|
|
LowerTriangularMask,
|
||
|
|
LowerTriangularMaskWithTensorBias,
|
||
|
|
BlockDiagonalMask,
|
||
|
|
BlockDiagonalCausalMask,
|
||
|
|
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||
|
|
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||
|
|
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||
|
|
}
|
||
|
|
SUPPORTS_DROPOUT = True
|
||
|
|
SUPPORTS_CUSTOM_SCALE = True
|
||
|
|
SUPPORTS_DIFFERENT_VALUE_EMBED = True
|
||
|
|
SUPPORTS_BMGHK = True
|
||
|
|
NAME = "cutlassF"
|
||
|
|
|
||
|
|
_TEST_K: List[int] = [
|
||
|
|
32, # 64x64 kernel
|
||
|
|
128, # 64x128 kernel
|
||
|
|
256, # 64x128 with accumulation in gmem
|
||
|
|
]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def apply(
|
||
|
|
cls, inp: Inputs, needs_gradient: bool
|
||
|
|
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||
|
|
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||
|
|
raise NotImplementedError("Unsupported attn_bias type")
|
||
|
|
if inp.query.ndim in [3, 4]:
|
||
|
|
return cls.apply_bmhk(inp, needs_gradient=needs_gradient)
|
||
|
|
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
|
||
|
|
|
||
|
|
# Workaround until this is properly implemented in C++
|
||
|
|
# run each head group in a different stream
|
||
|
|
n_groups = inp.key.shape[2]
|
||
|
|
main_stream = torch.cuda.current_stream()
|
||
|
|
streams = [main_stream] + [
|
||
|
|
torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1)
|
||
|
|
]
|
||
|
|
outs = []
|
||
|
|
for group, stream in enumerate(streams):
|
||
|
|
stream.wait_stream(main_stream)
|
||
|
|
with torch.cuda.stream(stream):
|
||
|
|
query = inp.query[:, :, group]
|
||
|
|
key = inp.key[:, :, group]
|
||
|
|
value = inp.value[:, :, group]
|
||
|
|
bias = inp.attn_bias
|
||
|
|
if isinstance(bias, torch.Tensor):
|
||
|
|
bias = bias[:, group]
|
||
|
|
if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias):
|
||
|
|
bias = attn_bias.LowerTriangularMaskWithTensorBias(
|
||
|
|
bias._bias[:, group]
|
||
|
|
)
|
||
|
|
outs.append(
|
||
|
|
cls.apply_bmhk(
|
||
|
|
replace(inp, query=query, key=key, value=value, attn_bias=bias),
|
||
|
|
needs_gradient=needs_gradient,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
for s in streams[1:]:
|
||
|
|
main_stream.wait_stream(s)
|
||
|
|
out = torch.stack([o[0] for o in outs], dim=2)
|
||
|
|
ctx: Optional[Context] = None
|
||
|
|
if needs_gradient:
|
||
|
|
ctx = Context(
|
||
|
|
out=out,
|
||
|
|
lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore
|
||
|
|
op_bw=outs[0][1].op_bw, # type: ignore
|
||
|
|
)
|
||
|
|
return out, ctx
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def apply_bmhk(
|
||
|
|
cls, inp: Inputs, needs_gradient: bool
|
||
|
|
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||
|
|
if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||
|
|
raise NotImplementedError("Unsupported attn_bias type")
|
||
|
|
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
|
||
|
|
out, lse, rng_seed, rng_offset = cls.OPERATOR(
|
||
|
|
query=inp.query,
|
||
|
|
key=inp.key,
|
||
|
|
value=inp.value,
|
||
|
|
attn_bias=_get_tensor_bias(inp.attn_bias),
|
||
|
|
seqstart_q=seqstart_q,
|
||
|
|
seqstart_k=seqstart_k,
|
||
|
|
max_seqlen_q=max_seqlen_q,
|
||
|
|
dropout_p=inp.p,
|
||
|
|
compute_logsumexp=needs_gradient,
|
||
|
|
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||
|
|
scale=inp.scale,
|
||
|
|
seqlen_k=inp.attn_bias.k_seqinfo.seqlen
|
||
|
|
if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||
|
|
else None,
|
||
|
|
window_size=inp.attn_bias._window_size
|
||
|
|
if isinstance(
|
||
|
|
inp.attn_bias,
|
||
|
|
(
|
||
|
|
BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
else None,
|
||
|
|
)
|
||
|
|
ctx: Optional[Context] = None
|
||
|
|
if needs_gradient:
|
||
|
|
ctx = Context(
|
||
|
|
out=out,
|
||
|
|
lse=lse,
|
||
|
|
# cutlass forward is only compatible with cutlass backward if
|
||
|
|
# dropout is used (because of the way RNG states are passed and the
|
||
|
|
# way random numbers are generated during backward)
|
||
|
|
op_bw=BwOp if inp.p != 0 else None,
|
||
|
|
)
|
||
|
|
if inp.p != 0:
|
||
|
|
ctx.rng_state = torch.tensor(
|
||
|
|
[rng_seed, rng_offset], dtype=torch.int64, device="cpu"
|
||
|
|
)
|
||
|
|
return out, ctx
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||
|
|
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||
|
|
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||
|
|
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||
|
|
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||
|
|
_check_bias_alignment(reasons, d.attn_bias)
|
||
|
|
return reasons
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
# type: ignore
|
||
|
|
def operator_flop(
|
||
|
|
cls,
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
b,
|
||
|
|
seqstart_q,
|
||
|
|
seqstart_k,
|
||
|
|
max_seqlen_q_,
|
||
|
|
compute_lse,
|
||
|
|
custom_mask_type,
|
||
|
|
*a,
|
||
|
|
) -> int:
|
||
|
|
return cls.attn_operator_flop(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
causal=custom_mask_type > 0,
|
||
|
|
seqstart_k=seqstart_k,
|
||
|
|
seqstart_q=seqstart_q,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@register_operator
|
||
|
|
class BwOp(AttentionBwOpBase):
|
||
|
|
__doc__ = FwOp.__doc__
|
||
|
|
|
||
|
|
OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass")
|
||
|
|
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
|
||
|
|
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
|
||
|
|
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
|
||
|
|
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||
|
|
type(None),
|
||
|
|
torch.Tensor,
|
||
|
|
LowerTriangularMask,
|
||
|
|
# TODO: Fix handling of gradient through the fMHA autograd function
|
||
|
|
# LowerTriangularMaskWithTensorBias,
|
||
|
|
BlockDiagonalMask,
|
||
|
|
BlockDiagonalCausalMask,
|
||
|
|
attn_bias.BlockDiagonalCausalFromBottomRightMask,
|
||
|
|
attn_bias.BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
}
|
||
|
|
SUPPORTS_ATTN_BIAS_GRAD = True
|
||
|
|
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
|
||
|
|
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
|
||
|
|
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
|
||
|
|
NAME = "cutlassB"
|
||
|
|
|
||
|
|
ERROR_ATOL: Mapping[torch.dtype, float] = {
|
||
|
|
torch.float: 5e-4,
|
||
|
|
# increased from 9e-2, more opportunities for numerical errors when bias is
|
||
|
|
# used, noticed in gK on SM80
|
||
|
|
torch.half: 1e-1,
|
||
|
|
torch.bfloat16: 7e-1,
|
||
|
|
}
|
||
|
|
|
||
|
|
_TEST_K: List[int] = [
|
||
|
|
32, # 64x64 kernel
|
||
|
|
128, # 64x128/128x128 kernel
|
||
|
|
256, # 64x128 with accumulation in gmem
|
||
|
|
]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||
|
|
reasons = super(BwOp, cls).not_supported_reasons(d)
|
||
|
|
matmul_alignment_mn = _minimum_gemm_alignment(d)
|
||
|
|
|
||
|
|
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
|
||
|
|
check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn)
|
||
|
|
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
|
||
|
|
_check_bias_alignment(reasons, d.attn_bias)
|
||
|
|
attn_bias_tensor = _get_tensor_bias(d.attn_bias)
|
||
|
|
|
||
|
|
# Backprop of gradient through broadcasted bias is not supported
|
||
|
|
if attn_bias_tensor is not None and attn_bias_tensor.requires_grad:
|
||
|
|
# Don't forget that inputs are either in BMK or BMHK!
|
||
|
|
if d.query.ndim == 3 and attn_bias_tensor.ndim == 3:
|
||
|
|
expected_bias_shape = (*d.query.shape[:2], d.key.shape[1])
|
||
|
|
else:
|
||
|
|
# bias is B H Mq Mk
|
||
|
|
expected_bias_shape = (
|
||
|
|
d.query.shape[0],
|
||
|
|
d.query.shape[2] if d.query.ndim == 4 else 1,
|
||
|
|
d.query.shape[1],
|
||
|
|
d.key.shape[1],
|
||
|
|
)
|
||
|
|
if tuple(attn_bias_tensor.shape) != expected_bias_shape:
|
||
|
|
reasons.append(
|
||
|
|
"Broadcasting the `attn_bias` tensor is not supported "
|
||
|
|
f"(shape: {tuple(attn_bias_tensor.shape)}"
|
||
|
|
f"/ expected: {expected_bias_shape})"
|
||
|
|
)
|
||
|
|
return reasons
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
|
||
|
|
if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES:
|
||
|
|
raise NotImplementedError("Unsupported attn_bias type")
|
||
|
|
|
||
|
|
seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp)
|
||
|
|
dtype = inp.query.dtype
|
||
|
|
|
||
|
|
rng_seed = rng_offset = 0
|
||
|
|
if inp.p != 0.0:
|
||
|
|
if (
|
||
|
|
ctx.rng_state is None
|
||
|
|
or ctx.rng_state.dtype != torch.int64
|
||
|
|
or ctx.rng_state.device.type != "cpu"
|
||
|
|
or ctx.rng_state.shape != (2,)
|
||
|
|
):
|
||
|
|
raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}")
|
||
|
|
rng_seed, rng_offset = ctx.rng_state.tolist()
|
||
|
|
|
||
|
|
force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
|
||
|
|
(grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR(
|
||
|
|
grad.to(dtype),
|
||
|
|
inp.query,
|
||
|
|
inp.key,
|
||
|
|
inp.value,
|
||
|
|
_get_tensor_bias(inp.attn_bias),
|
||
|
|
cu_seqlens_q=seqstart_q,
|
||
|
|
cu_seqlens_k=seqstart_k,
|
||
|
|
max_seqlen_q=max_seqlen_q,
|
||
|
|
max_seqlen_k=max_seqlen_k,
|
||
|
|
logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
|
||
|
|
output=ctx.out.to(dtype),
|
||
|
|
dropout_p=inp.p,
|
||
|
|
# if not using dropout, seed and offset are irrelevant but still expected
|
||
|
|
# in function signature so just pass 0
|
||
|
|
# seed and offset could be None if a different FW op other than cutlass
|
||
|
|
# was used.
|
||
|
|
rng_seed=rng_seed,
|
||
|
|
rng_offset=rng_offset,
|
||
|
|
custom_mask_type=_custom_mask_type(inp.attn_bias),
|
||
|
|
scale=inp.scale,
|
||
|
|
num_splits_key=-1, # Let C++ determine it
|
||
|
|
window_size=inp.attn_bias._window_size
|
||
|
|
if isinstance(
|
||
|
|
inp.attn_bias,
|
||
|
|
(
|
||
|
|
BlockDiagonalCausalLocalAttentionMask,
|
||
|
|
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
# c++/CUDA implementation returns an uninitialized tensor if bias doesn't
|
||
|
|
# require grad
|
||
|
|
if not (
|
||
|
|
isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad
|
||
|
|
):
|
||
|
|
grad_bias = None
|
||
|
|
|
||
|
|
return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
# type: ignore
|
||
|
|
def operator_flop(
|
||
|
|
cls,
|
||
|
|
dO,
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
b,
|
||
|
|
cu_seqlens_q,
|
||
|
|
cu_seqlens_k,
|
||
|
|
max_seqlen_q,
|
||
|
|
max_seqlen_k,
|
||
|
|
logsumexp,
|
||
|
|
output,
|
||
|
|
dropout_p,
|
||
|
|
rng_seed,
|
||
|
|
rng_offset,
|
||
|
|
custom_mask_type,
|
||
|
|
scale,
|
||
|
|
) -> int:
|
||
|
|
return cls.attn_operator_flop(
|
||
|
|
q,
|
||
|
|
k,
|
||
|
|
v,
|
||
|
|
seqstart_q=cu_seqlens_q,
|
||
|
|
seqstart_k=cu_seqlens_k,
|
||
|
|
causal=custom_mask_type > 0,
|
||
|
|
)
|