First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
_triton_available = torch.cuda.is_available()
if _triton_available:
try:
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
from .layer_norm import FusedLayerNorm, layer_norm # noqa
from .softmax import log_softmax, softmax # noqa
__all__ = [
"dropout",
"softmax",
"log_softmax",
"FusedDropoutBias",
"FusedLinear",
"FusedLayerNorm",
"layer_norm",
]
except ImportError:
__all__ = []

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,243 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
from typing import Optional
import torch
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.components.activations import Activation, build_activation
from xformers.triton.k_activations import get_triton_activation_index
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
BLOCK_M = 32
BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf
# Helper to handle the SPMD launch grid and error cases
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, trainable_bias):
# Soft-flatten an hypothetical 3rd dimension
x_ = x.reshape(-1, x.shape[-1]).contiguous()
y = torch.empty_like(x_)
M, N = x_.shape
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)
assert p > 0.0
def grid(meta):
return (
triton.cdiv(M, meta["BLOCK_M"]),
triton.cdiv(N, meta["BLOCK_N"]),
)
N_BLOCK_N = triton.cdiv(N, BLOCK_N)
# Generate one seed per sample
# seed max is int32 max for positive numbers: 2**16
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32)
# fmt: off
bias_ptr = bias if bias is not None else x_ # Possibly not being used
k_dropout_fw[grid](
y, x_,
bias_ptr,
seeds,
y.stride(0),
M, N,
p,
x.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=activation,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
if activation is not None:
ctx.save_for_backward(seeds, bias, x)
else:
ctx.save_for_backward(seeds, bias, None)
ctx.trainable_bias = bias is not None and trainable_bias
ctx.activation = activation
ctx.p = p
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, grad_out
): # pragma: no cover # This is covered, but called from C++ and not tracked
(seeds, bias, inputs) = ctx.saved_tensors
# Soft-flatten an hypothetical 3rd dimension
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
grad_in = torch.empty_like(grad_out_)
M, N = grad_out_.shape
# Optional inputs to compute the activation contribution to the gradient
assert inputs is not None or ctx.activation is None
if inputs is None:
inputs = grad_out_
elif inputs.ndim > 2:
inputs = inputs.reshape(-1, N)
# We split the problem in tiles:
# - over M there will be a follow up reduction
# - over N we compromise in between trying to use as much memory paralellism as possible,
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
# big because of register spilling
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)
if ctx.trainable_bias:
grad_bias = torch.empty(
(
N_BLOCKS_M,
N,
),
device=grad_in.device,
dtype=grad_in.dtype,
)
else:
grad_bias = grad_in # will not be used
def grid(meta):
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
# but we need to take this factor of 4 into account when scheduling all the kernels
return (
N_BLOCKS_M,
triton.cdiv(N, meta["BLOCK_N"]),
)
# fmt: off
k_dropout_bw[grid](
grad_in, grad_bias, grad_out_,
inputs, bias if bias is not None else inputs,
seeds,
grad_out_.stride(0), inputs.stride(0),
M, N,
ctx.p,
grad_in.dtype == torch.float16,
USE_BIAS=bias is not None,
ACTIVATION=ctx.activation,
TRAINABLE_BIAS=ctx.trainable_bias,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
)
# fmt: on
return (
grad_in.reshape_as(grad_out),
None,
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
None,
None,
None,
)
def dropout(
x: torch.Tensor,
p: float,
bias: Optional[torch.Tensor] = None,
activation: Optional[Activation] = None,
):
"""
Apply dropout on the input tensor.
Optionally add a bias, the computation will be fused.
"""
assert p <= 1.0 and p >= 0.0
if p == 1.0:
return torch.zeros_like(x)
# Micro optim, skip dropout
if p == 0.0:
x = x + bias if bias is not None else x
if activation is not None:
activation_fn = build_activation(activation)
return activation_fn(x)
return x
# The normal triton enabled codepath
activation_index = get_triton_activation_index(activation)
return _dropout.apply(
x,
float(p),
bias,
activation_index,
bias is not None and bias.requires_grad,
)
class FusedDropoutBias(torch.nn.Module):
"""
A layer which fuses the computation of Dropout(Activation(x))
in a single GPU kernel
"""
def __init__(
self,
p: float,
bias_shape: Optional[int],
activation: Optional[Activation] = None,
) -> None:
super().__init__()
self.p = float(p)
assert (
self.p < 1.0
), f"We don't want to drop all the values, most probably p={p} is not properly set"
self.activation_type = activation
self.bias = (
torch.zeros(bias_shape, requires_grad=True)
if bias_shape is not None
else None
)
self.activation = get_triton_activation_index(self.activation_type)
self.activation_pytorch = build_activation(self.activation_type)
def init_weights(self, *args, **kwargs):
with torch.no_grad():
if self.bias is not None:
self.bias.fill_(0.0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Convenience, catch a possible type or device mismatch
if self.bias is not None:
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
# Train/inference
p = self.p if self.training else 0.0
# This kernel is slower than pytorch for small buffers, bypassing it in that case
perf_check = x.shape[-1] > 512
# Catch a non-cuda setup, fallback to pytorch
if not x.is_cuda or not perf_check or p == 0.0:
x = x + self.bias if self.bias is not None else x
x = self.activation_pytorch(x)
return torch.nn.functional.dropout(x, p) if p > 0.0 else x
# The normal, Triton-backed path
return _dropout.apply(x, p, self.bias, self.activation, True)

View File

@@ -0,0 +1,119 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Any, Optional
import torch
import torch.nn as nn
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.components.activations import Activation
from xformers.triton.k_activations import get_triton_activation_index
from xformers.triton.k_fused_matmul_bw import fused_matmul_backward
from xformers.triton.k_fused_matmul_fw import fused_matmul
class _fused_linear_triton(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
ctx,
x,
weight,
bias,
activation,
trainable_weight,
trainable_bias,
save_activation_inputs,
):
# Kick the fused Triton kernel, handling bias and activation in one go
y, activation_inputs = fused_matmul(
x, weight, bias, activation, save_activation_inputs
)
ctx.activation = activation
ctx.trainable_weight = trainable_weight
ctx.trainable_bias = trainable_bias
# Micro-optimization: saving these is not always needed (?)
if x.requires_grad or ctx.trainable_weight or ctx.trainable_bias:
ctx.save_for_backward(weight, activation_inputs, x)
return y
@staticmethod
@custom_bwd
def backward(
ctx: Any, grad_out: torch.Tensor
) -> Any: # pragma: no cover # this is covered, but called directly from C++
"""
Compute the derivative with respect to x, other tensors were not trainable inputs.
"""
(weight, activation_inputs, x) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_matmul_backward(
grad_out=grad_out,
inputs=x,
act_in=activation_inputs,
weight=weight,
trainable_weight=ctx.trainable_weight,
trainable_bias=ctx.trainable_bias,
activation_grad=ctx.activation,
)
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
class FusedLinear(nn.Module):
"""
Handle a linear transform, like torch.nn.Linear_, and a given activation, in a single kernel.
The whole transform: is :math:`y = activation(xA^T + b)`.
This is typically significantly faster than PyTorch while using fp16 and non-sigmoid activations,
as of September 2021.
.. _torch.nn.Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
activation: Optional[Activation] = None,
**_,
):
super().__init__()
self.weight = nn.Parameter(
torch.empty(out_features, in_features), requires_grad=True
)
self.bias = (
nn.Parameter(torch.empty(out_features), requires_grad=True)
if bias
else None
)
self._activation_index = get_triton_activation_index(activation)
self.reset_parameters()
def reset_parameters(self) -> None:
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
torch.nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x):
return _fused_linear_triton.apply(
x,
self.weight,
self.bias,
self._activation_index,
self.weight.requires_grad,
self.bias.requires_grad if self.bias is not None else False,
self.training and x.requires_grad and self._activation_index > 0,
)

View File

@@ -0,0 +1,152 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional
import triton
import triton.language as tl
from xformers.components import Activation
_kAlpha = math.sqrt(2.0 / math.pi)
def get_triton_activation_index(activation: Optional[Activation]) -> int:
return (
{
Activation.ReLU: 1,
Activation.LeakyReLU: 2,
Activation.GeLU: 3,
Activation.SquaredReLU: 4,
Activation.SmeLU: 5,
Activation.StarReLU: 6,
}[activation]
if activation is not None
else 0
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
return tl.where(x >= 0, x, 0.0)
@triton.jit
def relu_grad(x):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
return tl.where(x >= 0, 1.0, 0.0)
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_sq = x * x
return tl.where(x > 0.0, x_sq, 0.0)
@triton.jit
def squared_relu_grad(x):
return tl.where(x >= 0.0, 2 * x, 0.0)
@triton.jit
def star_relu(x):
"""
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
"""
x_sq = x * x
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
@triton.jit
def star_relu_grad(x):
return tl.where(x >= 0.0, 1.7888 * x, 0.0)
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
return tl.where(x >= 0.0, x, 0.01 * x)
@triton.jit
def leaky_relu_grad(x):
return tl.where(x >= 0.0, 1.0, 0.01)
@triton.jit
def gelu(x):
"""
GeLU_ activation - Gaussian error linear unit
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
@triton.jit
def gelu_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)
@triton.jit
def smelu(x):
"""
SmeLU_ activation - Smooth ReLU with beta=2.0
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
"""
beta = 2.0
relu = tl.where(x >= beta, x, 0.0)
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
@triton.jit
def smelu_grad(x):
beta = 2.0
relu_grad = tl.where(x >= beta, 1.0, 0.0)
return tl.where(tl.abs(x) <= beta, (beta + x) / (2.0 * beta), relu_grad)

