289 lines
10 KiB
Python
289 lines
10 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import math
|
|
from enum import Enum, auto
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch.autograd.profiler import record_function
|
|
|
|
from .base import FeatureMap
|
|
|
|
"""
|
|
A set of feature maps which approximate the softmax kernel, as per the Performers_ paper.
|
|
|
|
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
|
https://arxiv.org/pdf/2009.14794v1.pdf
|
|
"""
|
|
|
|
|
|
class NormDistribution(Enum):
|
|
Xi = auto()
|
|
Uniform = auto()
|
|
|
|
|
|
class SoftMaxPositiveEstimators(FeatureMap):
|
|
def __init__(
|
|
self,
|
|
dim_features: int,
|
|
iter_before_redraw: Optional[int],
|
|
normalize_inputs: bool = False,
|
|
epsilon: float = 1e-6,
|
|
softmax_temp: float = -1,
|
|
):
|
|
super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon)
|
|
self.softmax_temp = softmax_temp
|
|
|
|
# Handle the scaling from all kernels by √m.
|
|
# This normalizes for all the feature maps involved
|
|
self.h_scale = math.log(math.sqrt(self.dim_features))
|
|
|
|
def pre_scale(self, x: torch.Tensor) -> torch.Tensor:
|
|
with record_function("feature_map::pre_scale"):
|
|
# Re-draw counting logic
|
|
if (
|
|
(
|
|
self.iter_before_redraw is not None
|
|
and self._iter_counter > self.iter_before_redraw
|
|
)
|
|
or self.features is None
|
|
or self.features.device != x.device
|
|
):
|
|
# The feature map is actually using half the dimension, we'll concatenate + and - features
|
|
self._iter_counter = 1
|
|
self.features = self._get_feature_map(
|
|
x.shape[-1], self.dim_feature_map, x.device
|
|
)
|
|
|
|
features = self.features
|
|
assert features is not None
|
|
|
|
if features.dtype != x.dtype:
|
|
self.features = features.to(x.dtype)
|
|
|
|
self._iter_counter += 1
|
|
|
|
# Normalization / softmax
|
|
if self.softmax_temp < 0:
|
|
# A = exp(QK.t/√d), so each input will be scaled by √√d
|
|
self.softmax_temp = x.shape[-1] ** -0.25
|
|
|
|
x_scaled = x * self.softmax_temp
|
|
|
|
# Compute the scaling factors in logspace, applied from within the exponential
|
|
# - dimnish possible exponential overflow
|
|
# - remove a multiply across the batch, replace by an addition
|
|
norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1)
|
|
self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon
|
|
|
|
if self.normalize_inputs:
|
|
# L0 normalize the exponential term, can be useful for numerical stability
|
|
# This ensures that features +- offset is below 1
|
|
self.offset -= norm_x_2.max(1, keepdim=True)[0]
|
|
|
|
# Return the scaled inputs, the rest depends on the kernel being used
|
|
return x_scaled
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def _get_random_ortho_matrix(
|
|
blocks: int,
|
|
dim: int,
|
|
device: torch.device,
|
|
norm_distribution: NormDistribution = NormDistribution.Uniform,
|
|
) -> torch.Tensor:
|
|
r"""
|
|
Generate a random matrix whose rows are exactly orthonormal
|
|
|
|
"How to generate random matrices from the classical compact groups", Mezzadri, 2007
|
|
https://arxiv.org/pdf/math-ph/0609050v2.pdf
|
|
|
|
.. note: the typical qr decomposition does not give uniform results, qr decomposition is not
|
|
unique and the qr decomposition routines are biased towards numerical stability. See the above
|
|
paper for more information.
|
|
|
|
.. note: this does not follow the original implementation from the Performers authors.
|
|
see docs/assets/kde plots to visualize the impact of using the R signs to correct Q
|
|
"""
|
|
|
|
H = torch.randn((blocks, dim, dim), device=device, requires_grad=False)
|
|
|
|
# Randomly scale the norms of the features, Xi distributed
|
|
if norm_distribution == NormDistribution.Xi:
|
|
# NOTE: This averages to sqrt(d)
|
|
norms = torch.sqrt(torch.einsum("...d,...d->...", H, H))
|
|
|
|
Q, R = torch.linalg.qr(H)
|
|
Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q
|
|
|
|
# Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal
|
|
if norm_distribution == NormDistribution.Xi:
|
|
return torch.diag_embed(norms) @ Q
|
|
|
|
return Q
|
|
|
|
|
|
class SMOrf(SoftMaxPositiveEstimators):
|
|
"""
|
|
"Positive random orthogonal features" softmax estimator,
|
|
SM_ort^m+, as proposed in the Performers_ paper, Lemma 1.
|
|
|
|
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
|
https://arxiv.org/pdf/2009.14794v1.pdf
|
|
"""
|
|
|
|
@torch.no_grad()
|
|
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
|
"""
|
|
Generate the projection matrix onto the random features
|
|
|
|
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
|
and not uniformally random.
|
|
"""
|
|
|
|
# Get per block random unitary matrices.
|
|
# We need enough of them to project the whole input dimension, regardless of the
|
|
# requested dimension of the features
|
|
features = self._get_random_ortho_matrix(
|
|
math.ceil(dim_input / dim_features),
|
|
dim_features,
|
|
norm_distribution=NormDistribution.Xi,
|
|
device=device,
|
|
)
|
|
|
|
return features.flatten(0, 1)[:dim_input]
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Softmax-dimension related scaling, shared for all kernels
|
|
x_scaled = super().pre_scale(x)
|
|
assert self.features is not None
|
|
|
|
# Project onto the random feature map.
|
|
x_scaled = x_scaled @ self.features
|
|
return torch.exp(x_scaled + self.offset)
|
|
|
|
|
|
class SMHyperbolic(SoftMaxPositiveEstimators):
|
|
"""
|
|
"Positive random features hyperbolic" estimator, SMHyp+,
|
|
as proposed in the Performers_ paper, Lemma 1.
|
|
|
|
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
|
https://arxiv.org/pdf/2009.14794v1.pdf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_features: int,
|
|
iter_before_redraw: Optional[int],
|
|
normalize_inputs: bool = False,
|
|
epsilon: float = 1e-6,
|
|
softmax_temp: float = -1,
|
|
):
|
|
super().__init__(
|
|
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
|
)
|
|
|
|
assert (
|
|
dim_features % 2 == 0
|
|
), "The feature dimension needs to be even with this kernel"
|
|
self.dim_feature_map = self.dim_features // 2
|
|
|
|
@torch.no_grad()
|
|
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
|
"""
|
|
Generate the projection matrix onto the random features
|
|
|
|
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
|
and not uniformally random.
|
|
"""
|
|
|
|
# Get per block random unitary matrices.
|
|
# We need enough of them to project the whole input dimension, regardless of the
|
|
# requested dimension of the features
|
|
features = self._get_random_ortho_matrix(
|
|
math.ceil(dim_input / dim_features),
|
|
dim_features,
|
|
norm_distribution=NormDistribution.Xi,
|
|
device=device,
|
|
)
|
|
|
|
return features.flatten(0, 1)[:dim_input]
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Softmax-dimension related scaling, shared for all kernels
|
|
x_scaled = super().pre_scale(x)
|
|
|
|
# Project onto the random feature map, concatenate both + and - results
|
|
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
|
# softmax kernel (cosh representation)
|
|
x_scaled = x_scaled @ self.features
|
|
return torch.cat(
|
|
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
|
dim=-1,
|
|
)
|
|
|
|
|
|
class SMReg(SoftMaxPositiveEstimators):
|
|
"""
|
|
"Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper.
|
|
|
|
_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020).
|
|
https://arxiv.org/pdf/2009.14794v1.pdf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_features: int,
|
|
iter_before_redraw: Optional[int],
|
|
normalize_inputs: bool = False,
|
|
epsilon: float = 1e-6,
|
|
softmax_temp: float = -1,
|
|
):
|
|
super().__init__(
|
|
dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp
|
|
)
|
|
|
|
assert (
|
|
dim_features % 2 == 0
|
|
), "The feature dimension needs to be even with this kernel"
|
|
self.dim_feature_map = self.dim_features // 2
|
|
|
|
@torch.no_grad()
|
|
def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device):
|
|
"""
|
|
Generate the projection matrix onto the random features
|
|
|
|
.. note: The heads dimension needs to be taken into account, hence the per-block random matrix
|
|
and not uniformally random.
|
|
"""
|
|
|
|
# Get per block random unitary matrices.
|
|
# We need enough of them to project the whole input dimension, regardless of the
|
|
# requested dimension of the features
|
|
features = self._get_random_ortho_matrix(
|
|
math.ceil(dim_input / dim_features),
|
|
dim_features,
|
|
norm_distribution=NormDistribution.Uniform,
|
|
device=device,
|
|
).flatten(0, 1)
|
|
norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device)
|
|
return (torch.diag(norms) @ features)[:dim_input]
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Softmax-dimension related scaling, shared for all kernels
|
|
x_scaled = super().pre_scale(x)
|
|
|
|
# Project onto the random feature map, concatenate both + and - results
|
|
# This follows Lemma 1 in the original Performers Paper to best approximate a
|
|
# softmax kernel (cosh representation + sample regularization)
|
|
x_scaled = x_scaled @ self.features
|
|
return torch.cat(
|
|
[torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)],
|
|
dim=-1,
|
|
)
|