# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from dataclasses import replace from itertools import zip_longest from typing import Any, List, Optional, Set, Tuple, Union import os import torch from ..common import _get_storage_base, get_operator, register_operator from .attn_bias import ( AttentionBias, BlockDiagonalCausalFromBottomRightMask, BlockDiagonalCausalLocalAttentionFromBottomRightMask, BlockDiagonalCausalLocalAttentionMask, BlockDiagonalCausalMask, BlockDiagonalMask, LowerTriangularMask, ) from .common import ( AttentionBwOpBase, AttentionFwOpBase, Context, Gradients, Inputs, check_lastdim_alignment_stride1, ) global enable_ixdnn enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0' FLASH_VERSION = "0.0.0" try: try: from ... import _C_flashattention # type: ignore[attr-defined] from ..._cpp_lib import _build_metadata if _build_metadata is not None: FLASH_VERSION = _build_metadata.flash_version except ImportError: import flash_attn from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention FLASH_VERSION = flash_attn.__version__ flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) if flash_ver_parsed < (2, 3): raise ImportError("Requires 2.3 for sliding window support") # create library so that flash-attn goes through the PyTorch Dispatcher _flash_lib = torch.library.Library("xformers_flash", "DEF") _flash_lib.define( "flash_fwd(Tensor query, Tensor key, Tensor value, " "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " "int max_seqlen_q, int max_seqlen_k, " "float p, float softmax_scale, " "bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)" ) _flash_lib.define( "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " "int max_seqlen_q, int max_seqlen_k, " "float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)" ) def _flash_fwd( query, key, value, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, is_causal, window_size, return_softmax, use_alibi, alibi_mode, imp_mode ): if enable_ixdnn: ( out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, ) = _C_flashattention.fwd_ixdnn( query, key, value, None, # out cu_seq_lens_q, cu_seq_lens_k, p, is_causal, return_softmax, use_alibi, alibi_mode, imp_mode, -1, -1, None, # rng ) else: if cu_seq_lens_q is None: assert cu_seq_lens_k is None ( out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, ) = _C_flashattention.fwd( query, key, value, None, # out p, softmax_scale, is_causal, return_softmax, use_alibi, alibi_mode, None, # rng ) else: out = query.new_empty(query.shape[0], query.shape[1], value.shape[2]) ( out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, ) = _C_flashattention.varlen_fwd( query, key, value, out, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, False, is_causal, return_softmax, use_alibi, alibi_mode, None, ) return out, softmax_lse, q_padded, k_padded, v_padded, out_padded def _flash_bwd( grad, query, key, value, out, lse, dq, dk, dv, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, is_causal, window_size, use_alibi, alibi_mode, imp_mode ): if enable_ixdnn: _C_flashattention.bwd_ixdnn( grad, query, key, value, out, lse, dq, dk, dv, cu_seq_lens_q, cu_seq_lens_k, p, is_causal, use_alibi, alibi_mode, imp_mode, -1, -1, None, ) else: if cu_seq_lens_k is None: assert cu_seq_lens_q is None _C_flashattention.bwd( grad, query, key, value, out, lse, dq, dk, dv, p, softmax_scale, is_causal, use_alibi, alibi_mode, None, ) else: _C_flashattention.varlen_bwd( grad, query, key, value, out, lse, dq, dk, dv, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, False, # zero_tensors is_causal, use_alibi, alibi_mode, None, ) return dq, dk, dv _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") except ImportError: pass def _convert_input_format( inp: Inputs, ) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]: assert inp.query.ndim in [4, 5] query, key, value = inp.query, inp.key, inp.value batch = query.shape[0] seqlen_q = query.shape[1] seqlen_kv = key.shape[1] head_dim_q = query.shape[-1] head_dim_v = value.shape[-1] attn_bias = inp.attn_bias if enable_ixdnn: if query.shape[0] == 1: cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device) cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device) else: cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device) cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device) max_seqlen_q = inp.query.shape[1] max_seqlen_k = inp.key.shape[1] else: if isinstance(attn_bias, BlockDiagonalMask): attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to( inp.query.device, non_blocking=True ) attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to( inp.query.device, non_blocking=True ) cu_seqlen_k = attn_bias.k_seqinfo.seqstart cu_seqlen_q = attn_bias.q_seqinfo.seqstart max_seqlen_q = attn_bias.q_seqinfo.max_seqlen max_seqlen_k = attn_bias.k_seqinfo.max_seqlen else: cu_seqlen_k = None cu_seqlen_q = None max_seqlen_q = inp.query.shape[1] max_seqlen_k = inp.key.shape[1] if query.ndim == 5: # QGA # Fold the group/head_in_group dimensions together def fold(x): # Either the head is replicated if x.stride(3) == 0: return x[:, :, :, 0] # Or we reshape return x.reshape( [ x.shape[0], x.shape[1], -1, x.shape[4], ] ) query = fold(query) key = fold(key) value = fold(value) # Optimize for MHA if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0: key = key[:, :, :1] value = value[:, :, :1] # Initially we have `query.shape = [batch, seqlen, head_dim_q]` # We want format `[batch * seqlen, num_heads, head_dim_q]` if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0': query = query.reshape([batch * seqlen_q, -1, head_dim_q]) key = key.reshape([batch * seqlen_kv, -1, head_dim_q]) value = value.reshape([batch * seqlen_kv, -1, head_dim_v]) new_inp = replace( inp, query=query, key=key, value=value, ) return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool: return isinstance( attn_bias, ( LowerTriangularMask, BlockDiagonalCausalMask, BlockDiagonalCausalLocalAttentionMask, BlockDiagonalCausalFromBottomRightMask, BlockDiagonalCausalLocalAttentionFromBottomRightMask, ), ) def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: if isinstance( attn_bias, (BlockDiagonalCausalLocalAttentionMask,), ): return attn_bias._window_size or 0 if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask): return attn_bias._window_size return 0 def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: # Flash does not support TopLeft, so only allow causal masks with TopLeft # if each batch element has equal number of queries and keys. if isinstance(d.attn_bias, BlockDiagonalCausalMask): # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask # if each batch element has equal number of queries and keys. for k_start, q_start in zip_longest( d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py ): if k_start != q_start: reasons.append( "Only support BlockDiagonalCausalMask if equal" " numbers of keys and queries" ) break elif isinstance(d.attn_bias, LowerTriangularMask): if d.query.shape[1] != d.key.shape[1]: reasons.append( "Only support LowerTriangularMask if equal number of" "keys and queries" ) def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None: """ We want to be able to collapse the G/H dimensions together """ if x.ndim == 5: stride_g, stride_h = x.stride(2), x.stride(3) if x.shape[2] == 1: return if x.shape[3] == 1 or stride_h == 0: return if stride_g != stride_h * x.shape[-2]: reasons.append( f"GQA is only supported when the G/H dimensions are contiguous\n" f" {name}.stride: {x.stride()}\n" f" {name}.shape : {list(x.shape)}" ) @register_operator class FwOp(AttentionFwOpBase): """Operator that computes memory-efficient attention using \ `Flash-Attention `_ \ implementation. """ OPERATOR = get_operator("xformers_flash", "flash_fwd") SUPPORTED_DEVICES: Set[str] = {"cuda"} SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 256 SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { type(None), LowerTriangularMask, BlockDiagonalMask, BlockDiagonalCausalMask, BlockDiagonalCausalLocalAttentionMask, BlockDiagonalCausalLocalAttentionFromBottomRightMask, BlockDiagonalCausalFromBottomRightMask, } SUPPORTS_DROPOUT = True SUPPORTS_CUSTOM_SCALE = True SUPPORTS_DIFFERENT_VALUE_EMBED = False SUPPORTS_BMGHK = True NAME = f"flshattF@{FLASH_VERSION}" VERSION = FLASH_VERSION @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) _check_needs_no_topleft(d, reasons) _check_strides_for_bmghk(d.query, "query", reasons) _check_strides_for_bmghk(d.key, "key", reasons) _check_strides_for_bmghk(d.value, "value", reasons) return reasons @classmethod def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: return_softmax = False out_shape = [ *inp.query.shape[:-1], inp.value.shape[-1], ] # no cumulative seqlen ( inp, cu_seqlens_q, max_seqlen_q, cu_seqlens_k, max_seqlen_k, ) = _convert_input_format(inp) out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR( inp.query, inp.key, inp.value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, inp.p, inp.scale_float, _is_causal(inp.attn_bias), _window_size(inp.attn_bias), return_softmax, inp.use_alibi, inp.alibi_mode, inp.imp_mode, ) out = out.reshape(out_shape) ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded) if inp.p != 0.0: ctx.op_bw = BwOp return (out, ctx) @classmethod # type: ignore def operator_flop( cls, query, key, value, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, causal, return_softmax, ) -> int: return cls.attn_operator_flop( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), causal=causal, seqstart_k=cu_seq_lens_k, seqstart_q=cu_seq_lens_q, ) @register_operator class BwOp(AttentionBwOpBase): __doc__ = FwOp.__doc__ OPERATOR = get_operator("xformers_flash", "flash_bwd") SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED IS_DETERMINISTIC = False SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this! NAME = f"flshattB@{FLASH_VERSION}" VERSION = FLASH_VERSION MAX_HEADDIM_SM8x = 192 @classmethod def shape_not_supported_reasons( cls, Mq: int, Mkv: int, K: int, Kv: int ) -> List[str]: reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) # In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there # is a strange embedding dimension. # if K not in {8, 16, 32, 64, 128, 256}: # reasons.append(f"Embed dim {K} not supported") return reasons @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) _check_needs_no_topleft(d, reasons) if d.device.type == "cuda": # Due to limited shared-memory, some GPUs are limited in head dimension device_capability = torch.cuda.get_device_capability(d.device) is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)] if ( max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x and not is_sm80_or_sm90 ): reasons.append( "requires a GPU with compute capability 8.0 " f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'" ) return reasons @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape ( inp, cu_seqlens_q, max_seqlen_q, cu_seqlens_k, max_seqlen_k, ) = _convert_input_format(inp) assert ctx.lse.is_contiguous ctx_lse = ctx.lse if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0': assert ctx_lse.shape[2] >= max_seqlen_q if max_seqlen_q != ctx_lse.shape[2]: ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() kernel_out_shape = [ *inp.query.shape[:-1], inp.value.shape[-1], ] # Create dq,dk,dv # If Q/K/V come from a single QKV tensor, let's put the gradient in the # right strides, so we can avoid a `cat` if ( ctx.q_padded.shape[0] == ctx.k_padded.shape[0] and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1] and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded) and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded) ): # Create one big contiguous chunk # This is because q, k and v usually come from a single # output of a linear layer that is chunked. # Creating the gradients with the right layout saves us # a `torch.cat` call in the backward pass chunk = torch.empty( (*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]), dtype=ctx.q_padded.dtype, device=inp.device, ) grads = Gradients( dq=chunk.select(-3, 0), dk=chunk.select(-3, 1), dv=chunk.select(-3, 2), ) else: grads = Gradients( dq=torch.empty_like(ctx.q_padded), dk=torch.empty_like(ctx.k_padded), dv=torch.empty_like(ctx.v_padded), ) assert grad.dtype in cls.SUPPORTED_DTYPES cls.OPERATOR( grad.reshape(kernel_out_shape).contiguous(), ctx.q_padded, ctx.k_padded, ctx.v_padded, # ctx.out.reshape(kernel_out_shape), ctx.o_padded, ctx_lse, grads.dq, grads.dk, grads.dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, inp.p, inp.scale_float, _is_causal(inp.attn_bias), _window_size(inp.attn_bias), inp.use_alibi, inp.alibi_mode, inp.imp_mode, ) grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape) grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape) return grads @classmethod # type: ignore def operator_flop( cls, grad, query, key, value, out, lse, dq, dk, dv, cu_seq_lens_q, cu_seq_lens_k, max_seq_len_q, max_seq_len_k, p, softmax_scale, causal, ) -> int: return cls.attn_operator_flop( query.unsqueeze(0), key.unsqueeze(0), value.unsqueeze(0), causal=causal, seqstart_k=cu_seq_lens_k, seqstart_q=cu_seq_lens_q, )