View File

@@ -0,0 +1,211 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This is heavily inspired by the Triton dropout tutorial
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
import triton
import triton.language as tl
from xformers.triton.k_activations import (
gelu,
gelu_grad,
leaky_relu,
leaky_relu_grad,
relu,
relu_grad,
smelu,
smelu_grad,
squared_relu,
squared_relu_grad,
)
_configs = [
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
]
# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_fw(
Y, X, BIAS, SEEDS,
stride,
M, N,
p: tl.constexpr,
is_fp16: tl.constexpr, # autotune
ACTIVATION: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
SIZE_RAND_BLOCK: tl.constexpr,
USE_BIAS: tl.constexpr,
):
"""
Apply dropout on an input tensor
Y : Output (M, N)
X : Input (M, N)
BIAS (N,)
SEEDS (M,)
p : dropout probability
"""
# fmt: on
row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
# pointers starting point
x_ptrs = X + rows[:, None] * stride + cols[None, :]
y_ptrs = Y + rows[:, None] * stride + cols[None, :]
# good to go, start the layer computations
col_mask = cols[None, :] < N
p_scale = 1. / (1. - p)
if USE_BIAS:
b_ptrs = BIAS + cols[None, :]
bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)
else:
bias = x_ptrs # will not be used
block_mask = (rows[:, None] < M) & col_mask
x = tl.load(x_ptrs, mask=block_mask, other=0.0)
# optionally apply a fused bias
if USE_BIAS:
x += bias
# optional: fused activation (while the data is in shared memory)
if ACTIVATION == 1:
x = relu(x)
elif ACTIVATION == 2:
x = leaky_relu(x)
elif ACTIVATION == 3:
x = gelu(x)
elif ACTIVATION == 4:
x = squared_relu(x)
elif ACTIVATION == 5:
x = smelu(x)
# get the random keep mask
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
r = tl.rand(seed_int, rand_offsets)
keep_mask = r > p
# prune and normalize in one go
keep = tl.view(keep_mask, x.shape)
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)
tl.store(y_ptrs, output, mask=block_mask) # output
# fmt: off
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
@triton.autotune(
configs=_configs,
key=["M", "N", "is_fp16"],
)
@triton.jit
def k_dropout_bw(
GRAD_IN, GRAD_BIAS, GRAD_OUT,
INPUTS, BIAS, SEEDS,
stride_grad, stride_inputs,
M, N,
p: tl.constexpr,
is_fp16: tl.constexpr, # autotune
ACTIVATION: tl.constexpr,
# Meta-parameters
BLOCK_M: tl.constexpr, # heuristics
BLOCK_N: tl.constexpr,
SIZE_RAND_BLOCK: tl.constexpr,
TRAINABLE_BIAS: tl.constexpr,
USE_BIAS: tl.constexpr,
):
"""
Apply dropout on an input tensor
GRAD_OUT (M, N)
GRAD_BIAS (N,)
GRAD_IN (M, N)
BIAS (N,)
SEEDS (N,)
p : dropout probability
"""
# fmt: on
row_id = tl.program_id(axis=0)
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
col_id = tl.program_id(axis=1)
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
# pointers starting point
grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]
grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]
input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]
# now go over the tiles
grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
col_mask = cols[None, :] < N
p_scale = 1. / (1. - p)
if USE_BIAS:
b_ptrs = BIAS + cols[None, :]
bias = tl.load(b_ptrs, mask=col_mask, other=0.)
block_mask = (rows[:, None] < M) & col_mask
grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)
# optional: fused activation (while the data is in shared memory)
if ACTIVATION:
inputs = tl.load(input_ptrs, mask=block_mask, other=0.)
# optionally apply a fused bias
if USE_BIAS:
inputs += bias
if ACTIVATION == 1:
act_grad = relu_grad(inputs)
elif ACTIVATION == 2:
act_grad = leaky_relu_grad(inputs)
elif ACTIVATION == 3:
act_grad = gelu_grad(inputs)
elif ACTIVATION == 4:
act_grad = squared_relu_grad(inputs)
elif ACTIVATION == 5:
act_grad = smelu_grad(inputs)
grad_out *= act_grad
# randomly prune (and scale) the resulting buffer, possibly a no-op
# note that even if we did not save the mask from the FW pass, it is generated
# from the same seeds, so the same drop mask is applied here
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
seed_int = tl.load(SEEDS + col_id)
r = tl.rand(seed_int, rand_offsets)
r = tl.view(r, grad_out.shape)
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)
# write-back
tl.store(grad_in_ptrs, output, mask=block_mask)
# optionally accumulate the bias gradient
if TRAINABLE_BIAS:
grad_bias += tl.sum(output, axis=0)
if TRAINABLE_BIAS:
grad_bias_ptr = GRAD_BIAS + row_id * N + cols
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)

