Files
enginex-bi_series-vllm/pkgs/xformers/ops/fmha/common.py

543 lines
19 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union
import torch
from ..._cpp_lib import _built_with_cuda
from ..common import BaseOperator
from .attn_bias import (
AttentionBias,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool:
# NoneType
if isinstance(None, attn_bias_type):
return True
if attn_bias_type in [LowerTriangularMask, torch.Tensor]:
return True
return False
@dataclass
class Inputs:
"""
Stores inputs to the `memory_efficient_attention` operators
"""
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None
p: float = 0.0
scale: Optional[float] = None
use_alibi: bool = False
alibi_mode: int = 1
imp_mode: int = 0
@property
def device(self) -> torch.device:
return self.query.device
@property
def scale_float(self) -> float:
return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale
def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.query.ndim == 5:
return self.query, self.key, self.value
if self.query.ndim == 4:
return (
self.query.unsqueeze(2),
self.key.unsqueeze(2),
self.value.unsqueeze(2),
)
if self.value.ndim == 3:
return (
self.query[:, :, None, None],
self.key[:, :, None, None],
self.value[:, :, None, None],
)
assert False
def normalize_bmhk(self) -> Tuple[int, ...]:
if self.query.ndim not in [3, 4, 5]:
raise ValueError(
f"Invalid shape for query: {self.query.shape}. "
"Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]"
", [batch, seqlen, num_heads, K], or [batch, seqlen, K]."
)
if self.value.dtype == torch.int32:
# Quantized K/V case, in which the last dims of Q and K are different.
# NB we currently don't have any implementations for quantized KV with
# SUPPORTS_DIFFERENT_VALUE_EMBED.
output_shape = tuple(self.query.shape)
else:
output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],)
# Convert from legacy format
if self.query.ndim == 3:
self.query = self.query.unsqueeze(2)
self.key = self.key.unsqueeze(2)
self.value = self.value.unsqueeze(2)
if isinstance(self.attn_bias, torch.Tensor):
if self.attn_bias.ndim != 3:
raise ValueError(
f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}"
)
self.attn_bias = self.attn_bias.unsqueeze(1)
return output_shape
def validate_inputs(self) -> None:
qkv = (self.query, self.key, self.value)
if self.query.ndim not in (3, 4, 5) or any(
x.ndim != self.query.ndim for x in qkv
):
raise ValueError(
f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if any(x.device != self.query.device for x in qkv):
raise ValueError("Query/Key/Value should all be on the same device")
quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32
non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv)
if not (quantized_dtypes or non_quantized_dtypes):
raise ValueError(
"Query/Key/Value should either all have the same dtype, or "
"(in the quantized case) Key/Value should have dtype torch.int32\n"
f" query.dtype: {self.query.dtype}\n"
f" key.dtype : {self.key.dtype}\n"
f" value.dtype: {self.value.dtype}"
)
# Biases with tensors attached are meant to be in BMHK format
# This would require to permute biases/gradients which can be expensive,
# so let's just forbid it - BMK is a legacy format anyway
if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK(
type(self.attn_bias)
):
raise ValueError(
f"Please provide inputs in BMHK format rather "
f"than BMK when using bias type `{type(self.attn_bias).__name__}`"
)
attn_bias_t: Optional[torch.Tensor] = None
if isinstance(self.attn_bias, torch.Tensor):
attn_bias_t = self.attn_bias
if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias):
attn_bias_t = self.attn_bias._bias
if self.query.ndim == 4 and attn_bias_t is not None:
expected_shape = (
self.query.shape[0],
self.query.shape[2],
self.query.shape[1],
self.key.shape[1],
)
if attn_bias_t.shape != expected_shape:
raise ValueError(
f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if isinstance(self.attn_bias, BlockDiagonalMask):
if any(x.shape[0] != 1 for x in qkv):
raise ValueError(
f"Expected batch_size=1 when using block-diagonal bias\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}"
)
if self.p < 0.0 or self.p > 1.0:
raise ValueError(f"Invalid dropout probability: p={self.p}")
# Check that shapes match between inputs
B, Mq = self.query.shape[:2]
K = self.query.shape[-1]
B, Mkv = self.key.shape[:2]
Kv = self.value.shape[-1]
valid_shapes = True
if self.query.ndim == 3: # BMK
valid_shapes = (
self.query.shape == (B, Mq, K)
and self.key.shape == (B, Mkv, K)
and self.value.shape == (B, Mkv, Kv)
)
H = self.query.shape[-2]
if self.query.ndim == 4: # BMHK
quantized_kv_cache = self.value.dtype == torch.int32
key_embed_dim = Kv if quantized_kv_cache else K
valid_shapes = (
self.query.shape == (B, Mq, H, K)
and self.key.shape == (B, Mkv, H, key_embed_dim)
and self.value.shape == (B, Mkv, H, Kv)
)
G = self.query.shape[2]
if self.query.ndim == 5: # BMNHK
valid_shapes = (
self.query.shape == (B, Mq, G, H, K)
and self.key.shape == (B, Mkv, G, H, K)
and self.value.shape == (B, Mkv, G, H, Kv)
)
if not valid_shapes:
raise ValueError(
f"Incompatible shapes for attention inputs:\n"
f" query.shape: {self.query.shape}\n"
f" key.shape : {self.key.shape}\n"
f" value.shape: {self.value.shape}\n"
"HINT: We don't support broadcasting, please use `expand` "
"yourself before calling `memory_efficient_attention` if you need to"
)
@dataclass
class Context:
lse: torch.Tensor
out: torch.Tensor
q_padded: Optional[torch.Tensor] = None
k_padded: Optional[torch.Tensor] = None
v_padded: Optional[torch.Tensor] = None
o_padded: Optional[torch.Tensor] = None
op_bw: Optional[Type["AttentionBwOpBase"]] = None
rng_state: Optional[torch.Tensor] = None
def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor:
pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to
lse = self.lse
if pad_amount > 0:
if force_pad_inf:
lse = lse[:, :, : self.out.shape[1]]
pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to
lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf)
elif force_pad_inf and self.out.shape[1] != lse.shape[2]:
lse[:, :, self.out.shape[1] :].fill_(math.inf)
return lse
@dataclass
class Gradients:
dq: torch.Tensor
dk: torch.Tensor
dv: torch.Tensor
# bias gradient. None if there is no tensor bias or if it doesn't require grad
db: Optional[torch.Tensor] = None
class AttentionOpBase(BaseOperator):
"""Base class for any attention operator in xFormers
See:
- :attr:`xformers.ops.fmha.cutlass.FwOp`
- :attr:`xformers.ops.fmha.cutlass.BwOp`
- :attr:`xformers.ops.fmha.flash.FwOp`
- :attr:`xformers.ops.fmha.flash.BwOp`
- :attr:`xformers.ops.fmha.triton.FwOp`
- :attr:`xformers.ops.fmha.triton.BwOp`
- :attr:`xformers.ops.fmha.small_k.FwOp`
- :attr:`xformers.ops.fmha.small_k.BwOp`
"""
OPERATOR: Any
SUPPORTED_DEVICES: Set[str]
CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0)
SUPPORTED_DTYPES: Set[torch.dtype]
SUPPORTED_MAX_K: float
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)}
SUPPORTS_DROPOUT: bool
SUPPORTS_CUSTOM_SCALE: bool = False
SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False
IS_DETERMINISTIC: bool = True
SUPPORTS_BMGHK: bool = False
NAME: str
OPERATOR_CATEGORY = "memory_efficient_attention"
_TEST_BATCH_SIZES: List[int] = [1, 300]
_TEST_K: List[int] = [32, 128]
@classmethod
def supports(cls, d: Inputs) -> bool:
return not cls.not_supported_reasons(d)
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = []
if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv:
reasons.append("query.shape[-1] != value.shape[-1]")
if max(K, Kv) > cls.SUPPORTED_MAX_K:
reasons.append(
f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}"
)
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
"""
Returns a list of reasons why this is not supported.
The kernel can run these inputs only if the returned list is empty
"""
reasons = cls.shape_not_supported_reasons(
Mq=d.query.shape[1],
Mkv=d.key.shape[1],
K=d.query.shape[-1],
Kv=d.query.shape[-1],
)
device_type = d.query.device.type
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
if device_type == "cuda" and not _built_with_cuda:
reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
reasons.append(
f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} "
f"but your GPU has capability {device_capability} (too old)"
)
if dtype not in cls.SUPPORTED_DTYPES:
reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})")
if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES:
reasons.append(f"attn_bias type is {type(d.attn_bias)}")
if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT:
reasons.append("dropout > 0.0")
if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE:
reasons.append("has custom scale")
# bfloat16 is only supported on A100+
# ... although the kernels can still run and give the
# correct result
if dtype is torch.bfloat16 and (
not device_type.startswith("cuda")
):
reasons.append("bf16 is only supported on A100+ GPUs")
if not cls.is_available():
reasons.append(
"operator wasn't built - see `python -m xformers.info` for more info"
)
if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled():
reasons.append(
"operator is non-deterministic, but `torch.use_deterministic_algorithms` is set"
)
if not cls.SUPPORTS_BMGHK and d.query.ndim == 5:
reasons.append("operator does not support BMGHK format")
return reasons
class AttentionFwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 3e-4,
torch.half: 4e-3,
torch.bfloat16: 2e-2,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 2e-5,
torch.half: 4e-4,
torch.bfloat16: 5e-3,
}
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# Q @ K.transpose
total_flop += num_q * num_kv * query.shape[-1] * 2
# (ignore softmax)
# attn @ V
total_flop += num_q * key.shape[-1] * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
class AttentionBwOpBase(AttentionOpBase):
ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 5e-4,
torch.half: 9e-2,
torch.bfloat16: 0.7,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
torch.float: 1e-4,
torch.half: 2e-2,
torch.bfloat16: 0.1,
}
SUPPORTS_ATTN_BIAS_GRAD = False
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d)
if (
isinstance(d.attn_bias, torch.Tensor)
and d.attn_bias.requires_grad
and not cls.SUPPORTS_ATTN_BIAS_GRAD
):
reasons.append(
"Computing the bias gradient is not supported (attn_bias.requires_grad = True)"
)
return reasons
@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
raise NotImplementedError()
@classmethod
def attn_operator_flop(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
causal: bool = False,
seqstart_k: Optional[torch.Tensor] = None,
seqstart_q: Optional[torch.Tensor] = None,
) -> int:
"""
Computes total flops for the attention
Assumes inputs in format BMHK
"""
assert query.ndim == 4
if seqstart_q is not None:
seqstart_q_py = seqstart_q.tolist()
else:
seqstart_q_py = [0, query.shape[1]]
if seqstart_k is not None:
seqstart_k_py = seqstart_k.tolist()
else:
seqstart_k_py = [0, key.shape[1]]
total_flop = 0
for q_start, q_end, k_start, k_end in zip(
seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:]
):
num_q = q_end - q_start
num_kv = k_end - k_start
Kqk = query.shape[-1]
Kv = value.shape[-1]
# (M,K) @ (K,N) GEMM needs M*N*K*2 flop
# att = Q @ K.transpose
total_flop += num_q * num_kv * Kqk * 2
# att @ dO
total_flop += num_kv * num_q * Kv * 2
# dov = dO @ V
total_flop += num_q * Kv * num_kv * 2
# dov @ K
total_flop += num_q * Kqk * num_kv * 2
# dov @ Q
total_flop += num_q * Kqk * num_kv * 2
# Multiply by num_heads and batches
total_flop = total_flop * value.shape[2] * value.shape[0]
if causal:
total_flop //= 2
return total_flop
AttentionOp = Tuple[
Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]]
]
@dataclass
class AttentionOpDispatch:
"""Dispatcher to automatically select
the best operator to run memory-efficient attention.
:Deprecated:
This class is deprecated and will be removed in a later version
"""
op: AttentionOp
@classmethod
def from_arguments(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> "AttentionOpDispatch":
"""Here for backward compatibility"""
from .dispatch import _dispatch_bw, _dispatch_fw
inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
)
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))
def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute(
(0, 2, 1, 3)
)
def check_lastdim_alignment_stride1(
reasons: List[str], name: str, x: torch.Tensor, alignment: int
) -> None:
if x.shape[-1] % alignment != 0:
reasons.append(f"{name}.shape[-1] % {alignment} != 0")
elif x.stride(-2) % alignment != 0:
reasons.append(
f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})"
)
# We can have stride=0 sometimes if dimension=1
if x.stride(-1) > 1:
reasons.append(
f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input"
)