Files
enginex-bi_series-vllm/pkgs/xformers/components/attention/fourier_mix.py
2025-08-05 19:02:46 +08:00

36 lines
1.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
@register_attention("fourier_mix", AttentionConfig)
class FourierMix(Attention):
def __init__(self, dropout: float, *_, **__):
"""
FFT-based pseudo-attention mechanism, from
"
"FNet: Mixing Tokens with Fourier Transforms"
Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf
"""
super().__init__()
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.requires_input_projection = False
def forward(self, q: torch.Tensor, *_, **__):
# Guard against autocast / fp16, not supported by torch.fft.fft2
with autocast(enabled=False):
att = torch.fft.fft2(q).real
att = self.attn_drop(att)
return att