View File

@@ -0,0 +1,161 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import triton
import triton.language as tl
from xformers.triton.k_activations import (
gelu_grad,
leaky_relu_grad,
relu_grad,
smelu_grad,
squared_relu_grad,
star_relu_grad,
)
# fmt: off
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2),
triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4),
],
key=["N"],
)
@triton.heuristics({
'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0,
})
@triton.jit
def kernel_bw(
# Pointers to matrices
GRAD_ACT, GRAD_OUT, ACT_INPUTS,
# Matrix dimensions
N,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_gom, stride_aim,
# Meta-parameters
BLOCK_N: tl.constexpr,
EVEN_N: tl.constexpr,
ACTIVATION_GRAD: tl.constexpr,
):
# fmt: on
"""
Go over all the activation inputs, compute the corresponding gradient
"""
# this kernel is relatively simple in terms of scheduling:
# - per row (pid_m)
# - each program a given chunk on the col axis,
# since it's more effective memory and occupancy wise
pid_m, pid_n = tl.program_id(axis=0), tl.program_id(axis=1)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# the memory addresses of elements in the first block of
# A and W can be computed using numpy-style broadcasting
act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn
# compute the gradient which is related to this activation
if EVEN_N:
act_in = tl.load(act_input_ptrs)
else:
act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0)
if ACTIVATION_GRAD == 1:
grad_act = relu_grad(act_in)
elif ACTIVATION_GRAD == 2:
grad_act = leaky_relu_grad(act_in)
elif ACTIVATION_GRAD == 3:
grad_act = gelu_grad(act_in)
elif ACTIVATION_GRAD == 4:
grad_act = squared_relu_grad(act_in)
elif ACTIVATION_GRAD == 5:
grad_act = smelu_grad(act_in)
elif ACTIVATION_GRAD == 6:
grad_act = star_relu_grad(act_in)
else:
grad_act = act_in
# now read the incoming gradient, the backpropagated one is the multiple of both
grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn
if EVEN_N:
grad_out = tl.load(grad_out_ptrs)
else:
grad_out = tl.load(grad_out_ptrs, mask=rn < N)
grad_act *= grad_out
# write back result
grad_act_ptrs = GRAD_ACT + pid_m * stride_gom + rn
tl.store(grad_act_ptrs, grad_act, mask=rn < N)
def fused_matmul_backward(
grad_out: torch.Tensor,
inputs: torch.Tensor,
act_in: Optional[torch.Tensor],
weight: torch.Tensor,
trainable_weight: bool,
trainable_bias: bool,
activation_grad: int = 0,
):
"""
Compute grad_in = activation^-1(grad_out) @ weight.transpose()
.. note: The weight buffer is transposed on the fly
.. note: Activation gradient needs to be a Triton kernel
"""
# Make sure that we don't have to handle the stride over cols
if not grad_out.is_contiguous():
grad_out = grad_out.contiguous()
grad_out_ = grad_out if grad_out.ndim == 2 else grad_out.flatten(0, -2)
inputs_ = inputs if inputs.ndim == 2 else inputs.flatten(0, -2)
assert grad_out_.shape[1] == weight.shape[0], "Incompatible dimensions in between grad_out and weight"
M, N = grad_out_.shape
N, _ = weight.shape
# Compute the gradient for the activation
if activation_grad > 0:
grad_act = torch.empty_like(grad_out_)
# Some activations do not require their inputs to
# know of their grad, the downstream grad is enough
if act_in is None:
act_in = grad_out_
grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) # noqa
# fmt: off
kernel_bw[grid](
grad_act, grad_out_, act_in, # data ptrs
N, # shapes
grad_act.stride(0), act_in.stride(0), # strides
ACTIVATION_GRAD=activation_grad, # optional fused activation
)
# fmt: on
# Backpropagation going up, the reference gradient is now
# just before the activation
grad_out_ = grad_act
# The following ops can also be handled by pytorch
grad_in = triton.ops.matmul(grad_out_, weight)
grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None
grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None
return grad_in.reshape_as(inputs), grad_weight, grad_bias

