Files
enginex-bi_series-vllm/pkgs/xformers/ops/fmha/flash.py
2025-08-05 19:02:46 +08:00

667 lines
21 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
from itertools import zip_longest
from typing import Any, List, Optional, Set, Tuple, Union
import os
import torch
from ..common import _get_storage_base, get_operator, register_operator
from .attn_bias import (
AttentionBias,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
)
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
global enable_ixdnn
enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0'
FLASH_VERSION = "0.0.0"
try:
try:
from ... import _C_flashattention # type: ignore[attr-defined]
from ..._cpp_lib import _build_metadata
if _build_metadata is not None:
FLASH_VERSION = _build_metadata.flash_version
except ImportError:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
FLASH_VERSION = flash_attn.__version__
flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
if flash_ver_parsed < (2, 3):
raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
_flash_lib = torch.library.Library("xformers_flash", "DEF")
_flash_lib.define(
"flash_fwd(Tensor query, Tensor key, Tensor value, "
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"
)
_flash_lib.define(
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)"
)
def _flash_fwd(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
return_softmax,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd_ixdnn(
query,
key,
value,
None, # out
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None, # rng
)
else:
if cu_seq_lens_q is None:
assert cu_seq_lens_k is None
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.fwd(
query,
key,
value,
None, # out
p,
softmax_scale,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None, # rng
)
else:
out = query.new_empty(query.shape[0], query.shape[1], value.shape[2])
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention.varlen_fwd(
query,
key,
value,
out,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False,
is_causal,
return_softmax,
use_alibi,
alibi_mode,
None,
)
return out, softmax_lse, q_padded, k_padded, v_padded, out_padded
def _flash_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_size,
use_alibi,
alibi_mode,
imp_mode
):
if enable_ixdnn:
_C_flashattention.bwd_ixdnn(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
p,
is_causal,
use_alibi,
alibi_mode,
imp_mode,
-1, -1,
None,
)
else:
if cu_seq_lens_k is None:
assert cu_seq_lens_q is None
_C_flashattention.bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
p,
softmax_scale,
is_causal,
use_alibi,
alibi_mode,
None,
)
else:
_C_flashattention.varlen_bwd(
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
False, # zero_tensors
is_causal,
use_alibi,
alibi_mode,
None,
)
return dq, dk, dv
_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
def _convert_input_format(
inp: Inputs,
) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]:
assert inp.query.ndim in [4, 5]
query, key, value = inp.query, inp.key, inp.value
batch = query.shape[0]
seqlen_q = query.shape[1]
seqlen_kv = key.shape[1]
head_dim_q = query.shape[-1]
head_dim_v = value.shape[-1]
attn_bias = inp.attn_bias
if enable_ixdnn:
if query.shape[0] == 1:
cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device)
else:
cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device)
cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device)
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
else:
if isinstance(attn_bias, BlockDiagonalMask):
attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to(
inp.query.device, non_blocking=True
)
cu_seqlen_k = attn_bias.k_seqinfo.seqstart
cu_seqlen_q = attn_bias.q_seqinfo.seqstart
max_seqlen_q = attn_bias.q_seqinfo.max_seqlen
max_seqlen_k = attn_bias.k_seqinfo.max_seqlen
else:
cu_seqlen_k = None
cu_seqlen_q = None
max_seqlen_q = inp.query.shape[1]
max_seqlen_k = inp.key.shape[1]
if query.ndim == 5: # QGA
# Fold the group/head_in_group dimensions together
def fold(x):
# Either the head is replicated
if x.stride(3) == 0:
return x[:, :, :, 0]
# Or we reshape
return x.reshape(
[
x.shape[0],
x.shape[1],
-1,
x.shape[4],
]
)
query = fold(query)
key = fold(key)
value = fold(value)
# Optimize for MHA
if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0:
key = key[:, :, :1]
value = value[:, :, :1]
# Initially we have `query.shape = [batch, seqlen, head_dim_q]`
# We want format `[batch * seqlen, num_heads, head_dim_q]`
if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
query = query.reshape([batch * seqlen_q, -1, head_dim_q])
key = key.reshape([batch * seqlen_kv, -1, head_dim_q])
value = value.reshape([batch * seqlen_kv, -1, head_dim_v])
new_inp = replace(
inp,
query=query,
key=key,
value=value,
)
return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k
def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
return isinstance(
attn_bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
),
)
def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
if isinstance(
attn_bias,
(BlockDiagonalCausalLocalAttentionMask,),
):
return attn_bias._window_size or 0
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
return attn_bias._window_size
return 0
def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
# Flash does not support TopLeft, so only allow causal masks with TopLeft
# if each batch element has equal number of queries and keys.
if isinstance(d.attn_bias, BlockDiagonalCausalMask):
# Flash does not support TopLeft, so only allow BlockDiagonalCausalMask
# if each batch element has equal number of queries and keys.
for k_start, q_start in zip_longest(
d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py
):
if k_start != q_start:
reasons.append(
"Only support BlockDiagonalCausalMask if equal"
" numbers of keys and queries"
)
break
elif isinstance(d.attn_bias, LowerTriangularMask):
if d.query.shape[1] != d.key.shape[1]:
reasons.append(
"Only support LowerTriangularMask if equal number of" "keys and queries"
)
def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None:
"""
We want to be able to collapse the G/H dimensions together
"""
if x.ndim == 5:
stride_g, stride_h = x.stride(2), x.stride(3)
if x.shape[2] == 1:
return
if x.shape[3] == 1 or stride_h == 0:
return
if stride_g != stride_h * x.shape[-2]:
reasons.append(
f"GQA is only supported when the G/H dimensions are contiguous\n"
f" {name}.stride: {x.stride()}\n"
f" {name}.shape : {list(x.shape)}"
)
@register_operator
class FwOp(AttentionFwOpBase):
"""Operator that computes memory-efficient attention using \
`Flash-Attention <https://github.com/HazyResearch/flash-attention>`_ \
implementation.
"""
OPERATOR = get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
SUPPORTED_MAX_K = 256
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
LowerTriangularMask,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_DIFFERENT_VALUE_EMBED = False
SUPPORTS_BMGHK = True
NAME = f"flshattF@{FLASH_VERSION}"
VERSION = FLASH_VERSION
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
_check_strides_for_bmghk(d.query, "query", reasons)
_check_strides_for_bmghk(d.key, "key", reasons)
_check_strides_for_bmghk(d.value, "value", reasons)
return reasons
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
return_softmax = False
out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# no cumulative seqlen
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR(
inp.query,
inp.key,
inp.value,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
out = out.reshape(out_shape)
ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded)
if inp.p != 0.0:
ctx.op_bw = BwOp
return (out, ctx)
@classmethod
# type: ignore
def operator_flop(
cls,
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
return_softmax,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)
@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__
OPERATOR = get_operator("xformers_flash", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES
SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
IS_DETERMINISTIC = False
SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this!
NAME = f"flshattB@{FLASH_VERSION}"
VERSION = FLASH_VERSION
MAX_HEADDIM_SM8x = 192
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
# In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there
# is a strange embedding dimension.
# if K not in {8, 16, 32, 64, 128, 256}:
# reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
if d.device.type == "cuda":
# Due to limited shared-memory, some GPUs are limited in head dimension
device_capability = torch.cuda.get_device_capability(d.device)
is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)]
if (
max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x
and not is_sm80_or_sm90
):
reasons.append(
"requires a GPU with compute capability 8.0 "
f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
) = _convert_input_format(inp)
assert ctx.lse.is_contiguous
ctx_lse = ctx.lse
if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0':
assert ctx_lse.shape[2] >= max_seqlen_q
if max_seqlen_q != ctx_lse.shape[2]:
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()
kernel_out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
# Create dq,dk,dv
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
# right strides, so we can avoid a `cat`
if (
ctx.q_padded.shape[0] == ctx.k_padded.shape[0]
and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1]
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded)
):
# Create one big contiguous chunk
# This is because q, k and v usually come from a single
# output of a linear layer that is chunked.
# Creating the gradients with the right layout saves us
# a `torch.cat` call in the backward pass
chunk = torch.empty(
(*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]),
dtype=ctx.q_padded.dtype,
device=inp.device,
)
grads = Gradients(
dq=chunk.select(-3, 0),
dk=chunk.select(-3, 1),
dv=chunk.select(-3, 2),
)
else:
grads = Gradients(
dq=torch.empty_like(ctx.q_padded),
dk=torch.empty_like(ctx.k_padded),
dv=torch.empty_like(ctx.v_padded),
)
assert grad.dtype in cls.SUPPORTED_DTYPES
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
ctx.q_padded,
ctx.k_padded,
ctx.v_padded,
# ctx.out.reshape(kernel_out_shape),
ctx.o_padded,
ctx_lse,
grads.dq,
grads.dk,
grads.dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
inp.use_alibi,
inp.alibi_mode,
inp.imp_mode,
)
grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension
grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape)
grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape)
return grads
@classmethod
# type: ignore
def operator_flop(
cls,
grad,
query,
key,
value,
out,
lse,
dq,
dk,
dv,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
causal,
) -> int:
return cls.attn_operator_flop(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
causal=causal,
seqstart_k=cu_seq_lens_k,
seqstart_q=cu_seq_lens_q,
)