325 lines
12 KiB
Python
325 lines
12 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
#
|
||
# This source code is licensed under the BSD license found in the
|
||
# LICENSE file in the root directory of this source tree.
|
||
|
||
|
||
import logging
|
||
from dataclasses import dataclass
|
||
from enum import Enum
|
||
from typing import Optional, Union
|
||
|
||
import torch
|
||
import torch.autograd.profiler as profiler
|
||
import torch.nn as nn
|
||
import torch.nn.functional as Fn
|
||
|
||
from xformers.components.attention import (
|
||
Attention,
|
||
AttentionConfig,
|
||
AttentionMask,
|
||
register_attention,
|
||
)
|
||
from xformers.components.attention.core import (
|
||
scaled_dot_product_attention,
|
||
scaled_query_key_softmax,
|
||
)
|
||
|
||
logger = logging.getLogger("xformers")
|
||
|
||
|
||
class LandmarkSelection(str, Enum):
|
||
Orthogonal = "orthogonal"
|
||
KMeans = "kmeans"
|
||
KMeans_Spherical = "kmeans_spherical"
|
||
Random = "random"
|
||
|
||
|
||
@dataclass
|
||
class OrthoformerAttentionConfig(AttentionConfig):
|
||
"""
|
||
num_landmarks Number of landmarks to use for softmax approximation.
|
||
subsample_fraction Percentage of q_samples matrix to sample per iteration
|
||
landmark_selection Landmark selection strategy
|
||
"""
|
||
|
||
num_landmarks: Optional[int]
|
||
subsample_fraction: Optional[float]
|
||
landmark_selection: Optional[LandmarkSelection]
|
||
|
||
|
||
@register_attention("orthoformer", OrthoformerAttentionConfig)
|
||
class OrthoFormerAttention(Attention):
|
||
def __init__(
|
||
self,
|
||
dropout: float,
|
||
num_landmarks: int = 32,
|
||
subsample_fraction: float = 1.0,
|
||
landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal,
|
||
*args,
|
||
**kwargs,
|
||
):
|
||
"""
|
||
Orthoformer_ attention mechanism.
|
||
::
|
||
|
||
"Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers"
|
||
Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer,
|
||
C., Vedaldi, A., Henriques, J. (2021)
|
||
|
||
Reference codebase: https://github.com/facebookresearch/Motionformer
|
||
|
||
.. _Orthoformer: https://arxiv.org/abs/2106.05392
|
||
|
||
"""
|
||
super().__init__()
|
||
|
||
self.num_landmarks = num_landmarks
|
||
self.attn_drop = nn.Dropout(dropout)
|
||
self.subsample_fraction = subsample_fraction
|
||
self.landmark_selection = landmark_selection
|
||
|
||
# Properties specific to this attention mechanism
|
||
self.supports_attention_mask = True
|
||
self.supports_key_padding_mask = False
|
||
|
||
def forward(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
|
||
*args,
|
||
**kwargs,
|
||
):
|
||
N = k.shape[1]
|
||
|
||
if self.num_landmarks == N:
|
||
# Default attention
|
||
x = scaled_dot_product_attention(q, k, v, att_mask)
|
||
else:
|
||
with torch.no_grad(), profiler.record_function("select landmarks"):
|
||
if self.landmark_selection == LandmarkSelection.Orthogonal:
|
||
landmarks = self._compute_orthogonal_landmarks(q)
|
||
elif self.landmark_selection == LandmarkSelection.Random:
|
||
half_L = self.num_landmarks // 2
|
||
landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :]
|
||
landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :]
|
||
landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2)
|
||
elif self.landmark_selection == LandmarkSelection.KMeans:
|
||
landmarks = self._cluster_landmarks(q)
|
||
elif self.landmark_selection == LandmarkSelection.KMeans_Spherical:
|
||
landmarks = self._cluster_landmarks(q, spherical=True)
|
||
|
||
if att_mask is not None:
|
||
logger.warning(
|
||
"Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \
|
||
The two are typically not compatible"
|
||
)
|
||
# FIXME: Should we still accept a mask in that case ?
|
||
att_mask = None
|
||
|
||
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
||
# like it could be uninitialized.
|
||
kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask)
|
||
# pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems
|
||
# like it could be uninitialized.
|
||
kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask)
|
||
x = torch.matmul(kernel_1, torch.matmul(kernel_2, v))
|
||
x = self.attn_drop(x)
|
||
return x
|
||
|
||
def _cluster_landmarks(
|
||
self,
|
||
q: torch.Tensor,
|
||
spherical: bool = False,
|
||
num_iters: int = 6,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Construct set of landmarks by recursively selecting new landmarks
|
||
that are maximally orthogonal to the existing set.
|
||
Returns near orthogonal landmarks with shape (B, M, D).
|
||
"""
|
||
|
||
num_landmarks = min(self.num_landmarks, q.shape[1])
|
||
|
||
if self.subsample_fraction < 1.0:
|
||
num_samples = max(
|
||
int(self.subsample_fraction * q.size(-2)), num_landmarks
|
||
) # Need at least M/2 samples of queries and keys
|
||
q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D)
|
||
else:
|
||
q_samples = q # (B, N, D)
|
||
|
||
if spherical:
|
||
q_samples_normalized = Fn.normalize(
|
||
q_samples, p=2, dim=-1
|
||
) # may need to change default eps to eps=1e-8 for mixed precision compatibility
|
||
landmarks = self._kmeans_spherical(
|
||
q_samples_normalized, num_landmarks, num_iters
|
||
)
|
||
else:
|
||
landmarks = self._kmeans(q_samples, num_landmarks, num_iters)
|
||
return landmarks # (B, M, D)
|
||
|
||
def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10):
|
||
"""
|
||
Arguments:
|
||
x: (B, N, D)
|
||
K: number of clusters
|
||
num_iters: the number of kmeans updates
|
||
"""
|
||
|
||
B, N, D = x.size()
|
||
assert K <= N, f"{K} > {N}"
|
||
|
||
c = x[
|
||
:, torch.randperm(N, device=x.device)[:K], :
|
||
].clone() # initialisation for the centroids
|
||
|
||
with profiler.record_function("kmeans"):
|
||
x_i = x.view(B, N, 1, D)
|
||
c_j = c.view(B, 1, K, D)
|
||
counts = c.new_zeros(B, K)
|
||
ones = x.new_ones((B, N))
|
||
|
||
for _ in range(num_iters):
|
||
# E step: assign points to the nearest cluster
|
||
D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances
|
||
cl = D_ij.argmin(
|
||
dim=-1, keepdim=True
|
||
).long() # (B, N, 1) index of point to nearest cluster
|
||
|
||
# M step: update the centroids
|
||
c.zero_()
|
||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
||
counts.fill_(1e-6) # avoid div0
|
||
counts.scatter_add_(
|
||
-1, cl.squeeze(-1), ones
|
||
) # number of points per cluster
|
||
c.divide_(counts.unsqueeze(-1)) # compute the average
|
||
|
||
return c
|
||
|
||
def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10):
|
||
"""
|
||
Arguments:
|
||
x: (B, N, D)
|
||
"""
|
||
B, N, D = x.size()
|
||
assert K <= N, f"{K} > {N}"
|
||
|
||
# initialisation for the centroids
|
||
c = x[:, torch.randperm(N, device=x.device)[:K], :].clone()
|
||
|
||
with profiler.record_function("kmeans_spherical"):
|
||
counts = c.new_zeros(B, K)
|
||
ones = x.new_ones((B, N))
|
||
|
||
for _ in range(num_iters):
|
||
# E step: assign points to the nearest cluster
|
||
D_ij = torch.matmul(
|
||
x, c.transpose(-2, -1)
|
||
) # (B, N, K) cosine similarity
|
||
cl = D_ij.argmax(
|
||
dim=-1, keepdim=True
|
||
).long() # (B, N, 1) index of point to nearest cluster
|
||
|
||
# M step: update the centroids
|
||
c.zero_()
|
||
c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster
|
||
counts.fill_(1e-6) # avoid div0
|
||
counts.scatter_add_(
|
||
-1, cl.squeeze(-1), ones
|
||
) # number of points per cluster
|
||
c.divide_(counts.unsqueeze(-1)) # compute the average
|
||
c = Fn.normalize(c, p=2, dim=-1) # renormalise
|
||
return c
|
||
|
||
def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Construct set of landmarks by recursively selecting new landmarks
|
||
that are maximally orthogonal to the existing set.
|
||
Returns near orthogonal landmarks with shape (B, M, D).
|
||
"""
|
||
|
||
if self.subsample_fraction < 1.0:
|
||
# Need at least M samples of queries
|
||
num_samples = max(
|
||
int(self.subsample_fraction * q.size(-2)), self.num_landmarks
|
||
)
|
||
q_samples = q[
|
||
:, torch.randint(q.size(-2), (num_samples,), device=q.device), :
|
||
]
|
||
else:
|
||
# (B, N, D)
|
||
q_samples = q
|
||
|
||
# may need to change default eps to eps=1e-8 for mixed precision compatibility
|
||
q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1)
|
||
B, N, D = q_samples_normalized.shape
|
||
|
||
selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device)
|
||
landmark_mask = torch.ones(
|
||
(B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device
|
||
)
|
||
|
||
# Get initial random landmark
|
||
random_idx = torch.randint(
|
||
q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device
|
||
)
|
||
selected_mask.scatter_(-2, random_idx, landmark_mask)
|
||
|
||
# Selected landmarks
|
||
selected_landmarks = torch.empty(
|
||
(B, self.num_landmarks, D),
|
||
device=q_samples_normalized.device,
|
||
dtype=q_samples_normalized.dtype,
|
||
)
|
||
selected_landmarks[:, 0, :] = q_samples_normalized[
|
||
torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), :
|
||
].view(B, D)
|
||
|
||
# Store computed cosine similarities
|
||
cos_sims = torch.empty(
|
||
(B, N, self.num_landmarks),
|
||
device=q_samples_normalized.device,
|
||
dtype=q_samples_normalized.dtype,
|
||
)
|
||
|
||
for M in range(1, self.num_landmarks):
|
||
with profiler.record_function("find new landmark"):
|
||
# Calculate absolute cosine similarity between selected and unselected landmarks
|
||
# (B, N, D) * (B, D) -> (B, N)
|
||
cos_sims[:, :, M - 1] = torch.einsum(
|
||
"b n d, b d -> b n",
|
||
q_samples_normalized,
|
||
selected_landmarks[:, M - 1, :],
|
||
).abs()
|
||
|
||
# (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys
|
||
cos_sim_set = cos_sims[:, :, :M]
|
||
|
||
# Get orthogonal landmark: landmark with smallest absolute cosine similarity:
|
||
# set cosine similarity for already selected landmarks to > 1
|
||
cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10
|
||
|
||
# (B,) - want max for non
|
||
selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1)
|
||
|
||
# Add most orthogonal landmark to selected landmarks:
|
||
selected_landmarks[:, M, :] = q_samples_normalized[
|
||
torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, :
|
||
].view(B, D)
|
||
|
||
# Removed selected indices from non-selected mask:
|
||
selected_mask.scatter_(
|
||
-2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask
|
||
)
|
||
|
||
# (B, M, D)
|
||
landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape(
|
||
B, -1, D
|
||
)
|
||
return landmarks # (B, M, D)
|