View File

@@ -0,0 +1,253 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import triton
import triton.language as tl
from xformers.triton.k_activations import (
gelu,
leaky_relu,
relu,
smelu,
squared_relu,
star_relu,
)
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def get_configs(block_k):
return [
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=4,
num_warps=2,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
triton.Config(
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k},
num_stages=3,
num_warps=4,
),
# Fails on small GPUS
# triton.Config(
# {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k},
# num_stages=3,
# num_warps=8,
# ),
]
# fmt: off
@triton.autotune(
configs=[c for block_k in [32, 64] for c in get_configs(block_k)],
key=["M", "N", "K"],
)
@triton.heuristics({
'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0,
})
@triton.jit
def kernel_fma(
# Pointers to matrices
OUT, ACT_INPUTS, INPUT, WEIGHT, bias,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_om, stride_im,
stride_wn,
# Meta-parameters
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_N: tl.constexpr,
BIAS: tl.constexpr,
SAVE_ACT_INPUTS: tl.constexpr,
ACTIVATION: tl.constexpr,
is_fp16: tl.constexpr, # autotune
):
# fmt: on
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
# programs are grouped together to improve L2 hit rate
# the logic is that we'll consolidate over K. If the programs were not grouped,
# then multiple cols/rows in the result would end up pulling in the same row and lines
# from the inputs. By grouping the computation we ensure some data reuse, which the hardware
# covers via the L2 cache
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis
num_pid_n = tl.cdiv(N, BLOCK_N) # number of programs ids along the N axis
num_pid_in_group = GROUP_M * num_pid_n # number of programs in group
group_id = pid // num_pid_in_group # id of the group this program is in
first_pid_m = group_id * GROUP_M # row-id of the first program in the group
GROUP_M = min(
num_pid_m - first_pid_m, GROUP_M
) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller
# *within groups*, programs are ordered in a column-major order
# row-id /col-id of the program in the *launch grid*
pid_m = first_pid_m + (pid % GROUP_M)
pid_n = (pid % num_pid_in_group) // GROUP_M
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# the memory addresses of elements can follow numpy broadcasting
input_ptrs = INPUT + rm[:, None] * stride_im
weight_ptrs = WEIGHT + rn[None, :] * stride_wn
# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
if BIAS:
if EVEN_N:
bias = tl.load(bias + rn).to(tl.float32)
else:
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
acc += bias[None, :]
# block level matrix multiplication.
# We fetch a block memory block from both inputs, matmul and accumulate, then repeat
mask_rn = rn < N
mask_rm = rm < M
for i in range(0, K, BLOCK_K):
rk = tl.arange(0, BLOCK_K) + i
a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0)
w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0)
acc += tl.dot(a, w)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# optional: save the activation inputs
if SAVE_ACT_INPUTS:
act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]
tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])
# optional: fused activation (while the data is in shared memory)
if ACTIVATION == 1:
acc = relu(acc)
elif ACTIVATION == 2:
acc = leaky_relu(acc)
elif ACTIVATION == 3:
acc = gelu(acc)
elif ACTIVATION == 4:
acc = squared_relu(acc)
elif ACTIVATION == 5:
acc = smelu(acc)
elif ACTIVATION == 6:
acc = star_relu(acc)
# write back result
out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :]
tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])
# Activation needs to be a triton kernel
def fused_matmul(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
activation=0,
save_act_inputs: bool = False
):
"""
Compute e = activation(x @ weight + bias).
This wrapper kicks the `kernel_fma` Triton kernel
"""
if not x.is_contiguous():
x = x.contiguous()
x_ = x if x.ndim == 2 else x.flatten(0, -2)
assert (
x_.shape[1] == weight.shape[1]
), f"Incompatible dimensions in between inputs and weight, {x_.shape} - {weight.shape}"
assert bias is None or bias.is_contiguous()
assert (
bias is None or bias.shape[0] == weight.shape[0]
), "Incompatible dimensions in between weight and bias"
assert weight.is_contiguous()
M, K = x_.shape
N, K = weight.shape
outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)
act_inputs = torch.empty_like(outputs) if save_act_inputs else x # will not be used in that case
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
# fmt: off
kernel_fma[grid](
outputs, act_inputs, x_, weight, # data ptrs
bias if bias is not None else x, # auto skip bias if not present
M, N, K, # shapes
outputs.stride(0), x_.stride(0), # strides
weight.stride(0),
ACTIVATION=activation, # optional fused activation
BIAS=bias is not None, # optional fused bias
GROUP_M=8, # speed optimization: group the programs
SAVE_ACT_INPUTS=save_act_inputs,
is_fp16=x_.dtype == torch.float16
)
# fmt: on
outputs = outputs if x.ndim == 2 else outputs.reshape(*x.shape[:-1], N)
return outputs, act_inputs if save_act_inputs else None

View File

