First commit
This commit is contained in:
79
pkgs/xformers/components/feedforward/fused_mlp.py
Normal file
79
pkgs/xformers/components/feedforward/fused_mlp.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation
|
||||
from xformers.components.feedforward import (
|
||||
Feedforward,
|
||||
FeedforwardConfig,
|
||||
register_feedforward,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.triton import FusedDropoutBias
|
||||
|
||||
@dataclass
|
||||
class FusedMlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
|
||||
@register_feedforward("FusedMLP", FusedMlpConfig)
|
||||
class FusedMLP(Feedforward):
|
||||
"""
|
||||
A MLP using fused linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dropout: float,
|
||||
activation: Activation,
|
||||
hidden_layer_multiplier: int,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_mlp = hidden_layer_multiplier * dim_model
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=dim_model, out_features=dim_mlp, bias=False
|
||||
), # bias is handled in the next layer
|
||||
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
||||
# the `FusedLinear` import.
|
||||
FusedDropoutBias(
|
||||
p=dropout,
|
||||
bias_shape=dim_mlp if bias else None,
|
||||
activation=activation,
|
||||
),
|
||||
nn.Linear(
|
||||
in_features=dim_mlp, out_features=dim_model, bias=False
|
||||
), # bias is handled in the next layer
|
||||
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
||||
# the `FusedLinear` import.
|
||||
FusedDropoutBias(
|
||||
p=dropout,
|
||||
bias_shape=dim_model if bias else None,
|
||||
activation=None,
|
||||
),
|
||||
)
|
||||
self.requires_cuda = True
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(inputs)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Triton is not available, FusedMLP will not be enabled.")
|
||||
Reference in New Issue
Block a user