First commit
This commit is contained in:
55
pkgs/xformers/helpers/timm_sparse_attention.py
Normal file
55
pkgs/xformers/helpers/timm_sparse_attention.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components.attention.core import scaled_dot_product_attention
|
||||
|
||||
|
||||
class TimmSparseAttention(torch.nn.Module):
|
||||
"""
|
||||
Almost drop-in replacement for timm attention
|
||||
but using the sparsity-aware scaled_dot_product_attention from xformers
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
attn_mask=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = torch.nn.Dropout(attn_drop)
|
||||
self.proj = torch.nn.Linear(dim, dim)
|
||||
self.proj_drop = torch.nn.Dropout(proj_drop)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
.permute(2, 0, 3, 1, 4)
|
||||
)
|
||||
qkv = qkv.flatten(1, 2)
|
||||
|
||||
q, k, v = qkv.unbind()
|
||||
|
||||
x = scaled_dot_product_attention(
|
||||
q, k, v, self.attn_mask, dropout=self.attn_drop
|
||||
)
|
||||
x = x.reshape(B, self.num_heads, N, C // self.num_heads)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user