80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from xformers.components import Activation
|
|
from xformers.components.feedforward import (
|
|
Feedforward,
|
|
FeedforwardConfig,
|
|
register_feedforward,
|
|
)
|
|
|
|
logger = logging.getLogger("xformers")
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
try:
|
|
from xformers.triton import FusedDropoutBias
|
|
|
|
@dataclass
|
|
class FusedMlpConfig(FeedforwardConfig):
|
|
hidden_layer_multiplier: int
|
|
|
|
@register_feedforward("FusedMLP", FusedMlpConfig)
|
|
class FusedMLP(Feedforward):
|
|
"""
|
|
A MLP using fused linear layers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_model: int,
|
|
dropout: float,
|
|
activation: Activation,
|
|
hidden_layer_multiplier: int,
|
|
bias: bool = True,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
dim_mlp = hidden_layer_multiplier * dim_model
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(
|
|
in_features=dim_model, out_features=dim_mlp, bias=False
|
|
), # bias is handled in the next layer
|
|
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
|
# the `FusedLinear` import.
|
|
FusedDropoutBias(
|
|
p=dropout,
|
|
bias_shape=dim_mlp if bias else None,
|
|
activation=activation,
|
|
),
|
|
nn.Linear(
|
|
in_features=dim_mlp, out_features=dim_model, bias=False
|
|
), # bias is handled in the next layer
|
|
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
|
# the `FusedLinear` import.
|
|
FusedDropoutBias(
|
|
p=dropout,
|
|
bias_shape=dim_model if bias else None,
|
|
activation=None,
|
|
),
|
|
)
|
|
self.requires_cuda = True
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
return self.mlp(inputs)
|
|
|
|
except ImportError:
|
|
logger.warning("Triton is not available, FusedMLP will not be enabled.")
|