72 lines
1.8 KiB
Python
72 lines
1.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class Activation(str, Enum):
|
|
SquaredReLU = "squared_relu"
|
|
GeLU = "gelu"
|
|
LeakyReLU = "leaky_relu"
|
|
ReLU = "relu"
|
|
SmeLU = "smelu"
|
|
StarReLU = "star_relu"
|
|
|
|
|
|
# For unit testing / parity comparisons, probably not the fastest way
|
|
class SquaredReLU(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x_ = torch.nn.functional.relu(x)
|
|
return x_ * x_
|
|
|
|
|
|
class StarReLU(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x_ = torch.nn.functional.relu(x)
|
|
return 0.8944 * x_ * x_ - 0.4472
|
|
|
|
|
|
class SmeLU(nn.Module):
|
|
def __init__(self, beta: float = 2.0) -> None:
|
|
super().__init__()
|
|
self.beta = beta
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
relu = torch.where(
|
|
x >= self.beta,
|
|
x,
|
|
torch.tensor([0.0], device=x.device, dtype=x.dtype),
|
|
)
|
|
return torch.where(
|
|
torch.abs(x) <= self.beta,
|
|
((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta),
|
|
relu,
|
|
)
|
|
|
|
|
|
def build_activation(activation: Optional[Activation]):
|
|
if not activation:
|
|
return nn.Identity()
|
|
|
|
return {
|
|
Activation.ReLU: nn.ReLU,
|
|
Activation.GeLU: nn.GELU,
|
|
Activation.LeakyReLU: nn.LeakyReLU,
|
|
Activation.SquaredReLU: SquaredReLU,
|
|
Activation.StarReLU: StarReLU,
|
|
Activation.SmeLU: SmeLU,
|
|
}[activation]()
|