48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from xformers.components import Activation, build_activation
|
|
from xformers.components.feedforward import Feedforward, FeedforwardConfig
|
|
|
|
from . import register_feedforward
|
|
|
|
|
|
@dataclass
|
|
class MlpConfig(FeedforwardConfig):
|
|
hidden_layer_multiplier: int
|
|
bias: bool
|
|
|
|
|
|
@register_feedforward("MLP", MlpConfig)
|
|
class MLP(Feedforward):
|
|
def __init__(
|
|
self,
|
|
dim_model: int,
|
|
dropout: float,
|
|
activation: Activation,
|
|
hidden_layer_multiplier: int,
|
|
bias: bool = True,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
dim_mlp = hidden_layer_multiplier * dim_model
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
|
|
build_activation(activation),
|
|
nn.Dropout(dropout),
|
|
nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
return self.mlp(inputs)
|