First commit
This commit is contained in:
206
pkgs/xformers/components/residual.py
Normal file
206
pkgs/xformers/components/residual.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xformers import _is_triton_available
|
||||
|
||||
if _is_triton_available():
|
||||
from xformers.triton.layer_norm import FusedLayerNorm
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class ResidualNormStyle(str, Enum):
|
||||
"""Support different residual path and norm styles.
|
||||
See "On Layer Normalization in the Transformer Architecture",
|
||||
Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf
|
||||
"""
|
||||
|
||||
Pre = "pre"
|
||||
Post = "post"
|
||||
DeepNorm = "deepnorm"
|
||||
|
||||
|
||||
class NormalizationType(str, Enum):
|
||||
LayerNorm = "layernorm"
|
||||
Skip = "skip"
|
||||
# TODO: BatchNorm = "batchnorm"
|
||||
# TODO: GroupNorm = "groupnorm"
|
||||
|
||||
|
||||
def get_normalization_layer(normalization_type: NormalizationType):
|
||||
class Skip(nn.Module):
|
||||
def __init__(self, *_, **__) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, **_):
|
||||
return x
|
||||
|
||||
return {
|
||||
NormalizationType.LayerNorm: nn.LayerNorm,
|
||||
NormalizationType.Skip: Skip,
|
||||
}[normalization_type]
|
||||
|
||||
|
||||
class RequiresWrappedInputs:
|
||||
"""Used to mark, through inheritance,
|
||||
the fact that this class will require inputs to be passed as a single list"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# CREDITS: the following is inspired by FastAI's Transformer implementation
|
||||
class Residual(nn.Module, RequiresWrappedInputs):
|
||||
"""
|
||||
Object-oriented handling of the residual path
|
||||
|
||||
This supports scaling of the residual path, as proposed by DeepNet_
|
||||
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
||||
|
||||
.. Note: the wrapped layers must accept all the inputs as a single list
|
||||
"""
|
||||
|
||||
def __init__(self, layer: nn.Module, scale: Optional[float] = None):
|
||||
super().__init__()
|
||||
self.layer = layer
|
||||
self.scale = scale
|
||||
|
||||
# PreNorm and PostNorm require all the tensors to be passed as a list
|
||||
self.wrap_inputs = isinstance(layer, RequiresWrappedInputs)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
||||
if self.scale is not None:
|
||||
residue = inputs[0] * self.scale
|
||||
else:
|
||||
residue = inputs[0]
|
||||
|
||||
if self.wrap_inputs:
|
||||
return residue + self.layer(inputs=inputs, **kwargs)
|
||||
|
||||
else:
|
||||
return residue + self.layer(*inputs, **kwargs)
|
||||
|
||||
|
||||
class PreNorm(nn.Module, RequiresWrappedInputs):
|
||||
"""Adds a normalization before computing attention
|
||||
|
||||
..Note: If a list of inputs is passed, all of them get normalized"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_norm: int,
|
||||
sublayer: nn.Module,
|
||||
normalization: NormalizationType,
|
||||
use_triton: bool = True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
if (
|
||||
_is_triton_available()
|
||||
and use_triton
|
||||
and normalization == NormalizationType.LayerNorm
|
||||
):
|
||||
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
|
||||
else:
|
||||
self.norm = get_normalization_layer(normalization)(d_norm)
|
||||
|
||||
self.sublayer = sublayer
|
||||
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
||||
assert len(inputs) > 0
|
||||
|
||||
# Perf improvement: if the inputs are all the same, only norm once
|
||||
ids = [id(x) for x in inputs]
|
||||
if ids.count(ids[0]) == len(ids):
|
||||
# The same tensor is passed multiple times
|
||||
x_norm = self.norm(inputs[0])
|
||||
inputs_normed = [x_norm for _ in inputs]
|
||||
else:
|
||||
# The inputs differ, norm them all
|
||||
inputs_normed = [self.norm(x_) for x_ in inputs]
|
||||
|
||||
if self.wrap_inputs:
|
||||
return self.sublayer(inputs=inputs_normed, **kwargs)
|
||||
else:
|
||||
return self.sublayer(*inputs_normed, **kwargs)
|
||||
|
||||
|
||||
class PostNorm(nn.Module, RequiresWrappedInputs):
|
||||
"""Adds LayerNorm after computing attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_norm: int,
|
||||
sublayer: nn.Module,
|
||||
normalization: NormalizationType,
|
||||
use_triton: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
if (
|
||||
_is_triton_available()
|
||||
and use_triton
|
||||
and normalization == NormalizationType.LayerNorm
|
||||
):
|
||||
self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm)
|
||||
else:
|
||||
self.norm = get_normalization_layer(normalization)(d_norm)
|
||||
|
||||
self.sublayer = sublayer
|
||||
self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs)
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor], **kwargs):
|
||||
if self.wrap_inputs:
|
||||
x = self.sublayer(inputs=inputs, **kwargs)
|
||||
else:
|
||||
x = self.sublayer(*inputs, **kwargs)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
|
||||
|
||||
|
||||
def get_deepnorm_coefficients(
|
||||
encoder_layers: int, decoder_layers: int
|
||||
) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
|
||||
"""
|
||||
See DeepNet_.
|
||||
|
||||
Returns alpha and beta depending on the number of encoder and decoder layers,
|
||||
first tuple is for the encoder and second for the decoder
|
||||
|
||||
.. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
|
||||
"""
|
||||
|
||||
N = encoder_layers
|
||||
M = decoder_layers
|
||||
|
||||
if decoder_layers == 0:
|
||||
# Encoder only
|
||||
return (
|
||||
DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
|
||||
None,
|
||||
)
|
||||
|
||||
elif encoder_layers == 0:
|
||||
# Decoder only
|
||||
return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
|
||||
else:
|
||||
# Encoder/decoder
|
||||
encoder_coeffs = DeepNormCoefficients(
|
||||
alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625
|
||||
)
|
||||
|
||||
decoder_coeffs = DeepNormCoefficients(
|
||||
alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
|
||||
)
|
||||
|
||||
return (encoder_coeffs, decoder_coeffs)
|
||||
Reference in New Issue
Block a user