@@ -0,0 +1,179 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: This comes almost as-is from the Triton layer norm tutorial
# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
import triton
import triton.language as tl
# fmt: off
@triton.jit
def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# fmt: on
"""
Fused layernorm kernel over a 3d tensor.
The layer norm is applied over the last dimension.
Compute
y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta
"""
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# Move to this row
x_ptrs = X + row * stride + cols
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
# Compute mean and variance
mean = tl.sum(x, axis=0) / N
x_zm = tl.where(mask, x - mean, 0.0)
tl.store(M + row, mean)
x_var = tl.sum(x_zm * x_zm, axis=0) / N
rstd = 1.0 / tl.sqrt(x_var + eps)
# Normalize, optionally affine
y = x_zm * rstd
tl.store(V + row, rstd)
mask = cols < N
if affine:
w = tl.load(W + cols, mask=mask, other=1.0)
b = tl.load(B + cols, mask=mask, other=0.0)
y = y * w + b
y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=mask)
# Backward pass (DX + partial DW + partial DB)
# fmt: off
@triton.jit
def layer_norm_bwd_dx_fused(
DX, DY, DW, DB,
X, W, M, V,
Lock, stride, N,
# META-parameters
affine: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
# fmt: on
# position of elements processed by this program
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
# offset data pointers to start at the row of interest
x_ptrs = X + row * stride + cols
dy_ptrs = DY + row * stride + cols
# load data to SRAM
x = tl.load(x_ptrs, mask=mask, other=0)
dy = tl.load(dy_ptrs, mask=mask, other=0)
mean = tl.load(M + row)
rstd = tl.load(V + row)
# compute dx
xhat = (x - mean) * rstd
if affine:
w = tl.load(W + cols, mask=mask, other=0)
wdy = w * dy
else:
wdy = dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
mean1 = tl.sum(xhat * wdy, axis=0) / N
mean2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * mean1 + mean2)) * rstd
# write-back dx
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N # re-materialize the mask to save registers
dx_ptrs = DX + row * stride + cols
tl.store(dx_ptrs, dx, mask=mask)
if affine:
# accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = dy.to(w.dtype)
# offset locks and weight/bias gradient pointer
# each kernel instance accumulates partial sums for
# DW and DB into one of GROUP_SIZE_M independent buffers
# these buffers stay in the L2, which allow this kernel
# to be fast
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
# - wait for a lock on the accumulated dw/db
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# - we got the lock, accumulate this kernel's results with
# the stored values.
dw_ptrs = DW + lock_id * N + cols
db_ptrs = DB + lock_id * N + cols
if count == 0:
# first store doesn't accumulate
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(dw_ptrs, mask=mask, other=0.)
partial_db += tl.load(db_ptrs, mask=mask, other=0.)
tl.store(dw_ptrs, partial_dw, mask=mask)
tl.store(db_ptrs, partial_db, mask=mask)
# release lock
tl.atomic_xchg(Lock, 0)
# Backward pass (total DW + total DB)
# fmt: off
@triton.jit
def layer_norm_bwd_dwdb(
DW, DB, FINAL_DW, FINAL_DB,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr
):
# fmt: on
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_cols = cols < N
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
offs = rows[:, None] * N + cols[None, :]
mask_rm = rows < M
dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_cols = cols < N
tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)
tl.store(FINAL_DB + cols, sum_db, mask=mask_cols)

View File

@@ -0,0 +1,175 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import triton
import triton.language as tl
# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/
# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html
def get_depth(args):
return triton.next_power_of_2(args["K"])
# autotune: Triton will test out these configurations, and automatically pick the fastest one.
# heuristic: add arguments to the kernel call automatically given some heuristics. These arguments are passed in "meta"
# fmt: off
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
triton.Config({}, num_warps=32),
],
key=["K"],
)
@triton.heuristics(values={"depth": get_depth})
@triton.jit
def _softmax(
Y, X, M,
stride_ym, stride_yn,
stride_xm, stride_xn,
stride_mn,
K,
# Meta-params
depth: tl.constexpr,
causal: tl.constexpr,
use_mask: tl.constexpr,
log: tl.constexpr,
):
# fmt: om
"""
Fused softmax kernel over a 3d tensor.
The softmax is applied over the last dimension, meaning that this is equivalent to torch.softmax(tensor, dim=-1)
Note, if the last dimension is large, say 128K elements, the kernel compile time can shot up to many minutes when
the kernel is run for the first time.
"""
m = tl.program_id(0)
n = tl.program_id(1)
# col indices
k = tl.arange(0, depth)
# the memory address of all the elements that we want to load can be computed as follows
x_ptrs = X + m * stride_xm + n * stride_xn + k
# load input data; pad out-of-bounds elements with 0
io_mask = k < K
# Causal - 1: skip on the loads directly
if causal:
io_mask = io_mask & (k <= n)
x = tl.load(x_ptrs, mask=io_mask, other=float("-inf")).to(tl.float32)
# Causal - 2: enforce correctness over a couple of misloaded values
if causal:
off = float("-inf")
off = off.to(x.dtype) # type: ignore
x = tl.where(k > n, off, x)
if use_mask:
mask_ptrs = M + n * stride_mn + k
add_mask = tl.load(mask_ptrs, io_mask, other=float("-inf")).to(tl.float32)
x += add_mask
# compute numerically-stable softmax
z = x - tl.max(x, axis=0)
num = tl.exp(z)
denom = tl.sum(num, axis=0)
if log:
y = z - tl.log(denom)
else:
y = num / denom
# write back to Y.
# we only write once, hence the "fused" softmax naming
y_ptrs = Y + m * stride_ym + n * stride_yn + k
# technically we could write only the lower triangular matrix in the causal case
# but this is deemed to error prone
tl.store(y_ptrs, y, mask=k < K)
# fmt: off
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
triton.Config({}, num_warps=16),
],
key=["K"],
)
@triton.jit
def _softmax_backward(
GradIn, GradOut, Out,
stride_bm, stride_bn,
stride_gm, stride_gn,
stride_om, stride_on,
K,
# meta-params
depth: tl.constexpr,
causal: tl.constexpr,
log: tl.constexpr,
):
# fmt: on
"""
Compute the softmax gradients.
..Note: Not autotuning for now because this would lead to broken accumulated gradients
"""
m = tl.program_id(0)
n = tl.program_id(1)
# col indices
k = tl.arange(0, depth)
# the memory address of all the elements that we want to load can be computed as follows
grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k
out_ptrs = Out + m * stride_om + n * stride_on + k
# load input data; pad out-of-bounds elements with 0
io_mask = k < K
# Causal - 1: skip on the loads directly
if causal:
io_mask = io_mask & (k <= n)
g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0)).to(tl.float32)
o = tl.load(out_ptrs, mask=io_mask, other=float(0)).to(tl.float32)
# Causal - 2: enforce correctness over a couple of misloaded values
if causal:
zero = float(0)
zero = zero.to(g.dtype) # type: ignore
g = tl.where(k > n, zero, g)
o = tl.where(k > n, zero, o)
if log:
s = tl.sum(g, 0)
grad_in = g - tl.exp(o) * s
else:
# Step 1: Compute the intermediate sum used for the gradient
s = tl.sum(g * o, 0)
# Step 2: Compute the gradients
grad_in = o * (g - s)
# write back to the input gradients
# technically we could write only the lower triangular matrix in the causal case
# but this is deemed to error prone
grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k
tl.store(grad_in_ptrs, grad_in, mask=k < K)

