Files
enginex-bi_series-vllm/pkgs/xformers/components/attention/feature_maps/softmax.py
2025-08-05 19:02:46 +08:00

289 lines
10 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum, auto
from typing import Optional
import torch
from torch.autograd.profiler import record_function
from .base import FeatureMap
"""
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
class NormDistribution(Enum):
Xi = auto()
Uniform = auto()
class SoftMaxPositiveEstimators(FeatureMap):
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
self.softmax_temp = softmax_temp
# Handle the scaling from all kernels by √m.
# This normalizes for all the feature maps involved
self.h_scale = math.log(math.sqrt(self.dim_features))
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
with record_function("feature_map::pre_scale"):
# Re-draw counting logic
if (
(
self.iter_before_redraw is not None
and self._iter_counter > self.iter_before_redraw
)
or self.features is None
or self.features.device != x.device
):
# The feature map is actually using half the dimension, we'll concatenate + and - features
self._iter_counter = 1
self.features = self._get_feature_map(
x.shape[-1], self.dim_feature_map, x.device
)
features = self.features
assert features is not None
if features.dtype != x.dtype:
self.features = features.to(x.dtype)
self._iter_counter += 1
# Normalization / softmax
if self.softmax_temp < 0:
# A = exp(QK.t/√d), so each input will be scaled by √√d
self.softmax_temp = x.shape[-1] ** -0.25
x_scaled = x * self.softmax_temp
# Compute the scaling factors in logspace, applied from within the exponential
# - dimnish possible exponential overflow
# - remove a multiply across the batch, replace by an addition
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
if self.normalize_inputs:
# L0 normalize the exponential term, can be useful for numerical stability
# This ensures that features +- offset is below 1
self.offset -= norm_x_2.max(1, keepdim=True)[0]
# Return the scaled inputs, the rest depends on the kernel being used
return x_scaled
@staticmethod
@torch.no_grad()
def _get_random_ortho_matrix(
blocks: int,
dim: int,
device: torch.device,
norm_distribution: NormDistribution = NormDistribution.Uniform,
) -> torch.Tensor:
r"""
Generate a random matrix whose rows are exactly orthonormal
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
https://arxiv.org/pdf/math-ph/0609050v2.pdf
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
unique and the qr decomposition routines are biased towards numerical stability. See the above
paper for more information.
.. note: this does not follow the original implementation from the Performers authors.
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
"""
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
# Randomly scale the norms of the features, Xi distributed
if norm_distribution == NormDistribution.Xi:
# NOTE: This averages to sqrt(d)
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
Q, R = torch.linalg.qr(H)
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
if norm_distribution == NormDistribution.Xi:
return torch.diag_embed(norms) @ Q
return Q
class SMOrf(SoftMaxPositiveEstimators):
"""
"Positive random orthogonal features" softmax estimator,
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
assert self.features is not None
# Project onto the random feature map.
x_scaled = x_scaled @ self.features
return torch.exp(x_scaled + self.offset)
class SMHyperbolic(SoftMaxPositiveEstimators):
"""
"Positive random features hyperbolic" estimator, SMHyp+,
as proposed in the Performers_ paper, Lemma 1.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Xi,
device=device,
)
return features.flatten(0, 1)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)
class SMReg(SoftMaxPositiveEstimators):
"""
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
https://arxiv.org/pdf/2009.14794v1.pdf
"""
def __init__(
self,
dim_features: int,
iter_before_redraw: Optional[int],
normalize_inputs: bool = False,
epsilon: float = 1e-6,
softmax_temp: float = -1,
):
super().__init__(
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
)
assert (
dim_features % 2 == 0
), "The feature dimension needs to be even with this kernel"
self.dim_feature_map = self.dim_features // 2
@torch.no_grad()
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
"""
Generate the projection matrix onto the random features
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
and not uniformally random.
"""
# Get per block random unitary matrices.
# We need enough of them to project the whole input dimension, regardless of the
# requested dimension of the features
features = self._get_random_ortho_matrix(
math.ceil(dim_input / dim_features),
dim_features,
norm_distribution=NormDistribution.Uniform,
device=device,
).flatten(0, 1)
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
return (torch.diag(norms) @ features)[:dim_input]
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Softmax-dimension related scaling, shared for all kernels
x_scaled = super().pre_scale(x)
# Project onto the random feature map, concatenate both + and - results
# This follows Lemma 1 in the original Performers Paper to best approximate a
# softmax kernel (cosh representation + sample regularization)
x_scaled = x_scaled @ self.features
return torch.cat(
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
dim=-1,
)