Files
enginex-bi_series-vllm/pkgs/xformers/triton/k_activations.py

153 lines
3.5 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional
import triton
import triton.language as tl
from xformers.components import Activation
_kAlpha = math.sqrt(2.0 / math.pi)
def get_triton_activation_index(activation: Optional[Activation]) -> int:
return (
{
Activation.ReLU: 1,
Activation.LeakyReLU: 2,
Activation.GeLU: 3,
Activation.SquaredReLU: 4,
Activation.SmeLU: 5,
Activation.StarReLU: 6,
}[activation]
if activation is not None
else 0
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
return tl.where(x >= 0, x, 0.0)
@triton.jit
def relu_grad(x):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
return tl.where(x >= 0, 1.0, 0.0)
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_sq = x * x
return tl.where(x > 0.0, x_sq, 0.0)
@triton.jit
def squared_relu_grad(x):
return tl.where(x >= 0.0, 2 * x, 0.0)
@triton.jit
def star_relu(x):
"""
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
"""
x_sq = x * x
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
@triton.jit
def star_relu_grad(x):
return tl.where(x >= 0.0, 1.7888 * x, 0.0)
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
return tl.where(x >= 0.0, x, 0.01 * x)
@triton.jit
def leaky_relu_grad(x):
return tl.where(x >= 0.0, 1.0, 0.01)
@triton.jit
def gelu(x):
"""
GeLU_ activation - Gaussian error linear unit
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
@triton.jit
def gelu_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
@triton.jit
def smelu(x):
"""
SmeLU_ activation - Smooth ReLU with beta=2.0
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
beta = 2.0
relu = tl.where(x >= beta, x, 0.0)
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
@triton.jit
def smelu_grad(x):
beta = 2.0
relu_grad = tl.where(x >= beta, 1.0, 0.0)
return tl.where(tl.abs(x) <= beta, (beta + x) / (2.0 * beta), relu_grad)