View File

@@ -0,0 +1,55 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import triton
import triton.language as tl
# fmt: off
@triton.jit
def k_sum_0(
Y, X,
stride_xm,
M, N,
is_fp16,
# META-params
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# fmt: om
"""
Sum a 2d tensor over the first (strided) dimension.
This extracts some speed through a parallel sum across the second dimension
"""
# partial row indices. We'll reduce over this dimension
m = tl.arange(0, BLOCK_M)
# To get some extra parallelization, we handle several columns in the same thread block
rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N)
# the memory address of all the elements that we want to load can be computed as follows
x_ptrs = X + m[:, None] * stride_xm + rn[None, :]
x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32)
tiles = M // BLOCK_M
if M % BLOCK_M > 0:
tiles += 1
col_mask = (rn[None, :] < N)
for _ in range(tiles):
# load input data; pad out-of-bounds elements with 0
# NOTE: make sure to accumulate in fp32 to prevent a trivial overflow
mask = (m[:, None] < M) & col_mask
x = tl.load(x_ptrs, mask=mask, other=0.0)
x_sum += tl.sum(x, 0)
# move the load pointer
x_ptrs += BLOCK_M * stride_xm
m += BLOCK_M # update the mask check
tl.store(Y + rn, x_sum, mask=rn < N)

View File

@@ -0,0 +1,236 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: the underlying kernel comes straight from the Triton tutorials
# see https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
import logging
from typing import Optional
import torch
import torch.nn as nn
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.triton.k_layer_norm import (
layer_norm_bwd_dwdb,
layer_norm_bwd_dx_fused,
layer_norm_fw,
)
logger = logging.getLogger("xformers")
_triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32
_triton_registered_warnings = False
class _LayerNorm(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None)
def forward(ctx, x, weight, bias, eps):
# catch eps being too small if the tensors are fp16
if x.dtype == torch.float16:
eps = max(eps, 1.6e-5)
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# allocate mean and std, they'll be used in the backward pass
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
if not x_arg.is_contiguous() or not y.is_contiguous():
global _triton_registered_warnings
if not _triton_registered_warnings:
logger.warning(
"Non-contiguous input tensor found. Making it contiguous,"
+ " but could have perf or trainer implications"
)
_triton_registered_warnings = True
x_arg = x_arg.contiguous()
y = y.contiguous()
# heuristics for number of warps.
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)
# enqueue kernel
# fmt: off
layer_norm_fw[(M,)](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N,
affine=weight is not None
)
# fmt: on
ctx.save_for_backward(x, mean, rstd, weight)
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
ctx.num_warps = num_warps
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, dy
): # pragma: no cover # this is covered, but called directly from C++
x, mean, rstd, weight = ctx.saved_tensors
# flatten the batch dimension, if any.
# We're interested in 'samples' x norm_dimension
x = x.reshape(-1, x.size(-1))
M, N = x.size()
# heuristics for amount of parallel reduction stream for DG/DB
GROUP_SIZE_M = 32
if N <= 8192:
GROUP_SIZE_M = 64
if N <= 4096:
GROUP_SIZE_M = 96
if N <= 2048:
GROUP_SIZE_M = 128
if N <= 1024:
GROUP_SIZE_M = 256
if dy.dtype == torch.float32:
GROUP_SIZE_M = GROUP_SIZE_M // 2
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
t_args = {"dtype": x.dtype, "device": x.device}
_dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args)
_db = torch.empty_like(_dw)
dw = torch.empty((x.size(-1),), **t_args)
db = torch.empty_like(dw)
dy = dy.contiguous()
dx = torch.empty_like(dy)
# Check the tensor shapes and layouts
# we suppose in the kernel that they have the same size and are contiguous
assert (
dy.numel() == x.numel()
), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)
# fmt: off
layer_norm_bwd_dx_fused[(M,)](
dx, dy, _dw, _db, x,
weight if weight is not None else x,
mean, rstd,
locks,
x.stride(0),
N,
affine=weight is not None,
GROUP_SIZE_M=GROUP_SIZE_M,
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
num_warps=num_warps
)
# fmt: on
def grid(meta):
return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
# accumulate partial sums in separate kernel
# fmt: off
layer_norm_bwd_dwdb[grid](
_dw, _db, dw, db,
GROUP_SIZE_M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=64
)
# fmt: on
dx = dx.reshape_as(dy)
return dx, dw, db, None
class FusedLayerNorm(nn.Module):
"""
Handle a layer normalization, like torch.nn.LayerNorm_.
This implementation should be measurably faster than the default PyTorch layernorm (as of PyTorch 1.9),
both for training and inference worloads.
.. NOTE: Computations under Torch AMP are kept as float32 by default, one can change this to be float16
by setting the flag `xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True`
.. _torch.nn.LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
"""
def __init__(self, normalized_shape, affine=True, eps=1e-06):
super().__init__()
if affine:
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
else:
self.weight = self.bias = None
self.epsilon = eps
def forward(self, x):
return layer_norm(x, self.weight, self.bias, self.epsilon)
def init_weights(self, *args, **kwargs):
with torch.no_grad():
if self.weight is not None:
self.weight.fill_(1.0)
if self.bias is not None:
self.bias.fill_(0.0)
def layer_norm(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-06,
) -> torch.Tensor:
global _triton_registered_warnings
r"""Applies normalization over a mini batch of inputs"""
try:
if (
not _triton_registered_warnings
and torch.cuda.is_available()
and x.is_cuda
and weight is not None
and bias is not None
):
return _LayerNorm.apply(x, weight, bias, eps)
except RuntimeError as e:
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
# fallback to PyTorch's implementation, which streams the tensor in and out
_triton_registered_warnings = True
logger.warning(
"Triton layernorm kernel register spillover or invalid image caught. "
"Deactivating this kernel, please file an issue in the xFormers repository"
)
logger.warning(e)
return torch.nn.functional.layer_norm(
x, [x.shape[-1]], weight=weight, bias=bias, eps=eps
)

