Files
enginex-bi_series-vllm/pkgs/xformers/helpers/timm_sparse_attention.py
2025-08-05 19:02:46 +08:00

56 lines
1.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
from xformers.components.attention.core import scaled_dot_product_attention
class TimmSparseAttention(torch.nn.Module):
"""
Almost drop-in replacement for timm attention
but using the sparsity-aware scaled_dot_product_attention from xformers
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.0,
proj_drop=0.0,
attn_mask=None,
):
super().__init__()
self.num_heads = num_heads
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
self.proj = torch.nn.Linear(dim, dim)
self.proj_drop = torch.nn.Dropout(proj_drop)
self.attn_mask = attn_mask
def forward(self, x):
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
qkv = qkv.flatten(1, 2)
q, k, v = qkv.unbind()
x = scaled_dot_product_attention(
q, k, v, self.attn_mask, dropout=self.attn_drop
)
x = x.reshape(B, self.num_heads, N, C // self.num_heads)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x