First commit
This commit is contained in:
87
pkgs/xformers/components/feedforward/__init__.py
Normal file
87
pkgs/xformers/components/feedforward/__init__.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Set, Union
|
||||
|
||||
from xformers.utils import (
|
||||
generate_matching_config,
|
||||
get_registry_decorator,
|
||||
import_all_modules,
|
||||
)
|
||||
|
||||
from .base import Feedforward, FeedforwardConfig # noqa
|
||||
|
||||
# CREDITS: Classy Vision registry mechanism
|
||||
|
||||
FEEDFORWARD_REGISTRY: Dict[str, Any] = {}
|
||||
FEEDFORWARD_CLASS_NAMES: Set[str] = set()
|
||||
|
||||
|
||||
def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]):
|
||||
"""Builds a feedforward from a config.
|
||||
|
||||
This assumes a 'name' key in the config which is used to determine what
|
||||
attention class to instantiate. For instance, a config `{"name": "my_feedforward",
|
||||
"foo": "bar"}` will find a class that was registered as "my_feedforward"
|
||||
(see :func:`register_feedforward`) and call .from_config on it."""
|
||||
|
||||
if not isinstance(config, FeedforwardConfig):
|
||||
config_instance = generate_matching_config(
|
||||
config, FEEDFORWARD_REGISTRY[config["name"]].config
|
||||
)
|
||||
else:
|
||||
config_instance = config
|
||||
|
||||
return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config(
|
||||
config_instance
|
||||
)
|
||||
|
||||
|
||||
"""Registers a Feedforward subclass.
|
||||
|
||||
This decorator allows xFormers to instantiate a subclass of Feedforward
|
||||
from a configuration file, even if the class itself is not part of the
|
||||
xFormers framework. To use it, apply this decorator to a Feedforward
|
||||
subclass, like this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@dataclass
|
||||
class MyConfig:
|
||||
...
|
||||
|
||||
@register_feedforward('my_ff', MyConfig)
|
||||
class MyFeedforward(Feedforward):
|
||||
...
|
||||
|
||||
To instantiate a feedforward from a configuration file, see :func:`build_feedforward`."""
|
||||
register_feedforward: Callable[
|
||||
[str, Any], Callable[[Any], Any]
|
||||
] = get_registry_decorator(
|
||||
FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig
|
||||
)
|
||||
|
||||
try:
|
||||
from .fused_mlp import FusedMLP # noqa
|
||||
|
||||
_fused_mlp_available = True
|
||||
except ImportError:
|
||||
_fused_mlp_available = False
|
||||
from .mlp import MLP # noqa
|
||||
|
||||
__all__ = [
|
||||
"MLP",
|
||||
"Feedforward",
|
||||
"build_feedforward",
|
||||
"register_feedforward",
|
||||
]
|
||||
|
||||
if _fused_mlp_available:
|
||||
__all__ += ["FusedMLP"]
|
||||
|
||||
# automatically import any Python files in the directory
|
||||
import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
53
pkgs/xformers/components/feedforward/base.py
Normal file
53
pkgs/xformers/components/feedforward/base.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation
|
||||
|
||||
Self = TypeVar("Self", bound="Feedforward")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeedforwardConfig:
|
||||
name: str
|
||||
dim_model: int
|
||||
dropout: float
|
||||
activation: Activation
|
||||
|
||||
|
||||
# Define the common interface, every feedforward block needs to derive from it
|
||||
class Feedforward(nn.Module, metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: Optional[int] = None,
|
||||
dropout: Optional[float] = None,
|
||||
activation: Optional[Activation] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# This feedforward requires a CUDA accelerator
|
||||
self.requires_cuda = False
|
||||
|
||||
# This feedforward requires a context length which is squared, often due to 2D pooling
|
||||
self.requires_squared_context = False
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self:
|
||||
# Generate the class inputs from the config
|
||||
fields = asdict(config)
|
||||
|
||||
# Skip all Nones so that default values are used
|
||||
fields = {k: v for k, v in fields.items() if v is not None}
|
||||
|
||||
return cls(**fields)
|
||||
97
pkgs/xformers/components/feedforward/conv_mlp.py
Normal file
97
pkgs/xformers/components/feedforward/conv_mlp.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: Largely reusing the code from the reference VAN implementation
|
||||
# see https://github.com/Visual-Attention-Network
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.components.feedforward import Feedforward, FeedforwardConfig
|
||||
|
||||
from . import register_feedforward
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvMlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
dim_model: int
|
||||
dim_model_out: Optional[int]
|
||||
act_layer: Activation
|
||||
dropout: float
|
||||
|
||||
|
||||
@register_feedforward("Conv2DFeedforward", ConvMlpConfig)
|
||||
class Conv2DFeedforward(Feedforward):
|
||||
"""
|
||||
A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.)
|
||||
|
||||
.. _VAN: https://arxiv.org/pdf/2202.09741.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
hidden_layer_multiplier: int = 1,
|
||||
dim_model_out: Optional[int] = None,
|
||||
activation: Activation = Activation.GeLU,
|
||||
dropout=0.0,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = dim_model_out or dim_model
|
||||
hidden_features = hidden_layer_multiplier * dim_model
|
||||
|
||||
self.conv_mlp = nn.Sequential(
|
||||
nn.Conv2d(dim_model, hidden_features, 1),
|
||||
nn.Conv2d(
|
||||
hidden_features,
|
||||
hidden_features,
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=True,
|
||||
groups=hidden_features,
|
||||
),
|
||||
build_activation(activation),
|
||||
nn.Conv2d(hidden_features, out_features, 1),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
# This feedforward requires a context length which is squared, often due to 2D pooling
|
||||
self.requires_squared_context = True
|
||||
|
||||
def init_weights(self, **kwargs):
|
||||
# Follow the original init, but also make it possible to initialize from the outside
|
||||
def init_module(m: nn.Module):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
self.apply(init_module)
|
||||
|
||||
def forward(self, x):
|
||||
# The conv layers expect NCHW, we have NLC by default
|
||||
B, L, C = x.shape
|
||||
HW = int(math.sqrt(x.shape[-2]))
|
||||
assert HW**2 == L, "Conv2DFeedforward requires squared context lengths"
|
||||
|
||||
x = x.reshape((B, HW, HW, C)).swapdims(1, -1)
|
||||
|
||||
# The actual FW, including the 2d convolutions
|
||||
x = self.conv_mlp(x)
|
||||
|
||||
# back to NLC
|
||||
x = x.transpose(1, -1)
|
||||
return x.flatten(1, 2)
|
||||
79
pkgs/xformers/components/feedforward/fused_mlp.py
Normal file
79
pkgs/xformers/components/feedforward/fused_mlp.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation
|
||||
from xformers.components.feedforward import (
|
||||
Feedforward,
|
||||
FeedforwardConfig,
|
||||
register_feedforward,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
from xformers.triton import FusedDropoutBias
|
||||
|
||||
@dataclass
|
||||
class FusedMlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
|
||||
@register_feedforward("FusedMLP", FusedMlpConfig)
|
||||
class FusedMLP(Feedforward):
|
||||
"""
|
||||
A MLP using fused linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dropout: float,
|
||||
activation: Activation,
|
||||
hidden_layer_multiplier: int,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dim_mlp = hidden_layer_multiplier * dim_model
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=dim_model, out_features=dim_mlp, bias=False
|
||||
), # bias is handled in the next layer
|
||||
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
||||
# the `FusedLinear` import.
|
||||
FusedDropoutBias(
|
||||
p=dropout,
|
||||
bias_shape=dim_mlp if bias else None,
|
||||
activation=activation,
|
||||
),
|
||||
nn.Linear(
|
||||
in_features=dim_mlp, out_features=dim_model, bias=False
|
||||
), # bias is handled in the next layer
|
||||
# pyre-ignore[16]: TODO(T101400990): Pyre did not recognize
|
||||
# the `FusedLinear` import.
|
||||
FusedDropoutBias(
|
||||
p=dropout,
|
||||
bias_shape=dim_model if bias else None,
|
||||
activation=None,
|
||||
),
|
||||
)
|
||||
self.requires_cuda = True
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(inputs)
|
||||
|
||||
except ImportError:
|
||||
logger.warning("Triton is not available, FusedMLP will not be enabled.")
|
||||
153
pkgs/xformers/components/feedforward/mixture_of_experts.py
Normal file
153
pkgs/xformers/components/feedforward/mixture_of_experts.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from xformers.components import Activation
|
||||
from xformers.components.feedforward import (
|
||||
Feedforward,
|
||||
FeedforwardConfig,
|
||||
register_feedforward,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_is_fairscale_available = True
|
||||
|
||||
try:
|
||||
import torch.distributed as dist
|
||||
from fairscale.nn import MOELayer, Top2Gate # type: ignore
|
||||
|
||||
from xformers.components.feedforward import MLP
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed."
|
||||
" Please install them if you would like to use MoE"
|
||||
)
|
||||
_is_fairscale_available = False
|
||||
|
||||
|
||||
if _is_fairscale_available:
|
||||
|
||||
# Credits: initially implemented in FairScale for sanity checking
|
||||
class RoundRobinGate(torch.nn.Module):
|
||||
def __init__(self, model_dim, num_experts):
|
||||
super().__init__()
|
||||
self.model_dim = model_dim
|
||||
self.num_experts = num_experts
|
||||
|
||||
def forward(self, input):
|
||||
s = input.shape[0]
|
||||
assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0"
|
||||
capacity = 2 * s // self.num_experts
|
||||
output = torch.zeros(
|
||||
s, self.num_experts, capacity, dtype=input.dtype, device=input.device
|
||||
)
|
||||
for i in range(s):
|
||||
output[i, i % self.num_experts, i // self.num_experts] = 1.0
|
||||
return 0.0, output, output.bool()
|
||||
|
||||
class GateConfig(str, Enum):
|
||||
RoundRobin = "round_robin"
|
||||
Top2 = "top_2"
|
||||
# Other gating techniques could be exposed here
|
||||
|
||||
@dataclass
|
||||
class MoEConfig(FeedforwardConfig):
|
||||
number_of_experts: int
|
||||
gate: GateConfig
|
||||
number_of_local_experts: Optional[int] = None
|
||||
expert_constructor: Optional[Any] = None
|
||||
hidden_layer_multiplier: Optional[int] = None
|
||||
group: Optional[Any] = None
|
||||
|
||||
@register_feedforward("MixtureOfExperts", MoEConfig)
|
||||
class MixtureOfExperts(Feedforward):
|
||||
"""
|
||||
A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_.
|
||||
xFormers uses the FairScale_ implementation under the hood.
|
||||
|
||||
.. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt
|
||||
|
||||
.. _Gshard: https://arxiv.org/pdf/2006.16668.pdf
|
||||
.. _FairScale: https://github.com/facebookresearch/fairscale/
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dropout: float,
|
||||
activation: Activation,
|
||||
number_of_experts: int,
|
||||
gate: Union[GateConfig, torch.nn.Module],
|
||||
number_of_local_experts: Optional[int] = None,
|
||||
expert_constructor: Optional[Callable[[], torch.nn.Module]] = None,
|
||||
hidden_layer_multiplier: Optional[int] = None,
|
||||
group: Optional[Any] = None,
|
||||
*_,
|
||||
**__,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Handle a possibly uninitialized process group
|
||||
assert (
|
||||
dist.is_initialized()
|
||||
), "Mixture of Experts require torch distributed to be initialized"
|
||||
|
||||
if number_of_local_experts is not None:
|
||||
assert number_of_experts >= number_of_local_experts
|
||||
else:
|
||||
if dist.get_world_size() == 1:
|
||||
logger.warning("Local experts no specified but world size of 1")
|
||||
logger.warning("Assuming that all experts are local")
|
||||
number_of_local_experts = number_of_experts
|
||||
else:
|
||||
number_of_local_experts = 1
|
||||
|
||||
# Programatically handle the gating technique
|
||||
if not isinstance(gate, torch.nn.Module):
|
||||
gate_constructor = {
|
||||
GateConfig.RoundRobin: RoundRobinGate,
|
||||
GateConfig.Top2: Top2Gate,
|
||||
}[gate]
|
||||
|
||||
self.gate = gate_constructor(dim_model, number_of_experts)
|
||||
else:
|
||||
self.gate = gate
|
||||
|
||||
# Programatically handle the experts
|
||||
if expert_constructor is None:
|
||||
|
||||
multiplier = (
|
||||
hidden_layer_multiplier
|
||||
if hidden_layer_multiplier is not None
|
||||
else 4
|
||||
)
|
||||
|
||||
def expert_constructor() -> torch.nn.Module:
|
||||
return MLP(dim_model, dropout, activation, multiplier)
|
||||
|
||||
assert expert_constructor is not None
|
||||
|
||||
local_experts = torch.nn.ModuleList(
|
||||
[expert_constructor() for _ in range(number_of_local_experts)]
|
||||
)
|
||||
|
||||
self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group)
|
||||
|
||||
self.requires_cuda = True
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
# FairScale MoE assumes that the dimensions are [S, B, E]
|
||||
# xFormers assumes [B, S, E]
|
||||
return self.moe(inputs.movedim(0, 1)).movedim(0, 1)
|
||||
47
pkgs/xformers/components/feedforward/mlp.py
Normal file
47
pkgs/xformers/components/feedforward/mlp.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers.components import Activation, build_activation
|
||||
from xformers.components.feedforward import Feedforward, FeedforwardConfig
|
||||
|
||||
from . import register_feedforward
|
||||
|
||||
|
||||
@dataclass
|
||||
class MlpConfig(FeedforwardConfig):
|
||||
hidden_layer_multiplier: int
|
||||
bias: bool
|
||||
|
||||
|
||||
@register_feedforward("MLP", MlpConfig)
|
||||
class MLP(Feedforward):
|
||||
def __init__(
|
||||
self,
|
||||
dim_model: int,
|
||||
dropout: float,
|
||||
activation: Activation,
|
||||
hidden_layer_multiplier: int,
|
||||
bias: bool = True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
dim_mlp = hidden_layer_multiplier * dim_model
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias),
|
||||
build_activation(activation),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(inputs)
|
||||
Reference in New Issue
Block a user