View File

@@ -0,0 +1,203 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Optional
import torch
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.triton.k_softmax import _softmax, _softmax_backward
# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/
# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html
logger = logging.getLogger("xformers")
_triton_softmax_fp16_enabled = False # NOTE: PyTorch keeps softmax as fp32
_triton_registered_warnings = False
# Helper to handle the SPMD launch grid and error cases
class _softmax_triton(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16 if _triton_softmax_fp16_enabled else None)
def forward(ctx, x, mask, log_outputs, causal):
"""
Fused softmax implementation, using the Triton programming model.
This only supports a reduction over the last dimension for now
"""
# Handle 2D/3D tensors
x_ = x.unsqueeze(0) if x.ndim == 2 else x
x_ = x_.flatten(0, -3)
if not x_.is_contiguous():
x_ = x_.contiguous()
y = torch.empty_like(x_)
assert (
y.stride(2) == 1 and x_.stride(2) == 1
), f"{x.shape} - {x_.shape} - {x_.stride()}"
# SPMD launch grid
grid_2d = (
x_.shape[0],
x_.shape[1],
)
# enqueue GPU kernel
use_mask = True
if mask is None:
# placeholder, will not be used
mask = x_
use_mask = False
else:
# Make sure that the mask is binary
assert mask.dtype == x.dtype, "An additive mask is requested"
_softmax[grid_2d](
y,
x_,
mask,
y.stride(0),
y.stride(1),
x_.stride(0),
x_.stride(1),
mask.stride(0),
x_.shape[2],
log=log_outputs,
use_mask=use_mask,
causal=causal,
)
ctx.save_for_backward(y)
ctx.log_outputs = log_outputs
ctx.causal = causal
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, grad_out
): # pragma: no cover # this is covered, but called directly from C++
(out,) = ctx.saved_tensors
# Handle 2D/3D tensors
grad_out_ = grad_out.unsqueeze(0) if grad_out.ndim == 2 else grad_out
grad_out_ = grad_out_.flatten(0, -3)
# SPMD launch grid
grid_2d = (
grad_out_.shape[0],
grad_out_.shape[1],
)
depth = triton.next_power_of_2(grad_out_.shape[2])
grad_in = torch.empty_like(
out
) # torch.zeros is measurably slower, we'll zero out in the kernel
# Make sure that the tensor are contiguous
grad_in, grad_out, out = map(lambda x: x.contiguous(), [grad_in, grad_out, out])
# fmt: off
_softmax_backward[grid_2d](
grad_in, grad_out_, out,
grad_in.stride(0), grad_in.stride(1),
grad_out_.stride(0), grad_out_.stride(1),
out.stride(0), out.stride(1),
out.shape[2],
depth=depth,
log=ctx.log_outputs,
causal=ctx.causal
)
# fmt: on
return grad_in.reshape_as(grad_out), None, None, None
def softmax(
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
) -> torch.Tensor:
r"""Applies the Softmax function to an 3-dimensional input Tensor
rescaling them so that the elements of the n-dimensional output Tensor
lie in the range [0,1] and sum to 1.
Softmax is defined as:
.. math::
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
.. warning: softmax is computed on the last dimension of the input tensor.
Args:
x: input tensor.
mask: optional mask, its application will be fused to the softmax computation if triton is used
causal: optional performance optimization, if triton is used and the attention is causal
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [0, 1] and sum to 1
"""
return _softmax_dispatch(x, log=False, mask=mask, causal=causal)
def log_softmax(
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
) -> torch.Tensor:
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an 3-dimensional
input Tensor. The LogSoftmax formulation can be simplified as:
.. math::
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
Args:
x: input tensor.
Returns:
a Tensor of the same dimension and shape as the input with
values in the range [-inf, 0)
"""
return _softmax_dispatch(x, log=True, mask=mask, causal=causal)
def _softmax_dispatch(
x: torch.Tensor, log: bool, mask: Optional[torch.Tensor], causal: bool = False
) -> torch.Tensor:
# Triton is used if
# - CUDA
# - there's enough data to make it faster than pytorch. This could change over time, Triton is improving
# - there was no previous failure
global _triton_registered_warnings
try:
if torch.cuda.is_available() and x.is_cuda and not _triton_registered_warnings:
return _softmax_triton.apply(x, mask, log, causal)
except RuntimeError as e:
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
# fallback to PyTorch's implementation, which streams the tensor in and out
_triton_registered_warnings = True
logger.warning(
"Triton softmax kernel register spillover or invalid image caught."
"Deactivating this kernel, please file an issue int the xFormers repository"
)
logger.warning(e)
if mask is not None:
x = x + mask
if causal:
x = x + torch.triu(torch.full_like(x, float("-inf")), diagonal=1)
if log:
return torch.log_softmax(x, dim=-1)
else:
return torch.softmax(x, dim=-1)

View File

@@ -0,0 +1,60 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
from xformers.triton.k_sum import k_sum_0
def sum_2d_dim_0(x: torch.Tensor):
"""
Sum a 2D tensor across the first dimension
"""
out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype)
assert (
x.ndim == 2
), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0"
M, N = x.shape
# This kernel is not competitive for these sizes
if M > 2048 or M < 8:
return x.sum(dim=0)
assert (
M >= 4
), "This is a very specific kernel, requires the reduction dimension to be bigger than 4"
assert x.stride(1) == 1, (
"We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n"
" You would probably be better served with torch.sum()"
)
BLOCK_M = min(triton.next_power_of_2(M), 2048)
BLOCK_N = 32
if BLOCK_M > 256:
BLOCK_N = 16
if BLOCK_M > 1024:
BLOCK_N = 8
def grid(meta):
return (triton.cdiv(N, meta["BLOCK_N"]),)
# fmt: off
k_sum_0[grid](
out, x,
x.stride(0),
M, N,
x.dtype == torch.float16,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_stages=4,
)
# fmt: on
return out

View File

@@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Optional
import torch
logger = logging.getLogger("xformers")
_gpu_is_old: Optional[bool] = None
def gpu_capabilities_older_than_70() -> bool:
"""Return True if the GPU's compute capability is older than SM70."""
global _gpu_is_old
if _gpu_is_old is None:
for i in range(torch.cuda.device_count()):
major, _ = torch.cuda.get_device_capability(f"cuda:{i}")
if major < 7:
_gpu_is_old = True
if _gpu_is_old is None:
_gpu_is_old = False
return _gpu_is_old
SUPPORTED_CUDA_DEVICES = ["V100", "A100", "T4"]
def get_current_cuda_device():
current_device = str(torch.cuda.get_device_properties(torch.cuda.current_device()))
for device_str in SUPPORTED_CUDA_DEVICES:
if current_device.find(device_str) > 0:
return device_str
logger.warning("Unsupported device, Triton code generation may fail")
return "P100" # default to an old GPU

