First commit
This commit is contained in:
47
pkgs/xformers/components/feedforward/mlp.py
Normal file
47
pkgs/xformers/components/feedforward/mlp.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.components.feedforward import Feedforward, FeedforwardConfig
|
||||
|
||||
from . import register_feedforward
|
||||
|
||||
|
||||
@dataclass
|
||||
class MlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
bias: bool
|
||||
|
||||
|
||||
@register_feedforward("MLP", MlpConfig)
|
||||
class MLP(Feedforward):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dropout: float,
|
||||
activation: Activation,
|
||||
hidden_layer_multiplier: int,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
dim_mlp = hidden_layer_multiplier * dim_model
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
|
||||
build_activation(activation),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(inputs)
|
||||
Reference in New Issue
Block a user