# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import math from dataclasses import dataclass from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union import torch from ..._cpp_lib import _built_with_cuda from ..common import BaseOperator from .attn_bias import ( AttentionBias, BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias, ) def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool: # NoneType if isinstance(None, attn_bias_type): return True if attn_bias_type in [LowerTriangularMask, torch.Tensor]: return True return False @dataclass class Inputs: """ Stores inputs to the `memory_efficient_attention` operators """ query: torch.Tensor key: torch.Tensor value: torch.Tensor attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None p: float = 0.0 scale: Optional[float] = None use_alibi: bool = False alibi_mode: int = 1 imp_mode: int = 0 @property def device(self) -> torch.device: return self.query.device @property def scale_float(self) -> float: return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if self.query.ndim == 5: return self.query, self.key, self.value if self.query.ndim == 4: return ( self.query.unsqueeze(2), self.key.unsqueeze(2), self.value.unsqueeze(2), ) if self.value.ndim == 3: return ( self.query[:, :, None, None], self.key[:, :, None, None], self.value[:, :, None, None], ) assert False def normalize_bmhk(self) -> Tuple[int, ...]: if self.query.ndim not in [3, 4, 5]: raise ValueError( f"Invalid shape for query: {self.query.shape}. " "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]" ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." ) if self.value.dtype == torch.int32: # Quantized K/V case, in which the last dims of Q and K are different. # NB we currently don't have any implementations for quantized KV with # SUPPORTS_DIFFERENT_VALUE_EMBED. output_shape = tuple(self.query.shape) else: output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) # Convert from legacy format if self.query.ndim == 3: self.query = self.query.unsqueeze(2) self.key = self.key.unsqueeze(2) self.value = self.value.unsqueeze(2) if isinstance(self.attn_bias, torch.Tensor): if self.attn_bias.ndim != 3: raise ValueError( f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}" ) self.attn_bias = self.attn_bias.unsqueeze(1) return output_shape def validate_inputs(self) -> None: qkv = (self.query, self.key, self.value) if self.query.ndim not in (3, 4, 5) or any( x.ndim != self.query.ndim for x in qkv ): raise ValueError( f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n" f" query.shape: {self.query.shape}\n" f" key.shape : {self.key.shape}\n" f" value.shape: {self.value.shape}" ) if any(x.device != self.query.device for x in qkv): raise ValueError("Query/Key/Value should all be on the same device") quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32 non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv) if not (quantized_dtypes or non_quantized_dtypes): raise ValueError( "Query/Key/Value should either all have the same dtype, or " "(in the quantized case) Key/Value should have dtype torch.int32\n" f" query.dtype: {self.query.dtype}\n" f" key.dtype : {self.key.dtype}\n" f" value.dtype: {self.value.dtype}" ) # Biases with tensors attached are meant to be in BMHK format # This would require to permute biases/gradients which can be expensive, # so let's just forbid it - BMK is a legacy format anyway if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK( type(self.attn_bias) ): raise ValueError( f"Please provide inputs in BMHK format rather " f"than BMK when using bias type `{type(self.attn_bias).__name__}`" ) attn_bias_t: Optional[torch.Tensor] = None if isinstance(self.attn_bias, torch.Tensor): attn_bias_t = self.attn_bias if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias): attn_bias_t = self.attn_bias._bias if self.query.ndim == 4 and attn_bias_t is not None: expected_shape = ( self.query.shape[0], self.query.shape[2], self.query.shape[1], self.key.shape[1], ) if attn_bias_t.shape != expected_shape: raise ValueError( f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n" f" query.shape: {self.query.shape}\n" f" key.shape : {self.key.shape}\n" f" value.shape: {self.value.shape}" ) if isinstance(self.attn_bias, BlockDiagonalMask): if any(x.shape[0] != 1 for x in qkv): raise ValueError( f"Expected batch_size=1 when using block-diagonal bias\n" f" query.shape: {self.query.shape}\n" f" key.shape : {self.key.shape}\n" f" value.shape: {self.value.shape}" ) if self.p < 0.0 or self.p > 1.0: raise ValueError(f"Invalid dropout probability: p={self.p}") # Check that shapes match between inputs B, Mq = self.query.shape[:2] K = self.query.shape[-1] B, Mkv = self.key.shape[:2] Kv = self.value.shape[-1] valid_shapes = True if self.query.ndim == 3: # BMK valid_shapes = ( self.query.shape == (B, Mq, K) and self.key.shape == (B, Mkv, K) and self.value.shape == (B, Mkv, Kv) ) H = self.query.shape[-2] if self.query.ndim == 4: # BMHK quantized_kv_cache = self.value.dtype == torch.int32 key_embed_dim = Kv if quantized_kv_cache else K valid_shapes = ( self.query.shape == (B, Mq, H, K) and self.key.shape == (B, Mkv, H, key_embed_dim) and self.value.shape == (B, Mkv, H, Kv) ) G = self.query.shape[2] if self.query.ndim == 5: # BMNHK valid_shapes = ( self.query.shape == (B, Mq, G, H, K) and self.key.shape == (B, Mkv, G, H, K) and self.value.shape == (B, Mkv, G, H, Kv) ) if not valid_shapes: raise ValueError( f"Incompatible shapes for attention inputs:\n" f" query.shape: {self.query.shape}\n" f" key.shape : {self.key.shape}\n" f" value.shape: {self.value.shape}\n" "HINT: We don't support broadcasting, please use `expand` " "yourself before calling `memory_efficient_attention` if you need to" ) @dataclass class Context: lse: torch.Tensor out: torch.Tensor q_padded: Optional[torch.Tensor] = None k_padded: Optional[torch.Tensor] = None v_padded: Optional[torch.Tensor] = None o_padded: Optional[torch.Tensor] = None op_bw: Optional[Type["AttentionBwOpBase"]] = None rng_state: Optional[torch.Tensor] = None def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to lse = self.lse if pad_amount > 0: if force_pad_inf: lse = lse[:, :, : self.out.shape[1]] pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) elif force_pad_inf and self.out.shape[1] != lse.shape[2]: lse[:, :, self.out.shape[1] :].fill_(math.inf) return lse @dataclass class Gradients: dq: torch.Tensor dk: torch.Tensor dv: torch.Tensor # bias gradient. None if there is no tensor bias or if it doesn't require grad db: Optional[torch.Tensor] = None class AttentionOpBase(BaseOperator): """Base class for any attention operator in xFormers See: - :attr:`xformers.ops.fmha.cutlass.FwOp` - :attr:`xformers.ops.fmha.cutlass.BwOp` - :attr:`xformers.ops.fmha.flash.FwOp` - :attr:`xformers.ops.fmha.flash.BwOp` - :attr:`xformers.ops.fmha.triton.FwOp` - :attr:`xformers.ops.fmha.triton.BwOp` - :attr:`xformers.ops.fmha.small_k.FwOp` - :attr:`xformers.ops.fmha.small_k.BwOp` """ OPERATOR: Any SUPPORTED_DEVICES: Set[str] CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0) SUPPORTED_DTYPES: Set[torch.dtype] SUPPORTED_MAX_K: float SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)} SUPPORTS_DROPOUT: bool SUPPORTS_CUSTOM_SCALE: bool = False SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False IS_DETERMINISTIC: bool = True SUPPORTS_BMGHK: bool = False NAME: str OPERATOR_CATEGORY = "memory_efficient_attention" _TEST_BATCH_SIZES: List[int] = [1, 300] _TEST_K: List[int] = [32, 128] @classmethod def supports(cls, d: Inputs) -> bool: return not cls.not_supported_reasons(d) @classmethod def shape_not_supported_reasons( cls, Mq: int, Mkv: int, K: int, Kv: int ) -> List[str]: reasons = [] if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: reasons.append("query.shape[-1] != value.shape[-1]") if max(K, Kv) > cls.SUPPORTED_MAX_K: reasons.append( f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" ) return reasons @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: """ Returns a list of reasons why this is not supported. The kernel can run these inputs only if the returned list is empty """ reasons = cls.shape_not_supported_reasons( Mq=d.query.shape[1], Mkv=d.key.shape[1], K=d.query.shape[-1], Kv=d.query.shape[-1], ) device_type = d.query.device.type dtype = d.query.dtype if device_type not in cls.SUPPORTED_DEVICES: reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") if device_type == "cuda" and not _built_with_cuda: reasons.append("xFormers wasn't build with CUDA support") if device_type == "cuda": device_capability = torch.cuda.get_device_capability(d.device) if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: reasons.append( f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " f"but your GPU has capability {device_capability} (too old)" ) if dtype not in cls.SUPPORTED_DTYPES: reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: reasons.append(f"attn_bias type is {type(d.attn_bias)}") if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: reasons.append("dropout > 0.0") if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: reasons.append("has custom scale") # bfloat16 is only supported on A100+ # ... although the kernels can still run and give the # correct result if dtype is torch.bfloat16 and ( not device_type.startswith("cuda") ): reasons.append("bf16 is only supported on A100+ GPUs") if not cls.is_available(): reasons.append( "operator wasn't built - see `python -m xformers.info` for more info" ) if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): reasons.append( "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" ) if not cls.SUPPORTS_BMGHK and d.query.ndim == 5: reasons.append("operator does not support BMGHK format") return reasons class AttentionFwOpBase(AttentionOpBase): ERROR_ATOL: Mapping[torch.dtype, float] = { torch.float: 3e-4, torch.half: 4e-3, torch.bfloat16: 2e-2, } ERROR_RTOL: Mapping[torch.dtype, float] = { torch.float: 2e-5, torch.half: 4e-4, torch.bfloat16: 5e-3, } @classmethod def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: raise NotImplementedError() @classmethod def attn_operator_flop( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal: bool = False, seqstart_k: Optional[torch.Tensor] = None, seqstart_q: Optional[torch.Tensor] = None, ) -> int: """ Computes total flops for the attention Assumes inputs in format BMHK """ assert query.ndim == 4 if seqstart_q is not None: seqstart_q_py = seqstart_q.tolist() else: seqstart_q_py = [0, query.shape[1]] if seqstart_k is not None: seqstart_k_py = seqstart_k.tolist() else: seqstart_k_py = [0, key.shape[1]] total_flop = 0 for q_start, q_end, k_start, k_end in zip( seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] ): num_q = q_end - q_start num_kv = k_end - k_start # (M,K) @ (K,N) GEMM needs M*N*K*2 flop # Q @ K.transpose total_flop += num_q * num_kv * query.shape[-1] * 2 # (ignore softmax) # attn @ V total_flop += num_q * key.shape[-1] * num_kv * 2 # Multiply by num_heads and batches total_flop = total_flop * value.shape[2] * value.shape[0] if causal: total_flop //= 2 return total_flop class AttentionBwOpBase(AttentionOpBase): ERROR_ATOL: Mapping[torch.dtype, float] = { torch.float: 5e-4, torch.half: 9e-2, torch.bfloat16: 0.7, } ERROR_RTOL: Mapping[torch.dtype, float] = { torch.float: 1e-4, torch.half: 2e-2, torch.bfloat16: 0.1, } SUPPORTS_ATTN_BIAS_GRAD = False @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d) if ( isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.requires_grad and not cls.SUPPORTS_ATTN_BIAS_GRAD ): reasons.append( "Computing the bias gradient is not supported (attn_bias.requires_grad = True)" ) return reasons @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: raise NotImplementedError() @classmethod def attn_operator_flop( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, causal: bool = False, seqstart_k: Optional[torch.Tensor] = None, seqstart_q: Optional[torch.Tensor] = None, ) -> int: """ Computes total flops for the attention Assumes inputs in format BMHK """ assert query.ndim == 4 if seqstart_q is not None: seqstart_q_py = seqstart_q.tolist() else: seqstart_q_py = [0, query.shape[1]] if seqstart_k is not None: seqstart_k_py = seqstart_k.tolist() else: seqstart_k_py = [0, key.shape[1]] total_flop = 0 for q_start, q_end, k_start, k_end in zip( seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] ): num_q = q_end - q_start num_kv = k_end - k_start Kqk = query.shape[-1] Kv = value.shape[-1] # (M,K) @ (K,N) GEMM needs M*N*K*2 flop # att = Q @ K.transpose total_flop += num_q * num_kv * Kqk * 2 # att @ dO total_flop += num_kv * num_q * Kv * 2 # dov = dO @ V total_flop += num_q * Kv * num_kv * 2 # dov @ K total_flop += num_q * Kqk * num_kv * 2 # dov @ Q total_flop += num_q * Kqk * num_kv * 2 # Multiply by num_heads and batches total_flop = total_flop * value.shape[2] * value.shape[0] if causal: total_flop //= 2 return total_flop AttentionOp = Tuple[ Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] ] @dataclass class AttentionOpDispatch: """Dispatcher to automatically select the best operator to run memory-efficient attention. :Deprecated: This class is deprecated and will be removed in a later version """ op: AttentionOp @classmethod def from_arguments( cls, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, p: float = 0.0, scale: Optional[float] = None, ) -> "AttentionOpDispatch": """Here for backward compatibility""" from .dispatch import _dispatch_bw, _dispatch_fw inp = Inputs( query=query, key=key, value=value, attn_bias=attn_bias, p=p, scale=scale, ) return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp))) def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: if tensor.ndim == 4: return tensor return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( (0, 2, 1, 3) ) def check_lastdim_alignment_stride1( reasons: List[str], name: str, x: torch.Tensor, alignment: int ) -> None: if x.shape[-1] % alignment != 0: reasons.append(f"{name}.shape[-1] % {alignment} != 0") elif x.stride(-2) % alignment != 0: reasons.append( f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})" ) # We can have stride=0 sometimes if dimension=1 if x.stride(-1) > 1: reasons.append( f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input" )