View File

@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import ast
import copy
import functools
import linecache
import sys
from typing import Any, Dict, List
import triton
class _ForLoopUnroller(ast.NodeTransformer):
def __init__(self, target, inline_variables, loop_iter):
self.loop_iter = loop_iter
self.target = target
self.inline_variables = inline_variables
def visit_Name(self, node):
if node.id != self.target:
return node
return ast.Name(str(self.loop_iter))
def visit_Subscript(self, node):
# Pattern-matching `value[slice]`
if (
isinstance(node.slice, ast.Name)
and node.slice.id == self.target
and isinstance(node.value, ast.Name)
and node.value.id in self.inline_variables
):
return ast.Name(f"{node.value.id}{self.loop_iter}")
return node
class _VisitorUnrollKernel(ast.NodeTransformer):
def __init__(self, N):
self.inline_variables = set()
self.N = N
def visit_AnnAssign(self, node):
# Pattern-matching:
# var_name: "VAR_ARGS_ARRAY"
if (
node.value is None
and node.simple == 1
and isinstance(node.target, ast.Name)
and isinstance(node.annotation, ast.Constant)
and node.annotation.value == "VAR_ARGS_ARRAY"
):
self.inline_variables.add(node.target.id)
return []
if node.value is not None:
node.value = self.visit(node.value)
if node.annotation is not None:
node.annotation = self.visit(node.annotation)
if node.target is not None:
node.target = self.visit(node.target)
return node
def visit_arguments(self, node):
# Replace `args` annotated with `VAR_ARGS_ARRAY`
new_args = []
for arg in node.args:
if (
arg.annotation is not None
and isinstance(arg.annotation, ast.Constant)
and arg.annotation.value == "VAR_ARGS_ARRAY"
):
self.inline_variables.add(arg.arg)
new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
continue
new_args.append(arg)
if node.vararg is not None:
self.inline_variables.add(node.vararg.arg)
new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
node.vararg = None
new_args += node.kwonlyargs
node.kwonlyargs = []
node.args = new_args
return node
def visit_For(self, node):
if (
not isinstance(node.iter, ast.Call)
or node.iter.func.id != "range"
or len(node.iter.args) != 1
or not isinstance(node.iter.args[0], ast.Call)
or node.iter.args[0].func.id != "len"
or len(node.iter.args[0].args) != 1
or node.iter.args[0].args[0].id not in self.inline_variables
):
node.body = [self.visit(x) for x in node.body]
return node
# We know we have to modify this loop
new_nodes = []
for i in range(self.N):
unroller = _ForLoopUnroller(
target=node.target.id,
inline_variables=self.inline_variables,
loop_iter=i,
)
for body in node.body:
body = copy.deepcopy(body)
new_node = ast.fix_missing_locations(unroller.visit(body))
new_node = self.visit(new_node)
new_nodes.append(new_node)
return new_nodes
# Hackfix to get access to get source-code for
# `exec`-created functions - see https://stackoverflow.com/a/69668999
_getlines_orig = None
_FILENAME_TO_SRC: Dict[str, str] = {}
def _monkey_patched_getlines(filename, module_globals=None):
if filename in _FILENAME_TO_SRC:
return _FILENAME_TO_SRC[filename]
else:
return _getlines_orig(filename, module_globals) # type: ignore
@functools.lru_cache(None)
def unroll_varargs(kernel, N: int):
"""
Specializes a triton kernel with variable number of inputs
to a specific number of inputs `N`.
NOTE: Because it's quite costly to call `triton.jit`,
we cache the returned value with `lru_cache`
"""
global _FILENAME_TO_SRC, _getlines_orig
k = triton.JITFunction(kernel.fn)
parsed = ast.parse(k.src)
nodeVisitor = _VisitorUnrollKernel(N=N)
parsed = nodeVisitor.visit(parsed)
parsed = ast.fix_missing_locations(parsed)
# NOTE: `ast.unparse` requires python 3.9+
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
raise RuntimeError("Error: This functionality requires python 3.9 or above")
new_src = ast.unparse(parsed) # type: ignore
# Now we want to `eval` the function, but we need all this
# boilerplate code to make sure triton can run `inspect.getsource`
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"
# Create function given source
code = compile(new_src, fn_filename, "exec")
_locals: Dict[str, Any] = {}
exec(code, kernel.fn.__globals__, _locals)
assert len(_locals) == 1, len(_locals)
fn = next(iter(_locals.values()))
# Patch `getlines` only the first time
if not _FILENAME_TO_SRC:
_getlines_orig = linecache.getlines
linecache.getlines = _monkey_patched_getlines
_FILENAME_TO_SRC[fn_filename] = new_src
jitted_fn = triton.jit(fn)
jitted_fn.src = new_src
return jitted_fn
# Note: just import this to make mypy happy
# when annotating variables with `VAR_ARGS_ARRAY`
VAR_ARGS_ARRAY = List[Any]