36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
from torch.cuda.amp import autocast
|
|
|
|
from xformers.components.attention import Attention, AttentionConfig, register_attention
|
|
|
|
|
|
@register_attention("fourier_mix", AttentionConfig)
|
|
class FourierMix(Attention):
|
|
def __init__(self, dropout: float, *_, **__):
|
|
"""
|
|
FFT-based pseudo-attention mechanism, from
|
|
"
|
|
"FNet: Mixing Tokens with Fourier Transforms"
|
|
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
|
|
"""
|
|
super().__init__()
|
|
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
|
|
|
|
# Properties specific to this attention mechanism
|
|
self.supports_attention_mask = False
|
|
self.requires_input_projection = False
|
|
|
|
def forward(self, q: torch.Tensor, *_, **__):
|
|
# Guard against autocast / fp16, not supported by torch.fft.fft2
|
|
with autocast(enabled=False):
|
|
att = torch.fft.fft2(q).real
|
|
|
|
att = self.attn_drop(att)
|
|
|
|
return att
|