First commit
This commit is contained in:
71
pkgs/xformers/components/activations.py
Normal file
71
pkgs/xformers/components/activations.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Activation(str, Enum):
|
||||
SquaredReLU = "squared_relu"
|
||||
GeLU = "gelu"
|
||||
LeakyReLU = "leaky_relu"
|
||||
ReLU = "relu"
|
||||
SmeLU = "smelu"
|
||||
StarReLU = "star_relu"
|
||||
|
||||
|
||||
# For unit testing / parity comparisons, probably not the fastest way
|
||||
class SquaredReLU(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_ = torch.nn.functional.relu(x)
|
||||
return x_ * x_
|
||||
|
||||
|
||||
class StarReLU(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_ = torch.nn.functional.relu(x)
|
||||
return 0.8944 * x_ * x_ - 0.4472
|
||||
|
||||
|
||||
class SmeLU(nn.Module):
|
||||
def __init__(self, beta: float = 2.0) -> None:
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
relu = torch.where(
|
||||
x >= self.beta,
|
||||
x,
|
||||
torch.tensor([0.0], device=x.device, dtype=x.dtype),
|
||||
)
|
||||
return torch.where(
|
||||
torch.abs(x) <= self.beta,
|
||||
((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta),
|
||||
relu,
|
||||
)
|
||||
|
||||
|
||||
def build_activation(activation: Optional[Activation]):
|
||||
if not activation:
|
||||
return nn.Identity()
|
||||
|
||||
return {
|
||||
Activation.ReLU: nn.ReLU,
|
||||
Activation.GeLU: nn.GELU,
|
||||
Activation.LeakyReLU: nn.LeakyReLU,
|
||||
Activation.SquaredReLU: SquaredReLU,
|
||||
Activation.StarReLU: StarReLU,
|
||||
Activation.SmeLU: SmeLU,
|
||||
}[activation]()
|
||||
Reference in New Issue
Block a user