First commit
This commit is contained in:
288
pkgs/xformers/components/attention/feature_maps/softmax.py
Normal file
288
pkgs/xformers/components/attention/feature_maps/softmax.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import math
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
from .base import FeatureMap
|
||||
|
||||
"""
|
||||
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
|
||||
class NormDistribution(Enum):
|
||||
Xi = auto()
|
||||
Uniform = auto()
|
||||
|
||||
|
||||
class SoftMaxPositiveEstimators(FeatureMap):
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
|
||||
self.softmax_temp = softmax_temp
|
||||
|
||||
# Handle the scaling from all kernels by √m.
|
||||
# This normalizes for all the feature maps involved
|
||||
self.h_scale = math.log(math.sqrt(self.dim_features))
|
||||
|
||||
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with record_function("feature_map::pre_scale"):
|
||||
# Re-draw counting logic
|
||||
if (
|
||||
(
|
||||
self.iter_before_redraw is not None
|
||||
and self._iter_counter > self.iter_before_redraw
|
||||
)
|
||||
or self.features is None
|
||||
or self.features.device != x.device
|
||||
):
|
||||
# The feature map is actually using half the dimension, we'll concatenate + and - features
|
||||
self._iter_counter = 1
|
||||
self.features = self._get_feature_map(
|
||||
x.shape[-1], self.dim_feature_map, x.device
|
||||
)
|
||||
|
||||
features = self.features
|
||||
assert features is not None
|
||||
|
||||
if features.dtype != x.dtype:
|
||||
self.features = features.to(x.dtype)
|
||||
|
||||
self._iter_counter += 1
|
||||
|
||||
# Normalization / softmax
|
||||
if self.softmax_temp < 0:
|
||||
# A = exp(QK.t/√d), so each input will be scaled by √√d
|
||||
self.softmax_temp = x.shape[-1] ** -0.25
|
||||
|
||||
x_scaled = x * self.softmax_temp
|
||||
|
||||
# Compute the scaling factors in logspace, applied from within the exponential
|
||||
# - dimnish possible exponential overflow
|
||||
# - remove a multiply across the batch, replace by an addition
|
||||
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
|
||||
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
|
||||
|
||||
if self.normalize_inputs:
|
||||
# L0 normalize the exponential term, can be useful for numerical stability
|
||||
# This ensures that features +- offset is below 1
|
||||
self.offset -= norm_x_2.max(1, keepdim=True)[0]
|
||||
|
||||
# Return the scaled inputs, the rest depends on the kernel being used
|
||||
return x_scaled
|
||||
|
||||
@staticmethod
|
||||
@torch.no_grad()
|
||||
def _get_random_ortho_matrix(
|
||||
blocks: int,
|
||||
dim: int,
|
||||
device: torch.device,
|
||||
norm_distribution: NormDistribution = NormDistribution.Uniform,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Generate a random matrix whose rows are exactly orthonormal
|
||||
|
||||
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
|
||||
https://arxiv.org/pdf/math-ph/0609050v2.pdf
|
||||
|
||||
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
|
||||
unique and the qr decomposition routines are biased towards numerical stability. See the above
|
||||
paper for more information.
|
||||
|
||||
.. note: this does not follow the original implementation from the Performers authors.
|
||||
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
|
||||
"""
|
||||
|
||||
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
|
||||
|
||||
# Randomly scale the norms of the features, Xi distributed
|
||||
if norm_distribution == NormDistribution.Xi:
|
||||
# NOTE: This averages to sqrt(d)
|
||||
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
|
||||
|
||||
Q, R = torch.linalg.qr(H)
|
||||
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
|
||||
|
||||
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
|
||||
if norm_distribution == NormDistribution.Xi:
|
||||
return torch.diag_embed(norms) @ Q
|
||||
|
||||
return Q
|
||||
|
||||
|
||||
class SMOrf(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Positive random orthogonal features" softmax estimator,
|
||||
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Xi,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return features.flatten(0, 1)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
assert self.features is not None
|
||||
|
||||
# Project onto the random feature map.
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.exp(x_scaled + self.offset)
|
||||
|
||||
|
||||
class SMHyperbolic(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Positive random features hyperbolic" estimator, SMHyp+,
|
||||
as proposed in the Performers_ paper, Lemma 1.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(
|
||||
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
||||
)
|
||||
|
||||
assert (
|
||||
dim_features % 2 == 0
|
||||
), "The feature dimension needs to be even with this kernel"
|
||||
self.dim_feature_map = self.dim_features // 2
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Xi,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return features.flatten(0, 1)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
|
||||
# Project onto the random feature map, concatenate both + and - results
|
||||
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
||||
# softmax kernel (cosh representation)
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.cat(
|
||||
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
class SMReg(SoftMaxPositiveEstimators):
|
||||
"""
|
||||
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
|
||||
|
||||
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
||||
https://arxiv.org/pdf/2009.14794v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_features: int,
|
||||
iter_before_redraw: Optional[int],
|
||||
normalize_inputs: bool = False,
|
||||
epsilon: float = 1e-6,
|
||||
softmax_temp: float = -1,
|
||||
):
|
||||
super().__init__(
|
||||
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
||||
)
|
||||
|
||||
assert (
|
||||
dim_features % 2 == 0
|
||||
), "The feature dimension needs to be even with this kernel"
|
||||
self.dim_feature_map = self.dim_features // 2
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
||||
"""
|
||||
Generate the projection matrix onto the random features
|
||||
|
||||
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
||||
and not uniformally random.
|
||||
"""
|
||||
|
||||
# Get per block random unitary matrices.
|
||||
# We need enough of them to project the whole input dimension, regardless of the
|
||||
# requested dimension of the features
|
||||
features = self._get_random_ortho_matrix(
|
||||
math.ceil(dim_input / dim_features),
|
||||
dim_features,
|
||||
norm_distribution=NormDistribution.Uniform,
|
||||
device=device,
|
||||
).flatten(0, 1)
|
||||
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
|
||||
return (torch.diag(norms) @ features)[:dim_input]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Softmax-dimension related scaling, shared for all kernels
|
||||
x_scaled = super().pre_scale(x)
|
||||
|
||||
# Project onto the random feature map, concatenate both + and - results
|
||||
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
||||
# softmax kernel (cosh representation + sample regularization)
|
||||
x_scaled = x_scaled @ self.features
|
||||
return torch.cat(
|
||||
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
||||
dim=-1,
|
||||
)
|
||||
Reference in New Issue
Block a user