Files
enginex-vllm-bi100-qwen36/pkgs/xformers/components/attention/core.py
2025-08-05 19:02:46 +08:00

343 lines
10 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from contextlib import nullcontext
from functools import lru_cache
from typing import Optional, Union
import torch
from xformers import _has_cpp_library, _is_triton_available
from xformers.components.attention.attention_mask import AttentionMask
if _has_cpp_library:
from ._sputnik_sparse import SparseCS
if _is_triton_available():
from xformers.triton.softmax import softmax as triton_softmax
from xformers.triton.utils import gpu_capabilities_older_than_70
_is_blocksparse_available = (
_is_triton_available() and not gpu_capabilities_older_than_70()
)
if _is_blocksparse_available:
from xformers.components.attention.blocksparse import BlockSparseAttention
logger = logging.getLogger("xformers")
def _create_random_sparsity(matrix, sparsity, divisible_by=4):
assert matrix.ndim == 3
keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity
nonzero = torch.nonzero(keep)
nnz = nonzero.shape[0]
# NOTE: need to make it a multiple of 4 for sputnik
nonzero = nonzero[: (nnz - nnz % divisible_by)]
i, j = nonzero.unbind(1)
output = torch.zeros_like(matrix)
bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None]
output[bdim, i, j] = matrix[bdim, i, j]
return output
def _broadcast_batch(mask, batch_size):
if mask.ndim == 3:
return mask
assert mask.ndim == 2
mask = mask.coalesce()
values = mask.values()
indices = mask.indices()
nnz = len(values)
# strategy: repeat the indices and append the extra batch dimension to the indices
indices = indices.repeat(1, batch_size)
# now create the batch indices
batch_indices = torch.arange(batch_size, device=indices.device)
batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten()
# put them together
indices = torch.cat([batch_indices[None, :], indices], dim=0)
# now repeat the values
values = values.repeat(batch_size)
size = (batch_size,) + mask.shape
return torch.sparse_coo_tensor(indices, values, size)
def _matmul_with_mask(
a: torch.Tensor,
b: torch.Tensor,
mask: Optional[Union[torch.Tensor, "SparseCS"]],
) -> torch.Tensor:
if mask is None:
return a @ b
if _has_cpp_library and mask.dtype == torch.bool:
if isinstance(mask, SparseCS):
return mask.matmul_with_mask(a, b)
if mask.is_sparse:
# perform broadcasting if needed
mask = _broadcast_batch(mask, a.shape[0])
# coalesced is not implemented for bool tensors, so need to cast
mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above
return torch.ops.xformers.matmul_with_mask(a, b, mask)
# Non optimized codepath
if _has_cpp_library:
assert not isinstance(mask, SparseCS)
att = a @ b
if mask.dtype == torch.bool:
assert not isinstance(mask, SparseCS)
if mask.ndim == 2:
mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1)
# mask is presumed false == ignore
att[~mask] = float("-inf")
else:
# mask is presumed additive
# repeat if batch sizes don't match
if (
not isinstance(mask, SparseCS)
and mask.ndim == 3
and mask.shape[0] != att.shape[0]
and (att.shape[0] % mask.shape[0]) == 0
):
repeat_factor = att.shape[0] // mask.shape[0]
mask = mask.repeat([repeat_factor, 1, 1])
logger.info("Mismatched batch dimensions for mask, repeating mask.")
att += mask
return att
def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor:
if _has_cpp_library and isinstance(a, SparseCS):
return a.softmax()
if a.is_sparse:
return torch.sparse.softmax(a, dim=a.ndim - 1)
if _is_triton_available():
return triton_softmax(a, mask=None, causal=causal)
else:
return torch.softmax(a, dim=a.ndim - 1)
if _has_cpp_library:
class SparseBMM(torch.autograd.Function):
@staticmethod
def forward(ctx, a, b):
a = a.coalesce()
r = torch.bmm(a, b)
ctx.save_for_backward(a, b)
return r
@staticmethod
def backward(ctx, grad):
a, b = ctx.saved_tensors
# gradients w.r.t. a
ga = None
if ctx.needs_input_grad[0]:
ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a)
# gradients w.r.t. b
gb = None
if ctx.needs_input_grad[1]:
gb = a.transpose(1, 2).bmm(grad)
return ga, gb
def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Batch matrix multiply between a sparse matrix and a dense matrix
"""
assert a.ndim == b.ndim == 3
assert a.shape[0] == b.shape[0]
assert a.shape[2] == b.shape[1]
return SparseBMM.apply(a, b)
def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
if _has_cpp_library:
if isinstance(a, SparseCS):
return a.spmm(b)
if a.is_sparse:
return _sparse_bmm(a, b)
return a @ b
def _apply_dropout(att, dropout):
if dropout is None:
return att
# Dropout chokes on sparse tensors
if _has_cpp_library:
if isinstance(att, SparseCS):
values = att.values.clone()
values = dropout(values)
att = SparseCS.wrap(
att.shape,
values,
att.row_indices,
att.row_offsets,
att.column_indices,
att._transp_info,
)
elif att.is_sparse:
att = att.coalesce()
values = att.values().clone() # protect against in-place dropout
values = dropout(values)
att = torch.sparse_coo_tensor(att.indices(), values, att.shape)
else:
# Simple dense case
att = dropout(att)
return att
# Non optimized vanilla dropout
att = dropout(att)
return att
def scaled_query_key_softmax(
q: torch.Tensor,
k: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
) -> torch.Tensor:
# TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh
# this is needed due to limitations in sparse_bmm for now
# Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S)
q = q / math.sqrt(k.size(-1))
# Matmul with mask
if att_mask is not None and isinstance(att_mask, AttentionMask):
# Additive mask
mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values
else:
mask = att_mask
att = _matmul_with_mask(q, k.transpose(-2, -1), mask)
# Softmax to get the attention probabilities
is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal
att = _softmax(att, causal=is_causal)
return att
if _is_blocksparse_available:
# 128 is default maxsize
@lru_cache(maxsize=128)
def _retrieve_blocksparse(
num_heads: int, seq_len: int, block_size: int
) -> BlockSparseAttention:
# Checks if blocksparse object exists in cache
blocks = seq_len // block_size
layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long)
return BlockSparseAttention(
layout=layout_fill, block_size=block_size, causal=True
)
def blocksparse_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
orig_dim = q.dim()
seq_len = q.shape[-2]
# Layout head dimension: 1 or batch size (q.shape[0])
layout_heads = 1
# TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size?
assert seq_len % block_size == 0, "Sequence length must be divisible by block size"
if orig_dim == 3:
# Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh
# Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs)
if layout_heads == 1:
q = q.unsqueeze(1)
k = k.unsqueeze(1)
v = v.unsqueeze(1)
else:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size)
# Dropout is a no-op in evaluation mode
if isinstance(dropout, torch.nn.Dropout):
blocksparse_attention.attn_drop = dropout
else:
blocksparse_attention.attn_drop = torch.nn.Dropout(0.0)
att = blocksparse_attention(q, k, v)
# Reshape attention (B, nh, S, hs) back to (N, S, hs)
if orig_dim == 3:
return att.flatten(0, 1)
return att
def scaled_dot_product_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]],
dropout: Optional[torch.nn.Module] = None,
block_size: int = 128,
) -> torch.Tensor:
autocast_disabled = (
_has_cpp_library
and isinstance(att_mask, SparseCS)
or (att_mask is not None and att_mask.is_sparse)
)
seq_len = q.shape[-2]
# switch if:
# causal is required but mask is not sparse
# fp16 or under amp context
# sequence length is divisible by block size
# same seq len for K and Q
switch_to_blocksparse = (
_is_blocksparse_available
and (att_mask is not None and not att_mask.is_sparse)
and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)
and (q.dtype == torch.float16 or torch.is_autocast_enabled())
and not seq_len % block_size
and q.shape[-2] == k.shape[-2]
)
if switch_to_blocksparse:
logger.info("Switching causal attention to Triton blocksparse...")
return blocksparse_attention(q, k, v, dropout, block_size)
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore
if autocast_disabled:
q, k, v = q.float(), k.float(), v.float()
att = scaled_query_key_softmax(q, k, att_mask=att_mask)
# Optional dropout, could be part of the masking in the future
att = _apply_dropout(att, dropout)
# Get to the predicted values, for all heads
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
y = bmm(att, v)
return y