Files
enginex-bi_series-vllm/pkgs/xformers/components/attention/utils.py
2025-08-05 19:02:46 +08:00

109 lines
3.7 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len)
def reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batched_dim: int
) -> torch.Tensor:
assert key_padding_mask.ndim == 2
batch_size, src_len = key_padding_mask.size()
num_heads = batched_dim // batch_size
return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads)
def _reshape_key_padding_mask(
key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int
) -> torch.Tensor:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = (
key_padding_mask.view(batch_size, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(batch_size * num_heads, 1, src_len)
)
return key_padding_mask
# Combine the attention mask and key padding mask into a single mask
# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py
# Additive masking not yet supported
def maybe_merge_masks(
att_mask: Optional[torch.Tensor],
key_padding_mask: Optional[torch.Tensor],
batch_size: int,
src_len: int,
num_heads: int,
tgt_len: Optional[int] = None,
) -> Optional[torch.Tensor]:
if tgt_len is None:
tgt_len = src_len
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, src_len)
key_padding_mask = _reshape_key_padding_mask(
key_padding_mask, batch_size, src_len, num_heads
)
if att_mask is None:
# make sure dimensions of key padding mask are the same as those expected for att_mask
att_mask = key_padding_mask.expand(-1, tgt_len, -1)
# Assumption is that False means to mask.
elif att_mask.dtype == torch.bool:
att_mask = att_mask.logical_and(key_padding_mask)
else:
att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf"))
return att_mask
# Assumes that matrix passed in has had softmax applied to it.
def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False):
"""
Computing the Moore-Penrose inverse.
Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient
matrix-matrix multiplications.
"""
i = torch.eye(
softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype
)
k = softmax_mat
# The entries of K are positive and ||K||_{\infty} = 1 due to softmax
if pinverse_original_init:
# This original implementation is more conservative to compute coefficient of Z_0.
v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2)
else:
# This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster
# convergence.
v = (
1
/ torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None]
* k.transpose(-1, -2)
)
for _ in range(n_iter):
kv = torch.matmul(k, v)
v = torch.matmul(
0.25 * v,
13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)),
)
return v
def bool_mask_to_additive(
mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32
) -> torch.Tensor:
assert (
mask.dtype == torch.bool
), "This util is meant to convert in between bool masks and additive ones"
mask_ = torch.zeros_like(mask, dtype=dtype)
mask_[~mask] = float("-inf")
return mask_