First commit
This commit is contained in:
97
pkgs/xformers/components/feedforward/conv_mlp.py
Normal file
97
pkgs/xformers/components/feedforward/conv_mlp.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: Largely reusing the code from the reference VAN implementation
|
||||
# see https://github.com/Visual-Attention-Network
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.components.feedforward import Feedforward, FeedforwardConfig
|
||||
|
||||
from . import register_feedforward
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvMlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
dim_model: int
|
||||
dim_model_out: Optional[int]
|
||||
act_layer: Activation
|
||||
dropout: float
|
||||
|
||||
|
||||
@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
|
||||
class Conv2DFeedforward(Feedforward):
|
||||
"""
|
||||
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
|
||||
|
||||
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
hidden_layer_multiplier: int = 1,
|
||||
dim_model_out: Optional[int] = None,
|
||||
activation: Activation = Activation.GeLU,
|
||||
dropout=0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = dim_model_out or dim_model
|
||||
hidden_features = hidden_layer_multiplier * dim_model
|
||||
|
||||
self.conv_mlp = nn.Sequential(
|
||||
nn.Conv2d(dim_model, hidden_features, 1),
|
||||
nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features,
|
||||
),
|
||||
build_activation(activation),
|
||||
nn.Conv2d(hidden_features, out_features, 1),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
# This feedforward requires a context length which is squared, often due to 2D pooling
|
||||
self.requires_squared_context = True
|
||||
|
||||
def init_weights(self, **kwargs):
|
||||
# Follow the original init, but also make it possible to initialize from the outside
|
||||
def init_module(m: nn.Module):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
self.apply(init_module)
|
||||
|
||||
def forward(self, x):
|
||||
# The conv layers expect NCHW, we have NLC by default
|
||||
B, L, C = x.shape
|
||||
HW = int(math.sqrt(x.shape[-2]))
|
||||
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"
|
||||
|
||||
x = x.reshape((B, HW, HW, C)).swapdims(1, -1)
|
||||
|
||||
# The actual FW, including the 2d convolutions
|
||||
x = self.conv_mlp(x)
|
||||
|
||||
# back to NLC
|
||||
x = x.transpose(1, -1)
|
||||
return x.flatten(1, 2)
|
||||
Reference in New Issue
Block a user