56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import torch
|
|
|
|
from xformers.components.attention.core import scaled_dot_product_attention
|
|
|
|
|
|
class TimmSparseAttention(torch.nn.Module):
|
|
"""
|
|
Almost drop-in replacement for timm attention
|
|
but using the sparsity-aware scaled_dot_product_attention from xformers
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
num_heads=8,
|
|
qkv_bias=False,
|
|
attn_drop=0.0,
|
|
proj_drop=0.0,
|
|
attn_mask=None,
|
|
):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
|
|
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
self.attn_drop = torch.nn.Dropout(attn_drop)
|
|
self.proj = torch.nn.Linear(dim, dim)
|
|
self.proj_drop = torch.nn.Dropout(proj_drop)
|
|
self.attn_mask = attn_mask
|
|
|
|
def forward(self, x):
|
|
B, N, C = x.shape
|
|
qkv = (
|
|
self.qkv(x)
|
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
|
.permute(2, 0, 3, 1, 4)
|
|
)
|
|
qkv = qkv.flatten(1, 2)
|
|
|
|
q, k, v = qkv.unbind()
|
|
|
|
x = scaled_dot_product_attention(
|
|
q, k, v, self.attn_mask, dropout=self.attn_drop
|
|
)
|
|
x = x.reshape(B, self.num_heads, N, C // self.num_heads)
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C)
|
|
x = self.proj(x)
|
|
x = self.proj_drop(x)
|
|
return x
|