109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import Any, List, Optional, Set, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from ..common import get_xformers_operator, register_operator
|
|
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
from .common import AttentionFwOpBase, Context, Inputs
|
|
|
|
|
|
@register_operator
|
|
class FwOp(AttentionFwOpBase):
|
|
"""An operator optimized for very small values of K (``K <= 32``) \
|
|
and f32 pre-Ampere as it does not use TensorCores.
|
|
Only supports contiguous inputs in BMK format, so an extra reshape \
|
|
or contiguous call might be done.
|
|
|
|
:Deprecated:
|
|
|
|
This operator is deprecated and should not be used in new code
|
|
"""
|
|
|
|
OPERATOR = get_xformers_operator("efficient_attention_forward_decoder")
|
|
SUPPORTED_DEVICES = {"cuda"}
|
|
SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32}
|
|
CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0)
|
|
SUPPORTED_MAX_K: float = 128
|
|
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask}
|
|
SUPPORTS_DROPOUT = False
|
|
SUPPORTS_CUSTOM_SCALE = True
|
|
NAME = "decoderF"
|
|
|
|
@classmethod
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
|
|
attn_bias = d.attn_bias
|
|
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
|
# If we don't get here, we've an error elsewhere
|
|
if d.query.ndim != 4 or d.key.ndim != 4:
|
|
reasons.append("Inputs must be BMHK. BMK not supported")
|
|
|
|
if d.query.shape[0] != 1:
|
|
reasons.append("One formal batch element expected")
|
|
|
|
if d.query.shape[-1] != 128:
|
|
reasons.append("Only head_dim==128 for now.")
|
|
|
|
if d.key.stride(-1) != 1:
|
|
reasons.append("expect keys to have last dim contiguous")
|
|
|
|
if d.value.stride(-1) != 1:
|
|
reasons.append("expect values to have last dim contiguous")
|
|
|
|
q_starts = attn_bias.q_seqinfo.seqstart_py
|
|
if attn_bias.q_seqinfo.max_seqlen != 1:
|
|
reasons.append("decoding expects one query")
|
|
elif d.query.shape[1] != len(q_starts) - 1:
|
|
reasons.append("empty lanes not supported yet")
|
|
|
|
if attn_bias.k_seqinfo.padding > 8192:
|
|
reasons.append("key padding exceeds 8192")
|
|
|
|
return reasons
|
|
|
|
@classmethod
|
|
def apply(
|
|
cls, inp: Inputs, needs_gradient: bool
|
|
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
if needs_gradient:
|
|
raise NotImplementedError("gradient")
|
|
attn_bias = inp.attn_bias
|
|
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
|
|
attn_bias.k_seqinfo.to(inp.query.device)
|
|
attn_bias.q_seqinfo.to(inp.query.device)
|
|
|
|
padding = attn_bias.k_seqinfo.padding
|
|
multiquery = inp.key.stride(2) == 0
|
|
if multiquery:
|
|
key = inp.key[0, :, :1].unflatten(0, (-1, padding))
|
|
value = inp.value[0, :, :1].unflatten(0, (-1, padding))
|
|
else:
|
|
key = inp.key[0].unflatten(0, (-1, padding))
|
|
value = inp.value[0].unflatten(0, (-1, padding))
|
|
|
|
seq_positions = attn_bias.k_seqinfo.seqlen
|
|
|
|
query = inp.query[0, :, None]
|
|
|
|
if inp.scale is not None:
|
|
qk_scale = inp.scale
|
|
else:
|
|
qk_scale = 1.0 / np.sqrt(key.shape[-1])
|
|
|
|
out = cls.OPERATOR(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
seq_positions=seq_positions,
|
|
scale=qk_scale,
|
|
)
|
|
return out, None
|