207 lines
6.0 KiB
Python
207 lines
6.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from enum import Enum
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from xformers import _is_triton_available
|
|
|
|
if _is_triton_available():
|
|
from xformers.triton.layer_norm import FusedLayerNorm
|
|
|
|
from collections import namedtuple
|
|
|
|
|
|
class ResidualNormStyle(str, Enum):
|
|
"""Support different residual path and norm styles.
|
|
See "On Layer Normalization in the Transformer Architecture",
|
|
Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
|
|
"""
|
|
|
|
Pre = "pre"
|
|
Post = "post"
|
|
DeepNorm = "deepnorm"
|
|
|
|
|
|
class NormalizationType(str, Enum):
|
|
LayerNorm = "layernorm"
|
|
Skip = "skip"
|
|
# TODO: BatchNorm = "batchnorm"
|
|
# TODO: GroupNorm = "groupnorm"
|
|
|
|
|
|
def get_normalization_layer(normalization_type: NormalizationType):
|
|
class Skip(nn.Module):
|
|
def __init__(self, *_, **__) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor, **_):
|
|
return x
|
|
|
|
return {
|
|
NormalizationType.LayerNorm: nn.LayerNorm,
|
|
NormalizationType.Skip: Skip,
|
|
}[normalization_type]
|
|
|
|
|
|
class RequiresWrappedInputs:
|
|
"""Used to mark, through inheritance,
|
|
the fact that this class will require inputs to be passed as a single list"""
|
|
|
|
pass
|
|
|
|
|
|
# CREDITS: the following is inspired by FastAI's Transformer implementation
|
|
class Residual(nn.Module, RequiresWrappedInputs):
|
|
"""
|
|
Object-oriented handling of the residual path
|
|
|
|
This supports scaling of the residual path, as proposed by DeepNet_
|
|
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
|
|
|
.. Note: the wrapped layers must accept all the inputs as a single list
|
|
"""
|
|
|
|
def __init__(self, layer: nn.Module, scale: Optional[float] = None):
|
|
super().__init__()
|
|
self.layer = layer
|
|
self.scale = scale
|
|
|
|
# PreNorm and PostNorm require all the tensors to be passed as a list
|
|
self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)
|
|
|
|
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
|
if self.scale is not None:
|
|
residue = inputs[0] * self.scale
|
|
else:
|
|
residue = inputs[0]
|
|
|
|
if self.wrap_inputs:
|
|
return residue + self.layer(inputs=inputs, **kwargs)
|
|
|
|
else:
|
|
return residue + self.layer(*inputs, **kwargs)
|
|
|
|
|
|
class PreNorm(nn.Module, RequiresWrappedInputs):
|
|
"""Adds a normalization before computing attention
|
|
|
|
..Note: If a list of inputs is passed, all of them get normalized"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_norm: int,
|
|
sublayer: nn.Module,
|
|
normalization: NormalizationType,
|
|
use_triton: bool = True,
|
|
):
|
|
|
|
super().__init__()
|
|
if (
|
|
_is_triton_available()
|
|
and use_triton
|
|
and normalization == NormalizationType.LayerNorm
|
|
):
|
|
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
|
|
else:
|
|
self.norm = get_normalization_layer(normalization)(d_norm)
|
|
|
|
self.sublayer = sublayer
|
|
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
|
|
|
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
|
assert len(inputs) > 0
|
|
|
|
# Perf improvement: if the inputs are all the same, only norm once
|
|
ids = [id(x) for x in inputs]
|
|
if ids.count(ids[0]) == len(ids):
|
|
# The same tensor is passed multiple times
|
|
x_norm = self.norm(inputs[0])
|
|
inputs_normed = [x_norm for _ in inputs]
|
|
else:
|
|
# The inputs differ, norm them all
|
|
inputs_normed = [self.norm(x_) for x_ in inputs]
|
|
|
|
if self.wrap_inputs:
|
|
return self.sublayer(inputs=inputs_normed, **kwargs)
|
|
else:
|
|
return self.sublayer(*inputs_normed, **kwargs)
|
|
|
|
|
|
class PostNorm(nn.Module, RequiresWrappedInputs):
|
|
"""Adds LayerNorm after computing attention"""
|
|
|
|
def __init__(
|
|
self,
|
|
d_norm: int,
|
|
sublayer: nn.Module,
|
|
normalization: NormalizationType,
|
|
use_triton: bool = True,
|
|
):
|
|
super().__init__()
|
|
if (
|
|
_is_triton_available()
|
|
and use_triton
|
|
and normalization == NormalizationType.LayerNorm
|
|
):
|
|
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
|
|
else:
|
|
self.norm = get_normalization_layer(normalization)(d_norm)
|
|
|
|
self.sublayer = sublayer
|
|
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
|
|
|
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
|
if self.wrap_inputs:
|
|
x = self.sublayer(inputs=inputs, **kwargs)
|
|
else:
|
|
x = self.sublayer(*inputs, **kwargs)
|
|
return self.norm(x)
|
|
|
|
|
|
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
|
|
|
|
|
|
def get_deepnorm_coefficients(
|
|
encoder_layers: int, decoder_layers: int
|
|
) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
|
|
"""
|
|
See DeepNet_.
|
|
|
|
Returns alpha and beta depending on the number of encoder and decoder layers,
|
|
first tuple is for the encoder and second for the decoder
|
|
|
|
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
|
"""
|
|
|
|
N = encoder_layers
|
|
M = decoder_layers
|
|
|
|
if decoder_layers == 0:
|
|
# Encoder only
|
|
return (
|
|
DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
|
|
None,
|
|
)
|
|
|
|
elif encoder_layers == 0:
|
|
# Decoder only
|
|
return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
|
|
else:
|
|
# Encoder/decoder
|
|
encoder_coeffs = DeepNormCoefficients(
|
|
alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
|
|
)
|
|
|
|
decoder_coeffs = DeepNormCoefficients(
|
|
alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
|
|
)
|
|
|
|
return (encoder_coeffs, decoder_coeffs)
|