First commit
This commit is contained in:
27
pkgs/xformers/triton/__init__.py
Normal file
27
pkgs/xformers/triton/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
_triton_available = torch.cuda.is_available()
|
||||
if _triton_available:
|
||||
try:
|
||||
from .dropout import FusedDropoutBias, dropout # noqa
|
||||
from .fused_linear_layer import FusedLinear # noqa
|
||||
from .layer_norm import FusedLayerNorm, layer_norm # noqa
|
||||
from .softmax import log_softmax, softmax # noqa
|
||||
|
||||
__all__ = [
|
||||
"dropout",
|
||||
"softmax",
|
||||
"log_softmax",
|
||||
"FusedDropoutBias",
|
||||
"FusedLinear",
|
||||
"FusedLayerNorm",
|
||||
"layer_norm",
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = []
|
||||
BIN
pkgs/xformers/triton/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/dropout.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/dropout.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/k_activations.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/k_activations.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/k_dropout.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/k_dropout.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/k_softmax.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/k_softmax.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/k_sum.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/k_sum.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/layer_norm.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/layer_norm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/softmax.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/softmax.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/sum_strided.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/sum_strided.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/utils.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc
Normal file
BIN
pkgs/xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc
Normal file
Binary file not shown.
243
pkgs/xformers/triton/dropout.py
Normal file
243
pkgs/xformers/triton/dropout.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: This is heavily inspired by the Triton dropout tutorial
|
||||
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from xformers.components.activations import Activation, build_activation
|
||||
from xformers.triton.k_activations import get_triton_activation_index
|
||||
from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw
|
||||
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf
|
||||
|
||||
|
||||
# Helper to handle the SPMD launch grid and error cases
|
||||
class _dropout(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, x, p, bias, activation, trainable_bias):
|
||||
# Soft-flatten an hypothetical 3rd dimension
|
||||
x_ = x.reshape(-1, x.shape[-1]).contiguous()
|
||||
y = torch.empty_like(x_)
|
||||
M, N = x_.shape
|
||||
|
||||
assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N)
|
||||
assert p > 0.0
|
||||
|
||||
def grid(meta):
|
||||
return (
|
||||
triton.cdiv(M, meta["BLOCK_M"]),
|
||||
triton.cdiv(N, meta["BLOCK_N"]),
|
||||
)
|
||||
|
||||
N_BLOCK_N = triton.cdiv(N, BLOCK_N)
|
||||
|
||||
# Generate one seed per sample
|
||||
# seed max is int32 max for positive numbers: 2**16
|
||||
seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32)
|
||||
|
||||
# fmt: off
|
||||
bias_ptr = bias if bias is not None else x_ # Possibly not being used
|
||||
|
||||
k_dropout_fw[grid](
|
||||
y, x_,
|
||||
bias_ptr,
|
||||
seeds,
|
||||
y.stride(0),
|
||||
M, N,
|
||||
p,
|
||||
x.dtype == torch.float16,
|
||||
USE_BIAS=bias is not None,
|
||||
ACTIVATION=activation,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
if activation is not None:
|
||||
ctx.save_for_backward(seeds, bias, x)
|
||||
else:
|
||||
ctx.save_for_backward(seeds, bias, None)
|
||||
|
||||
ctx.trainable_bias = bias is not None and trainable_bias
|
||||
ctx.activation = activation
|
||||
ctx.p = p
|
||||
|
||||
return y.reshape_as(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx, grad_out
|
||||
): # pragma: no cover # This is covered, but called from C++ and not tracked
|
||||
(seeds, bias, inputs) = ctx.saved_tensors
|
||||
|
||||
# Soft-flatten an hypothetical 3rd dimension
|
||||
grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous()
|
||||
grad_in = torch.empty_like(grad_out_)
|
||||
|
||||
M, N = grad_out_.shape
|
||||
|
||||
# Optional inputs to compute the activation contribution to the gradient
|
||||
assert inputs is not None or ctx.activation is None
|
||||
|
||||
if inputs is None:
|
||||
inputs = grad_out_
|
||||
elif inputs.ndim > 2:
|
||||
inputs = inputs.reshape(-1, N)
|
||||
|
||||
# We split the problem in tiles:
|
||||
# - over M there will be a follow up reduction
|
||||
# - over N we compromise in between trying to use as much memory paralellism as possible,
|
||||
# (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too
|
||||
# big because of register spilling
|
||||
N_BLOCKS_M = triton.cdiv(M, BLOCK_M)
|
||||
|
||||
if ctx.trainable_bias:
|
||||
grad_bias = torch.empty(
|
||||
(
|
||||
N_BLOCKS_M,
|
||||
N,
|
||||
),
|
||||
device=grad_in.device,
|
||||
dtype=grad_in.dtype,
|
||||
)
|
||||
|
||||
else:
|
||||
grad_bias = grad_in # will not be used
|
||||
|
||||
def grid(meta):
|
||||
# NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for
|
||||
# a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks
|
||||
# but we need to take this factor of 4 into account when scheduling all the kernels
|
||||
return (
|
||||
N_BLOCKS_M,
|
||||
triton.cdiv(N, meta["BLOCK_N"]),
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
k_dropout_bw[grid](
|
||||
grad_in, grad_bias, grad_out_,
|
||||
inputs, bias if bias is not None else inputs,
|
||||
seeds,
|
||||
grad_out_.stride(0), inputs.stride(0),
|
||||
M, N,
|
||||
ctx.p,
|
||||
grad_in.dtype == torch.float16,
|
||||
USE_BIAS=bias is not None,
|
||||
ACTIVATION=ctx.activation,
|
||||
TRAINABLE_BIAS=ctx.trainable_bias,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
return (
|
||||
grad_in.reshape_as(grad_out),
|
||||
None,
|
||||
torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def dropout(
|
||||
x: torch.Tensor,
|
||||
p: float,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[Activation] = None,
|
||||
):
|
||||
"""
|
||||
Apply dropout on the input tensor.
|
||||
Optionally add a bias, the computation will be fused.
|
||||
"""
|
||||
|
||||
assert p <= 1.0 and p >= 0.0
|
||||
|
||||
if p == 1.0:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
# Micro optim, skip dropout
|
||||
if p == 0.0:
|
||||
x = x + bias if bias is not None else x
|
||||
if activation is not None:
|
||||
activation_fn = build_activation(activation)
|
||||
return activation_fn(x)
|
||||
return x
|
||||
|
||||
# The normal triton enabled codepath
|
||||
activation_index = get_triton_activation_index(activation)
|
||||
return _dropout.apply(
|
||||
x,
|
||||
float(p),
|
||||
bias,
|
||||
activation_index,
|
||||
bias is not None and bias.requires_grad,
|
||||
)
|
||||
|
||||
|
||||
class FusedDropoutBias(torch.nn.Module):
|
||||
"""
|
||||
A layer which fuses the computation of Dropout(Activation(x))
|
||||
in a single GPU kernel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
p: float,
|
||||
bias_shape: Optional[int],
|
||||
activation: Optional[Activation] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.p = float(p)
|
||||
|
||||
assert (
|
||||
self.p < 1.0
|
||||
), f"We don't want to drop all the values, most probably p={p} is not properly set"
|
||||
|
||||
self.activation_type = activation
|
||||
self.bias = (
|
||||
torch.zeros(bias_shape, requires_grad=True)
|
||||
if bias_shape is not None
|
||||
else None
|
||||
)
|
||||
|
||||
self.activation = get_triton_activation_index(self.activation_type)
|
||||
self.activation_pytorch = build_activation(self.activation_type)
|
||||
|
||||
def init_weights(self, *args, **kwargs):
|
||||
with torch.no_grad():
|
||||
if self.bias is not None:
|
||||
self.bias.fill_(0.0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Convenience, catch a possible type or device mismatch
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore
|
||||
|
||||
# Train/inference
|
||||
p = self.p if self.training else 0.0
|
||||
|
||||
# This kernel is slower than pytorch for small buffers, bypassing it in that case
|
||||
perf_check = x.shape[-1] > 512
|
||||
|
||||
# Catch a non-cuda setup, fallback to pytorch
|
||||
if not x.is_cuda or not perf_check or p == 0.0:
|
||||
x = x + self.bias if self.bias is not None else x
|
||||
x = self.activation_pytorch(x)
|
||||
return torch.nn.functional.dropout(x, p) if p > 0.0 else x
|
||||
|
||||
# The normal, Triton-backed path
|
||||
return _dropout.apply(x, p, self.bias, self.activation, True)
|
||||
119
pkgs/xformers/triton/fused_linear_layer.py
Normal file
119
pkgs/xformers/triton/fused_linear_layer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from xformers.components.activations import Activation
|
||||
from xformers.triton.k_activations import get_triton_activation_index
|
||||
from xformers.triton.k_fused_matmul_bw import fused_matmul_backward
|
||||
from xformers.triton.k_fused_matmul_fw import fused_matmul
|
||||
|
||||
|
||||
class _fused_linear_triton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
activation,
|
||||
trainable_weight,
|
||||
trainable_bias,
|
||||
save_activation_inputs,
|
||||
):
|
||||
|
||||
# Kick the fused Triton kernel, handling bias and activation in one go
|
||||
y, activation_inputs = fused_matmul(
|
||||
x, weight, bias, activation, save_activation_inputs
|
||||
)
|
||||
|
||||
ctx.activation = activation
|
||||
ctx.trainable_weight = trainable_weight
|
||||
ctx.trainable_bias = trainable_bias
|
||||
|
||||
# Micro-optimization: saving these is not always needed (?)
|
||||
if x.requires_grad or ctx.trainable_weight or ctx.trainable_bias:
|
||||
ctx.save_for_backward(weight, activation_inputs, x)
|
||||
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx: Any, grad_out: torch.Tensor
|
||||
) -> Any: # pragma: no cover # this is covered, but called directly from C++
|
||||
"""
|
||||
Compute the derivative with respect to x, other tensors were not trainable inputs.
|
||||
"""
|
||||
(weight, activation_inputs, x) = ctx.saved_tensors
|
||||
|
||||
grad_input, grad_weight, grad_bias = fused_matmul_backward(
|
||||
grad_out=grad_out,
|
||||
inputs=x,
|
||||
act_in=activation_inputs,
|
||||
weight=weight,
|
||||
trainable_weight=ctx.trainable_weight,
|
||||
trainable_bias=ctx.trainable_bias,
|
||||
activation_grad=ctx.activation,
|
||||
)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FusedLinear(nn.Module):
|
||||
"""
|
||||
Handle a linear transform, like torch.nn.Linear_, and a given activation, in a single kernel.
|
||||
The whole transform: is :math:`y = activation(xA^T + b)`.
|
||||
|
||||
This is typically significantly faster than PyTorch while using fp16 and non-sigmoid activations,
|
||||
as of September 2021.
|
||||
|
||||
.. _torch.nn.Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = False,
|
||||
activation: Optional[Activation] = None,
|
||||
**_,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(out_features, in_features), requires_grad=True
|
||||
)
|
||||
self.bias = (
|
||||
nn.Parameter(torch.empty(out_features), requires_grad=True)
|
||||
if bias
|
||||
else None
|
||||
)
|
||||
|
||||
self._activation_index = get_triton_activation_index(activation)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
||||
if self.bias is not None:
|
||||
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
torch.nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, x):
|
||||
return _fused_linear_triton.apply(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self._activation_index,
|
||||
self.weight.requires_grad,
|
||||
self.bias.requires_grad if self.bias is not None else False,
|
||||
self.training and x.requires_grad and self._activation_index > 0,
|
||||
)
|
||||
152
pkgs/xformers/triton/k_activations.py
Normal file
152
pkgs/xformers/triton/k_activations.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.components import Activation
|
||||
|
||||
_kAlpha = math.sqrt(2.0 / math.pi)
|
||||
|
||||
|
||||
def get_triton_activation_index(activation: Optional[Activation]) -> int:
|
||||
return (
|
||||
{
|
||||
Activation.ReLU: 1,
|
||||
Activation.LeakyReLU: 2,
|
||||
Activation.GeLU: 3,
|
||||
Activation.SquaredReLU: 4,
|
||||
Activation.SmeLU: 5,
|
||||
Activation.StarReLU: 6,
|
||||
}[activation]
|
||||
if activation is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cosh(x):
|
||||
exp_x = tl.exp(x)
|
||||
return (exp_x + 1.0 / exp_x) * 0.5
|
||||
|
||||
|
||||
# a Triton implementation of the most used activations
|
||||
# See for instance http://arxiv.org/abs/1606.08415 for an overview
|
||||
|
||||
# ReLU
|
||||
@triton.jit
|
||||
def relu(x):
|
||||
"""
|
||||
ReLU_ activation function
|
||||
|
||||
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
|
||||
"""
|
||||
return tl.where(x >= 0, x, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def relu_grad(x):
|
||||
# ReLU is different from other activations
|
||||
# in that it does not require the input to retrospectively compute its gradient
|
||||
# here the input is the downstream gradient, and we return the upstream gradient directly
|
||||
return tl.where(x >= 0, 1.0, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def squared_relu(x):
|
||||
"""
|
||||
Squared ReLU activation, as proposed in the Primer_ paper.
|
||||
|
||||
.. _Primer: https://arxiv.org/abs/2109.08668
|
||||
"""
|
||||
x_sq = x * x
|
||||
return tl.where(x > 0.0, x_sq, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def squared_relu_grad(x):
|
||||
return tl.where(x >= 0.0, 2 * x, 0.0)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def star_relu(x):
|
||||
"""
|
||||
Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper.
|
||||
|
||||
.. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf
|
||||
"""
|
||||
x_sq = x * x
|
||||
return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472
|
||||
|
||||
|
||||
@triton.jit
|
||||
def star_relu_grad(x):
|
||||
return tl.where(x >= 0.0, 1.7888 * x, 0.0)
|
||||
|
||||
|
||||
# Leaky ReLU
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
"""
|
||||
LeakyReLU_ activation
|
||||
|
||||
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
|
||||
"""
|
||||
return tl.where(x >= 0.0, x, 0.01 * x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def leaky_relu_grad(x):
|
||||
return tl.where(x >= 0.0, 1.0, 0.01)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu(x):
|
||||
"""
|
||||
GeLU_ activation - Gaussian error linear unit
|
||||
|
||||
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
|
||||
"""
|
||||
return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x)))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_grad(x):
|
||||
# CREDITS: Fast implementation proposed in
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
|
||||
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
return 0.5 * x * (
|
||||
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
|
||||
) + 0.5 * (1 + tanh_out)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def smelu(x):
|
||||
"""
|
||||
SmeLU_ activation - Smooth ReLU with beta=2.0
|
||||
|
||||
.. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf
|
||||
"""
|
||||
beta = 2.0
|
||||
|
||||
relu = tl.where(x >= beta, x, 0.0)
|
||||
return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def smelu_grad(x):
|
||||
beta = 2.0
|
||||
|
||||
relu_grad = tl.where(x >= beta, 1.0, 0.0)
|
||||
return tl.where(tl.abs(x) <= beta, (beta + x) / (2.0 * beta), relu_grad)
|
||||
211
pkgs/xformers/triton/k_dropout.py
Normal file
211
pkgs/xformers/triton/k_dropout.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: This is heavily inspired by the Triton dropout tutorial
|
||||
# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.k_activations import (
|
||||
gelu,
|
||||
gelu_grad,
|
||||
leaky_relu,
|
||||
leaky_relu_grad,
|
||||
relu,
|
||||
relu_grad,
|
||||
smelu,
|
||||
smelu_grad,
|
||||
squared_relu,
|
||||
squared_relu_grad,
|
||||
)
|
||||
|
||||
_configs = [
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
]
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
|
||||
@triton.autotune(
|
||||
configs=_configs,
|
||||
key=["M", "N", "is_fp16"],
|
||||
)
|
||||
@triton.jit
|
||||
def k_dropout_fw(
|
||||
Y, X, BIAS, SEEDS,
|
||||
stride,
|
||||
M, N,
|
||||
p: tl.constexpr,
|
||||
is_fp16: tl.constexpr, # autotune
|
||||
ACTIVATION: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SIZE_RAND_BLOCK: tl.constexpr,
|
||||
USE_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Apply dropout on an input tensor
|
||||
Y : Output (M, N)
|
||||
X : Input (M, N)
|
||||
BIAS (N,)
|
||||
SEEDS (M,)
|
||||
p : dropout probability
|
||||
"""
|
||||
# fmt: on
|
||||
|
||||
row_id = tl.program_id(axis=0)
|
||||
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
col_id = tl.program_id(axis=1)
|
||||
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# pointers starting point
|
||||
x_ptrs = X + rows[:, None] * stride + cols[None, :]
|
||||
y_ptrs = Y + rows[:, None] * stride + cols[None, :]
|
||||
|
||||
# good to go, start the layer computations
|
||||
col_mask = cols[None, :] < N
|
||||
p_scale = 1. / (1. - p)
|
||||
if USE_BIAS:
|
||||
b_ptrs = BIAS + cols[None, :]
|
||||
bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.)
|
||||
else:
|
||||
bias = x_ptrs # will not be used
|
||||
|
||||
block_mask = (rows[:, None] < M) & col_mask
|
||||
x = tl.load(x_ptrs, mask=block_mask, other=0.0)
|
||||
|
||||
# optionally apply a fused bias
|
||||
if USE_BIAS:
|
||||
x += bias
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION == 1:
|
||||
x = relu(x)
|
||||
elif ACTIVATION == 2:
|
||||
x = leaky_relu(x)
|
||||
elif ACTIVATION == 3:
|
||||
x = gelu(x)
|
||||
elif ACTIVATION == 4:
|
||||
x = squared_relu(x)
|
||||
elif ACTIVATION == 5:
|
||||
x = smelu(x)
|
||||
|
||||
# get the random keep mask
|
||||
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
|
||||
seed_int = tl.load(SEEDS + col_id)
|
||||
r = tl.rand(seed_int, rand_offsets)
|
||||
keep_mask = r > p
|
||||
|
||||
# prune and normalize in one go
|
||||
keep = tl.view(keep_mask, x.shape)
|
||||
output = tl.where(keep, (x * p_scale).to(x.dtype), 0.)
|
||||
|
||||
tl.store(y_ptrs, output, mask=block_mask) # output
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]})
|
||||
@triton.autotune(
|
||||
configs=_configs,
|
||||
key=["M", "N", "is_fp16"],
|
||||
)
|
||||
@triton.jit
|
||||
def k_dropout_bw(
|
||||
GRAD_IN, GRAD_BIAS, GRAD_OUT,
|
||||
INPUTS, BIAS, SEEDS,
|
||||
stride_grad, stride_inputs,
|
||||
M, N,
|
||||
p: tl.constexpr,
|
||||
is_fp16: tl.constexpr, # autotune
|
||||
ACTIVATION: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr, # heuristics
|
||||
BLOCK_N: tl.constexpr,
|
||||
SIZE_RAND_BLOCK: tl.constexpr,
|
||||
TRAINABLE_BIAS: tl.constexpr,
|
||||
USE_BIAS: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Apply dropout on an input tensor
|
||||
GRAD_OUT (M, N)
|
||||
GRAD_BIAS (N,)
|
||||
GRAD_IN (M, N)
|
||||
BIAS (N,)
|
||||
SEEDS (N,)
|
||||
p : dropout probability
|
||||
"""
|
||||
# fmt: on
|
||||
|
||||
row_id = tl.program_id(axis=0)
|
||||
rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
col_id = tl.program_id(axis=1)
|
||||
cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# pointers starting point
|
||||
grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :]
|
||||
grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :]
|
||||
input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :]
|
||||
|
||||
# now go over the tiles
|
||||
grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
col_mask = cols[None, :] < N
|
||||
p_scale = 1. / (1. - p)
|
||||
|
||||
if USE_BIAS:
|
||||
b_ptrs = BIAS + cols[None, :]
|
||||
bias = tl.load(b_ptrs, mask=col_mask, other=0.)
|
||||
|
||||
block_mask = (rows[:, None] < M) & col_mask
|
||||
grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.)
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION:
|
||||
inputs = tl.load(input_ptrs, mask=block_mask, other=0.)
|
||||
|
||||
# optionally apply a fused bias
|
||||
if USE_BIAS:
|
||||
inputs += bias
|
||||
|
||||
if ACTIVATION == 1:
|
||||
act_grad = relu_grad(inputs)
|
||||
elif ACTIVATION == 2:
|
||||
act_grad = leaky_relu_grad(inputs)
|
||||
elif ACTIVATION == 3:
|
||||
act_grad = gelu_grad(inputs)
|
||||
elif ACTIVATION == 4:
|
||||
act_grad = squared_relu_grad(inputs)
|
||||
elif ACTIVATION == 5:
|
||||
act_grad = smelu_grad(inputs)
|
||||
|
||||
grad_out *= act_grad
|
||||
|
||||
# randomly prune (and scale) the resulting buffer, possibly a no-op
|
||||
# note that even if we did not save the mask from the FW pass, it is generated
|
||||
# from the same seeds, so the same drop mask is applied here
|
||||
rand_offsets = tl.arange(0, SIZE_RAND_BLOCK)
|
||||
seed_int = tl.load(SEEDS + col_id)
|
||||
r = tl.rand(seed_int, rand_offsets)
|
||||
r = tl.view(r, grad_out.shape)
|
||||
output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.)
|
||||
|
||||
# write-back
|
||||
tl.store(grad_in_ptrs, output, mask=block_mask)
|
||||
|
||||
# optionally accumulate the bias gradient
|
||||
if TRAINABLE_BIAS:
|
||||
grad_bias += tl.sum(output, axis=0)
|
||||
|
||||
if TRAINABLE_BIAS:
|
||||
grad_bias_ptr = GRAD_BIAS + row_id * N + cols
|
||||
tl.store(grad_bias_ptr, grad_bias, mask=cols < N)
|
||||
161
pkgs/xformers/triton/k_fused_matmul_bw.py
Normal file
161
pkgs/xformers/triton/k_fused_matmul_bw.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.k_activations import (
|
||||
gelu_grad,
|
||||
leaky_relu_grad,
|
||||
relu_grad,
|
||||
smelu_grad,
|
||||
squared_relu_grad,
|
||||
star_relu_grad,
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2),
|
||||
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2),
|
||||
triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4),
|
||||
triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4),
|
||||
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4),
|
||||
],
|
||||
key=["N"],
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def kernel_bw(
|
||||
# Pointers to matrices
|
||||
GRAD_ACT, GRAD_OUT, ACT_INPUTS,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_gom, stride_aim,
|
||||
# Meta-parameters
|
||||
BLOCK_N: tl.constexpr,
|
||||
EVEN_N: tl.constexpr,
|
||||
ACTIVATION_GRAD: tl.constexpr,
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
"""
|
||||
Go over all the activation inputs, compute the corresponding gradient
|
||||
"""
|
||||
|
||||
# this kernel is relatively simple in terms of scheduling:
|
||||
# - per row (pid_m)
|
||||
# - each program a given chunk on the col axis,
|
||||
# since it's more effective memory and occupancy wise
|
||||
pid_m, pid_n = tl.program_id(axis=0), tl.program_id(axis=1)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# the memory addresses of elements in the first block of
|
||||
# A and W can be computed using numpy-style broadcasting
|
||||
act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn
|
||||
|
||||
# compute the gradient which is related to this activation
|
||||
if EVEN_N:
|
||||
act_in = tl.load(act_input_ptrs)
|
||||
else:
|
||||
act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0)
|
||||
|
||||
if ACTIVATION_GRAD == 1:
|
||||
grad_act = relu_grad(act_in)
|
||||
elif ACTIVATION_GRAD == 2:
|
||||
grad_act = leaky_relu_grad(act_in)
|
||||
elif ACTIVATION_GRAD == 3:
|
||||
grad_act = gelu_grad(act_in)
|
||||
elif ACTIVATION_GRAD == 4:
|
||||
grad_act = squared_relu_grad(act_in)
|
||||
elif ACTIVATION_GRAD == 5:
|
||||
grad_act = smelu_grad(act_in)
|
||||
elif ACTIVATION_GRAD == 6:
|
||||
grad_act = star_relu_grad(act_in)
|
||||
else:
|
||||
grad_act = act_in
|
||||
|
||||
# now read the incoming gradient, the backpropagated one is the multiple of both
|
||||
grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn
|
||||
if EVEN_N:
|
||||
grad_out = tl.load(grad_out_ptrs)
|
||||
else:
|
||||
grad_out = tl.load(grad_out_ptrs, mask=rn < N)
|
||||
|
||||
grad_act *= grad_out
|
||||
|
||||
# write back result
|
||||
grad_act_ptrs = GRAD_ACT + pid_m * stride_gom + rn
|
||||
tl.store(grad_act_ptrs, grad_act, mask=rn < N)
|
||||
|
||||
|
||||
def fused_matmul_backward(
|
||||
grad_out: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
act_in: Optional[torch.Tensor],
|
||||
weight: torch.Tensor,
|
||||
trainable_weight: bool,
|
||||
trainable_bias: bool,
|
||||
activation_grad: int = 0,
|
||||
):
|
||||
"""
|
||||
Compute grad_in = activation^-1(grad_out) @ weight.transpose()
|
||||
|
||||
.. note: The weight buffer is transposed on the fly
|
||||
.. note: Activation gradient needs to be a Triton kernel
|
||||
"""
|
||||
|
||||
# Make sure that we don't have to handle the stride over cols
|
||||
if not grad_out.is_contiguous():
|
||||
grad_out = grad_out.contiguous()
|
||||
|
||||
grad_out_ = grad_out if grad_out.ndim == 2 else grad_out.flatten(0, -2)
|
||||
inputs_ = inputs if inputs.ndim == 2 else inputs.flatten(0, -2)
|
||||
|
||||
assert grad_out_.shape[1] == weight.shape[0], "Incompatible dimensions in between grad_out and weight"
|
||||
|
||||
M, N = grad_out_.shape
|
||||
N, _ = weight.shape
|
||||
|
||||
# Compute the gradient for the activation
|
||||
if activation_grad > 0:
|
||||
grad_act = torch.empty_like(grad_out_)
|
||||
|
||||
# Some activations do not require their inputs to
|
||||
# know of their grad, the downstream grad is enough
|
||||
if act_in is None:
|
||||
act_in = grad_out_
|
||||
|
||||
grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) # noqa
|
||||
|
||||
# fmt: off
|
||||
kernel_bw[grid](
|
||||
grad_act, grad_out_, act_in, # data ptrs
|
||||
N, # shapes
|
||||
grad_act.stride(0), act_in.stride(0), # strides
|
||||
ACTIVATION_GRAD=activation_grad, # optional fused activation
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
# Backpropagation going up, the reference gradient is now
|
||||
# just before the activation
|
||||
grad_out_ = grad_act
|
||||
|
||||
# The following ops can also be handled by pytorch
|
||||
grad_in = triton.ops.matmul(grad_out_, weight)
|
||||
grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None
|
||||
grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None
|
||||
|
||||
return grad_in.reshape_as(inputs), grad_weight, grad_bias
|
||||
253
pkgs/xformers/triton/k_fused_matmul_fw.py
Normal file
253
pkgs/xformers/triton/k_fused_matmul_fw.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.k_activations import (
|
||||
gelu,
|
||||
leaky_relu,
|
||||
relu,
|
||||
smelu,
|
||||
squared_relu,
|
||||
star_relu,
|
||||
)
|
||||
|
||||
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
|
||||
|
||||
|
||||
def get_configs(block_k):
|
||||
return [
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k},
|
||||
num_stages=3,
|
||||
num_warps=4,
|
||||
),
|
||||
# Fails on small GPUS
|
||||
# triton.Config(
|
||||
# {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
]
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.autotune(
|
||||
configs=[c for block_k in [32, 64] for c in get_configs(block_k)],
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def kernel_fma(
|
||||
# Pointers to matrices
|
||||
OUT, ACT_INPUTS, INPUT, WEIGHT, bias,
|
||||
# Matrix dimensions
|
||||
M, N, K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_om, stride_im,
|
||||
stride_wn,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
EVEN_N: tl.constexpr,
|
||||
BIAS: tl.constexpr,
|
||||
SAVE_ACT_INPUTS: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
is_fp16: tl.constexpr, # autotune
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
"""
|
||||
Kernel for computing Out = activation(A x W + C)
|
||||
|
||||
- Input has shape (M, K)
|
||||
- Weight has shape (K, N)
|
||||
- Bias has shape (N,)
|
||||
- Output has shape (M, N)
|
||||
- ActInputs (optional) has shape (M, N)
|
||||
|
||||
'ActInputs' optionally saves the A x W + C intermediate for backward computations
|
||||
|
||||
This kernel will consolidate over K
|
||||
"""
|
||||
|
||||
# programs are grouped together to improve L2 hit rate
|
||||
# the logic is that we'll consolidate over K. If the programs were not grouped,
|
||||
# then multiple cols/rows in the result would end up pulling in the same row and lines
|
||||
# from the inputs. By grouping the computation we ensure some data reuse, which the hardware
|
||||
# covers via the L2 cache
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis
|
||||
num_pid_n = tl.cdiv(N, BLOCK_N) # number of programs ids along the N axis
|
||||
num_pid_in_group = GROUP_M * num_pid_n # number of programs in group
|
||||
group_id = pid // num_pid_in_group # id of the group this program is in
|
||||
first_pid_m = group_id * GROUP_M # row-id of the first program in the group
|
||||
GROUP_M = min(
|
||||
num_pid_m - first_pid_m, GROUP_M
|
||||
) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller
|
||||
|
||||
# *within groups*, programs are ordered in a column-major order
|
||||
# row-id /col-id of the program in the *launch grid*
|
||||
pid_m = first_pid_m + (pid % GROUP_M)
|
||||
pid_n = (pid % num_pid_in_group) // GROUP_M
|
||||
|
||||
# now compute the block that each program will go through
|
||||
# rm (resp. rn) denotes a range of indices
|
||||
# for rows (resp. col) of C
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
# the memory addresses of elements can follow numpy broadcasting
|
||||
input_ptrs = INPUT + rm[:, None] * stride_im
|
||||
weight_ptrs = WEIGHT + rn[None, :] * stride_wn
|
||||
|
||||
# initialize and iteratively update accumulator
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
if BIAS:
|
||||
if EVEN_N:
|
||||
bias = tl.load(bias + rn).to(tl.float32)
|
||||
else:
|
||||
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
|
||||
acc += bias[None, :]
|
||||
|
||||
# block level matrix multiplication.
|
||||
# We fetch a block memory block from both inputs, matmul and accumulate, then repeat
|
||||
mask_rn = rn < N
|
||||
mask_rm = rm < M
|
||||
|
||||
for i in range(0, K, BLOCK_K):
|
||||
rk = tl.arange(0, BLOCK_K) + i
|
||||
a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0)
|
||||
w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0)
|
||||
|
||||
acc += tl.dot(a, w)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# optional: save the activation inputs
|
||||
if SAVE_ACT_INPUTS:
|
||||
act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :]
|
||||
tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION == 1:
|
||||
acc = relu(acc)
|
||||
elif ACTIVATION == 2:
|
||||
acc = leaky_relu(acc)
|
||||
elif ACTIVATION == 3:
|
||||
acc = gelu(acc)
|
||||
elif ACTIVATION == 4:
|
||||
acc = squared_relu(acc)
|
||||
elif ACTIVATION == 5:
|
||||
acc = smelu(acc)
|
||||
elif ACTIVATION == 6:
|
||||
acc = star_relu(acc)
|
||||
|
||||
# write back result
|
||||
out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :]
|
||||
tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :])
|
||||
|
||||
|
||||
# Activation needs to be a triton kernel
|
||||
def fused_matmul(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
activation=0,
|
||||
save_act_inputs: bool = False
|
||||
):
|
||||
"""
|
||||
Compute e = activation(x @ weight + bias).
|
||||
This wrapper kicks the `kernel_fma` Triton kernel
|
||||
"""
|
||||
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
|
||||
x_ = x if x.ndim == 2 else x.flatten(0, -2)
|
||||
|
||||
assert (
|
||||
x_.shape[1] == weight.shape[1]
|
||||
), f"Incompatible dimensions in between inputs and weight, {x_.shape} - {weight.shape}"
|
||||
assert bias is None or bias.is_contiguous()
|
||||
assert (
|
||||
bias is None or bias.shape[0] == weight.shape[0]
|
||||
), "Incompatible dimensions in between weight and bias"
|
||||
assert weight.is_contiguous()
|
||||
|
||||
M, K = x_.shape
|
||||
N, K = weight.shape
|
||||
|
||||
outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)
|
||||
act_inputs = torch.empty_like(outputs) if save_act_inputs else x # will not be used in that case
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
|
||||
|
||||
# fmt: off
|
||||
kernel_fma[grid](
|
||||
outputs, act_inputs, x_, weight, # data ptrs
|
||||
bias if bias is not None else x, # auto skip bias if not present
|
||||
M, N, K, # shapes
|
||||
outputs.stride(0), x_.stride(0), # strides
|
||||
weight.stride(0),
|
||||
ACTIVATION=activation, # optional fused activation
|
||||
BIAS=bias is not None, # optional fused bias
|
||||
GROUP_M=8, # speed optimization: group the programs
|
||||
SAVE_ACT_INPUTS=save_act_inputs,
|
||||
is_fp16=x_.dtype == torch.float16
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
outputs = outputs if x.ndim == 2 else outputs.reshape(*x.shape[:-1], N)
|
||||
|
||||
return outputs, act_inputs if save_act_inputs else None
|
||||
179
pkgs/xformers/triton/k_layer_norm.py
Normal file
179
pkgs/xformers/triton/k_layer_norm.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: This comes almost as-is from the Triton layer norm tutorial
|
||||
# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# fmt: on
|
||||
"""
|
||||
Fused layernorm kernel over a 3d tensor.
|
||||
The layer norm is applied over the last dimension.
|
||||
|
||||
Compute
|
||||
y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta
|
||||
"""
|
||||
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
|
||||
# Move to this row
|
||||
x_ptrs = X + row * stride + cols
|
||||
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute mean and variance
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
x_zm = tl.where(mask, x - mean, 0.0)
|
||||
tl.store(M + row, mean)
|
||||
|
||||
x_var = tl.sum(x_zm * x_zm, axis=0) / N
|
||||
rstd = 1.0 / tl.sqrt(x_var + eps)
|
||||
|
||||
# Normalize, optionally affine
|
||||
y = x_zm * rstd
|
||||
tl.store(V + row, rstd)
|
||||
|
||||
mask = cols < N
|
||||
if affine:
|
||||
w = tl.load(W + cols, mask=mask, other=1.0)
|
||||
b = tl.load(B + cols, mask=mask, other=0.0)
|
||||
y = y * w + b
|
||||
|
||||
y_ptrs = Y + row * stride + cols
|
||||
tl.store(y_ptrs, y, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_bwd_dx_fused(
|
||||
DX, DY, DW, DB,
|
||||
X, W, M, V,
|
||||
Lock, stride, N,
|
||||
# META-parameters
|
||||
affine: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
|
||||
# offset data pointers to start at the row of interest
|
||||
x_ptrs = X + row * stride + cols
|
||||
dy_ptrs = DY + row * stride + cols
|
||||
|
||||
# load data to SRAM
|
||||
x = tl.load(x_ptrs, mask=mask, other=0)
|
||||
dy = tl.load(dy_ptrs, mask=mask, other=0)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
|
||||
# compute dx
|
||||
xhat = (x - mean) * rstd
|
||||
|
||||
if affine:
|
||||
w = tl.load(W + cols, mask=mask, other=0)
|
||||
wdy = w * dy
|
||||
else:
|
||||
wdy = dy
|
||||
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
|
||||
# write-back dx
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N # re-materialize the mask to save registers
|
||||
dx_ptrs = DX + row * stride + cols
|
||||
tl.store(dx_ptrs, dx, mask=mask)
|
||||
|
||||
if affine:
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = dy.to(w.dtype)
|
||||
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
|
||||
# - wait for a lock on the accumulated dw/db
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
|
||||
# - we got the lock, accumulate this kernel's results with
|
||||
# the stored values.
|
||||
dw_ptrs = DW + lock_id * N + cols
|
||||
db_ptrs = DB + lock_id * N + cols
|
||||
|
||||
if count == 0:
|
||||
# first store doesn't accumulate
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(dw_ptrs, mask=mask, other=0.)
|
||||
partial_db += tl.load(db_ptrs, mask=mask, other=0.)
|
||||
|
||||
tl.store(dw_ptrs, partial_dw, mask=mask)
|
||||
tl.store(db_ptrs, partial_db, mask=mask)
|
||||
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_bwd_dwdb(
|
||||
DW, DB, FINAL_DW, FINAL_DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask_cols = cols < N
|
||||
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
mask_rm = rows < M
|
||||
|
||||
dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
|
||||
db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
|
||||
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask_cols = cols < N
|
||||
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=mask_cols)
|
||||
175
pkgs/xformers/triton/k_softmax.py
Normal file
175
pkgs/xformers/triton/k_softmax.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/
|
||||
# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html
|
||||
|
||||
|
||||
def get_depth(args):
|
||||
return triton.next_power_of_2(args["K"])
|
||||
|
||||
|
||||
# autotune: Triton will test out these configurations, and automatically pick the fastest one.
|
||||
# heuristic: add arguments to the kernel call automatically given some heuristics. These arguments are passed in "meta"
|
||||
# fmt: off
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
triton.Config({}, num_warps=32),
|
||||
],
|
||||
key=["K"],
|
||||
)
|
||||
@triton.heuristics(values={"depth": get_depth})
|
||||
@triton.jit
|
||||
def _softmax(
|
||||
Y, X, M,
|
||||
stride_ym, stride_yn,
|
||||
stride_xm, stride_xn,
|
||||
stride_mn,
|
||||
K,
|
||||
# Meta-params
|
||||
depth: tl.constexpr,
|
||||
causal: tl.constexpr,
|
||||
use_mask: tl.constexpr,
|
||||
log: tl.constexpr,
|
||||
):
|
||||
# fmt: om
|
||||
|
||||
"""
|
||||
Fused softmax kernel over a 3d tensor.
|
||||
The softmax is applied over the last dimension, meaning that this is equivalent to torch.softmax(tensor, dim=-1)
|
||||
|
||||
Note, if the last dimension is large, say 128K elements, the kernel compile time can shot up to many minutes when
|
||||
the kernel is run for the first time.
|
||||
"""
|
||||
|
||||
m = tl.program_id(0)
|
||||
n = tl.program_id(1)
|
||||
|
||||
# col indices
|
||||
k = tl.arange(0, depth)
|
||||
|
||||
# the memory address of all the elements that we want to load can be computed as follows
|
||||
x_ptrs = X + m * stride_xm + n * stride_xn + k
|
||||
|
||||
# load input data; pad out-of-bounds elements with 0
|
||||
io_mask = k < K
|
||||
|
||||
# Causal - 1: skip on the loads directly
|
||||
if causal:
|
||||
io_mask = io_mask & (k <= n)
|
||||
|
||||
x = tl.load(x_ptrs, mask=io_mask, other=float("-inf")).to(tl.float32)
|
||||
|
||||
# Causal - 2: enforce correctness over a couple of misloaded values
|
||||
if causal:
|
||||
off = float("-inf")
|
||||
off = off.to(x.dtype) # type: ignore
|
||||
x = tl.where(k > n, off, x)
|
||||
|
||||
if use_mask:
|
||||
mask_ptrs = M + n * stride_mn + k
|
||||
add_mask = tl.load(mask_ptrs, io_mask, other=float("-inf")).to(tl.float32)
|
||||
x += add_mask
|
||||
|
||||
# compute numerically-stable softmax
|
||||
z = x - tl.max(x, axis=0)
|
||||
num = tl.exp(z)
|
||||
denom = tl.sum(num, axis=0)
|
||||
|
||||
if log:
|
||||
y = z - tl.log(denom)
|
||||
else:
|
||||
y = num / denom
|
||||
|
||||
# write back to Y.
|
||||
# we only write once, hence the "fused" softmax naming
|
||||
y_ptrs = Y + m * stride_ym + n * stride_yn + k
|
||||
|
||||
# technically we could write only the lower triangular matrix in the causal case
|
||||
# but this is deemed to error prone
|
||||
tl.store(y_ptrs, y, mask=k < K)
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
triton.Config({}, num_warps=16),
|
||||
],
|
||||
key=["K"],
|
||||
)
|
||||
@triton.jit
|
||||
def _softmax_backward(
|
||||
GradIn, GradOut, Out,
|
||||
stride_bm, stride_bn,
|
||||
stride_gm, stride_gn,
|
||||
stride_om, stride_on,
|
||||
K,
|
||||
# meta-params
|
||||
depth: tl.constexpr,
|
||||
causal: tl.constexpr,
|
||||
log: tl.constexpr,
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
"""
|
||||
Compute the softmax gradients.
|
||||
..Note: Not autotuning for now because this would lead to broken accumulated gradients
|
||||
"""
|
||||
|
||||
m = tl.program_id(0)
|
||||
n = tl.program_id(1)
|
||||
|
||||
# col indices
|
||||
k = tl.arange(0, depth)
|
||||
|
||||
# the memory address of all the elements that we want to load can be computed as follows
|
||||
grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k
|
||||
out_ptrs = Out + m * stride_om + n * stride_on + k
|
||||
|
||||
# load input data; pad out-of-bounds elements with 0
|
||||
io_mask = k < K
|
||||
|
||||
# Causal - 1: skip on the loads directly
|
||||
if causal:
|
||||
io_mask = io_mask & (k <= n)
|
||||
|
||||
g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0)).to(tl.float32)
|
||||
o = tl.load(out_ptrs, mask=io_mask, other=float(0)).to(tl.float32)
|
||||
|
||||
# Causal - 2: enforce correctness over a couple of misloaded values
|
||||
if causal:
|
||||
zero = float(0)
|
||||
zero = zero.to(g.dtype) # type: ignore
|
||||
g = tl.where(k > n, zero, g)
|
||||
o = tl.where(k > n, zero, o)
|
||||
|
||||
if log:
|
||||
s = tl.sum(g, 0)
|
||||
grad_in = g - tl.exp(o) * s
|
||||
else:
|
||||
# Step 1: Compute the intermediate sum used for the gradient
|
||||
s = tl.sum(g * o, 0)
|
||||
|
||||
# Step 2: Compute the gradients
|
||||
grad_in = o * (g - s)
|
||||
|
||||
# write back to the input gradients
|
||||
# technically we could write only the lower triangular matrix in the causal case
|
||||
# but this is deemed to error prone
|
||||
grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k
|
||||
tl.store(grad_in_ptrs, grad_in, mask=k < K)
|
||||
55
pkgs/xformers/triton/k_sum.py
Normal file
55
pkgs/xformers/triton/k_sum.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def k_sum_0(
|
||||
Y, X,
|
||||
stride_xm,
|
||||
M, N,
|
||||
is_fp16,
|
||||
# META-params
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# fmt: om
|
||||
|
||||
"""
|
||||
Sum a 2d tensor over the first (strided) dimension.
|
||||
This extracts some speed through a parallel sum across the second dimension
|
||||
"""
|
||||
|
||||
# partial row indices. We'll reduce over this dimension
|
||||
m = tl.arange(0, BLOCK_M)
|
||||
|
||||
# To get some extra parallelization, we handle several columns in the same thread block
|
||||
rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# the memory address of all the elements that we want to load can be computed as follows
|
||||
x_ptrs = X + m[:, None] * stride_xm + rn[None, :]
|
||||
x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
|
||||
tiles = M // BLOCK_M
|
||||
if M % BLOCK_M > 0:
|
||||
tiles += 1
|
||||
|
||||
col_mask = (rn[None, :] < N)
|
||||
|
||||
for _ in range(tiles):
|
||||
# load input data; pad out-of-bounds elements with 0
|
||||
# NOTE: make sure to accumulate in fp32 to prevent a trivial overflow
|
||||
mask = (m[:, None] < M) & col_mask
|
||||
x = tl.load(x_ptrs, mask=mask, other=0.0)
|
||||
x_sum += tl.sum(x, 0)
|
||||
|
||||
# move the load pointer
|
||||
x_ptrs += BLOCK_M * stride_xm
|
||||
m += BLOCK_M # update the mask check
|
||||
|
||||
tl.store(Y + rn, x_sum, mask=rn < N)
|
||||
236
pkgs/xformers/triton/layer_norm.py
Normal file
236
pkgs/xformers/triton/layer_norm.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# CREDITS: the underlying kernel comes straight from the Triton tutorials
|
||||
# see https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from xformers.triton.k_layer_norm import (
|
||||
layer_norm_bwd_dwdb,
|
||||
layer_norm_bwd_dx_fused,
|
||||
layer_norm_fw,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32
|
||||
_triton_registered_warnings = False
|
||||
|
||||
|
||||
class _LayerNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None)
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
# catch eps being too small if the tensors are fp16
|
||||
if x.dtype == torch.float16:
|
||||
eps = max(eps, 1.6e-5)
|
||||
|
||||
# allocate output
|
||||
y = torch.empty_like(x)
|
||||
|
||||
# reshape input data into 2D tensor
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
|
||||
# allocate mean and std, they'll be used in the backward pass
|
||||
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||||
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
if N > BLOCK_SIZE_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not x_arg.is_contiguous() or not y.is_contiguous():
|
||||
global _triton_registered_warnings
|
||||
if not _triton_registered_warnings:
|
||||
logger.warning(
|
||||
"Non-contiguous input tensor found. Making it contiguous,"
|
||||
+ " but could have perf or trainer implications"
|
||||
)
|
||||
|
||||
_triton_registered_warnings = True
|
||||
|
||||
x_arg = x_arg.contiguous()
|
||||
y = y.contiguous()
|
||||
|
||||
# heuristics for number of warps.
|
||||
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)
|
||||
|
||||
# enqueue kernel
|
||||
# fmt: off
|
||||
layer_norm_fw[(M,)](
|
||||
x_arg, y, weight, bias, mean, rstd,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
eps,
|
||||
num_warps=num_warps,
|
||||
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||||
affine=weight is not None
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
ctx.save_for_backward(x, mean, rstd, weight)
|
||||
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
|
||||
ctx.num_warps = num_warps
|
||||
|
||||
return y.reshape_as(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx, dy
|
||||
): # pragma: no cover # this is covered, but called directly from C++
|
||||
x, mean, rstd, weight = ctx.saved_tensors
|
||||
|
||||
# flatten the batch dimension, if any.
|
||||
# We're interested in 'samples' x norm_dimension
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
M, N = x.size()
|
||||
|
||||
# heuristics for amount of parallel reduction stream for DG/DB
|
||||
GROUP_SIZE_M = 32
|
||||
if N <= 8192:
|
||||
GROUP_SIZE_M = 64
|
||||
if N <= 4096:
|
||||
GROUP_SIZE_M = 96
|
||||
if N <= 2048:
|
||||
GROUP_SIZE_M = 128
|
||||
if N <= 1024:
|
||||
GROUP_SIZE_M = 256
|
||||
|
||||
if dy.dtype == torch.float32:
|
||||
GROUP_SIZE_M = GROUP_SIZE_M // 2
|
||||
|
||||
# allocate output
|
||||
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
|
||||
t_args = {"dtype": x.dtype, "device": x.device}
|
||||
_dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args)
|
||||
_db = torch.empty_like(_dw)
|
||||
dw = torch.empty((x.size(-1),), **t_args)
|
||||
db = torch.empty_like(dw)
|
||||
dy = dy.contiguous()
|
||||
dx = torch.empty_like(dy)
|
||||
|
||||
# Check the tensor shapes and layouts
|
||||
# we suppose in the kernel that they have the same size and are contiguous
|
||||
assert (
|
||||
dy.numel() == x.numel()
|
||||
), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
|
||||
|
||||
# enqueue kernel using forward pass heuristics
|
||||
# also compute partial sums for DW and DB
|
||||
num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)
|
||||
|
||||
# fmt: off
|
||||
layer_norm_bwd_dx_fused[(M,)](
|
||||
dx, dy, _dw, _db, x,
|
||||
weight if weight is not None else x,
|
||||
mean, rstd,
|
||||
locks,
|
||||
x.stride(0),
|
||||
N,
|
||||
affine=weight is not None,
|
||||
GROUP_SIZE_M=GROUP_SIZE_M,
|
||||
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
|
||||
num_warps=num_warps
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def grid(meta):
|
||||
return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||||
|
||||
# accumulate partial sums in separate kernel
|
||||
# fmt: off
|
||||
layer_norm_bwd_dwdb[grid](
|
||||
_dw, _db, dw, db,
|
||||
GROUP_SIZE_M,
|
||||
N,
|
||||
BLOCK_SIZE_M=32,
|
||||
BLOCK_SIZE_N=64
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
dx = dx.reshape_as(dy)
|
||||
return dx, dw, db, None
|
||||
|
||||
|
||||
class FusedLayerNorm(nn.Module):
|
||||
"""
|
||||
Handle a layer normalization, like torch.nn.LayerNorm_.
|
||||
|
||||
This implementation should be measurably faster than the default PyTorch layernorm (as of PyTorch 1.9),
|
||||
both for training and inference worloads.
|
||||
|
||||
.. NOTE: Computations under Torch AMP are kept as float32 by default, one can change this to be float16
|
||||
by setting the flag `xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True`
|
||||
|
||||
.. _torch.nn.LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, affine=True, eps=1e-06):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
else:
|
||||
self.weight = self.bias = None
|
||||
self.epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
return layer_norm(x, self.weight, self.bias, self.epsilon)
|
||||
|
||||
def init_weights(self, *args, **kwargs):
|
||||
with torch.no_grad():
|
||||
if self.weight is not None:
|
||||
self.weight.fill_(1.0)
|
||||
|
||||
if self.bias is not None:
|
||||
self.bias.fill_(0.0)
|
||||
|
||||
|
||||
def layer_norm(
|
||||
x: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-06,
|
||||
) -> torch.Tensor:
|
||||
|
||||
global _triton_registered_warnings
|
||||
|
||||
r"""Applies normalization over a mini batch of inputs"""
|
||||
|
||||
try:
|
||||
if (
|
||||
not _triton_registered_warnings
|
||||
and torch.cuda.is_available()
|
||||
and x.is_cuda
|
||||
and weight is not None
|
||||
and bias is not None
|
||||
):
|
||||
return _LayerNorm.apply(x, weight, bias, eps)
|
||||
except RuntimeError as e:
|
||||
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
|
||||
# fallback to PyTorch's implementation, which streams the tensor in and out
|
||||
_triton_registered_warnings = True
|
||||
logger.warning(
|
||||
"Triton layernorm kernel register spillover or invalid image caught. "
|
||||
"Deactivating this kernel, please file an issue in the xFormers repository"
|
||||
)
|
||||
logger.warning(e)
|
||||
|
||||
return torch.nn.functional.layer_norm(
|
||||
x, [x.shape[-1]], weight=weight, bias=bias, eps=eps
|
||||
)
|
||||
203
pkgs/xformers/triton/softmax.py
Normal file
203
pkgs/xformers/triton/softmax.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
from xformers.triton.k_softmax import _softmax, _softmax_backward
|
||||
|
||||
# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/
|
||||
# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html
|
||||
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_triton_softmax_fp16_enabled = False # NOTE: PyTorch keeps softmax as fp32
|
||||
_triton_registered_warnings = False
|
||||
|
||||
|
||||
# Helper to handle the SPMD launch grid and error cases
|
||||
class _softmax_triton(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16 if _triton_softmax_fp16_enabled else None)
|
||||
def forward(ctx, x, mask, log_outputs, causal):
|
||||
"""
|
||||
Fused softmax implementation, using the Triton programming model.
|
||||
This only supports a reduction over the last dimension for now
|
||||
"""
|
||||
|
||||
# Handle 2D/3D tensors
|
||||
x_ = x.unsqueeze(0) if x.ndim == 2 else x
|
||||
x_ = x_.flatten(0, -3)
|
||||
|
||||
if not x_.is_contiguous():
|
||||
x_ = x_.contiguous()
|
||||
|
||||
y = torch.empty_like(x_)
|
||||
assert (
|
||||
y.stride(2) == 1 and x_.stride(2) == 1
|
||||
), f"{x.shape} - {x_.shape} - {x_.stride()}"
|
||||
|
||||
# SPMD launch grid
|
||||
grid_2d = (
|
||||
x_.shape[0],
|
||||
x_.shape[1],
|
||||
)
|
||||
|
||||
# enqueue GPU kernel
|
||||
use_mask = True
|
||||
if mask is None:
|
||||
# placeholder, will not be used
|
||||
mask = x_
|
||||
use_mask = False
|
||||
else:
|
||||
# Make sure that the mask is binary
|
||||
assert mask.dtype == x.dtype, "An additive mask is requested"
|
||||
|
||||
_softmax[grid_2d](
|
||||
y,
|
||||
x_,
|
||||
mask,
|
||||
y.stride(0),
|
||||
y.stride(1),
|
||||
x_.stride(0),
|
||||
x_.stride(1),
|
||||
mask.stride(0),
|
||||
x_.shape[2],
|
||||
log=log_outputs,
|
||||
use_mask=use_mask,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(y)
|
||||
ctx.log_outputs = log_outputs
|
||||
ctx.causal = causal
|
||||
return y.reshape_as(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(
|
||||
ctx, grad_out
|
||||
): # pragma: no cover # this is covered, but called directly from C++
|
||||
(out,) = ctx.saved_tensors
|
||||
|
||||
# Handle 2D/3D tensors
|
||||
grad_out_ = grad_out.unsqueeze(0) if grad_out.ndim == 2 else grad_out
|
||||
grad_out_ = grad_out_.flatten(0, -3)
|
||||
|
||||
# SPMD launch grid
|
||||
grid_2d = (
|
||||
grad_out_.shape[0],
|
||||
grad_out_.shape[1],
|
||||
)
|
||||
|
||||
depth = triton.next_power_of_2(grad_out_.shape[2])
|
||||
grad_in = torch.empty_like(
|
||||
out
|
||||
) # torch.zeros is measurably slower, we'll zero out in the kernel
|
||||
|
||||
# Make sure that the tensor are contiguous
|
||||
grad_in, grad_out, out = map(lambda x: x.contiguous(), [grad_in, grad_out, out])
|
||||
|
||||
# fmt: off
|
||||
_softmax_backward[grid_2d](
|
||||
grad_in, grad_out_, out,
|
||||
grad_in.stride(0), grad_in.stride(1),
|
||||
grad_out_.stride(0), grad_out_.stride(1),
|
||||
out.stride(0), out.stride(1),
|
||||
out.shape[2],
|
||||
depth=depth,
|
||||
log=ctx.log_outputs,
|
||||
causal=ctx.causal
|
||||
)
|
||||
# fmt: on
|
||||
return grad_in.reshape_as(grad_out), None, None, None
|
||||
|
||||
|
||||
def softmax(
|
||||
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
|
||||
) -> torch.Tensor:
|
||||
r"""Applies the Softmax function to an 3-dimensional input Tensor
|
||||
rescaling them so that the elements of the n-dimensional output Tensor
|
||||
lie in the range [0,1] and sum to 1.
|
||||
|
||||
Softmax is defined as:
|
||||
|
||||
.. math::
|
||||
\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
|
||||
|
||||
.. warning: softmax is computed on the last dimension of the input tensor.
|
||||
|
||||
|
||||
Args:
|
||||
x: input tensor.
|
||||
mask: optional mask, its application will be fused to the softmax computation if triton is used
|
||||
causal: optional performance optimization, if triton is used and the attention is causal
|
||||
|
||||
Returns:
|
||||
a Tensor of the same dimension and shape as the input with
|
||||
values in the range [0, 1] and sum to 1
|
||||
"""
|
||||
return _softmax_dispatch(x, log=False, mask=mask, causal=causal)
|
||||
|
||||
|
||||
def log_softmax(
|
||||
x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False
|
||||
) -> torch.Tensor:
|
||||
r"""Applies the :math:`\log(\text{Softmax}(x))` function to an 3-dimensional
|
||||
input Tensor. The LogSoftmax formulation can be simplified as:
|
||||
|
||||
.. math::
|
||||
\text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
|
||||
|
||||
Args:
|
||||
x: input tensor.
|
||||
|
||||
Returns:
|
||||
a Tensor of the same dimension and shape as the input with
|
||||
values in the range [-inf, 0)
|
||||
"""
|
||||
return _softmax_dispatch(x, log=True, mask=mask, causal=causal)
|
||||
|
||||
|
||||
def _softmax_dispatch(
|
||||
x: torch.Tensor, log: bool, mask: Optional[torch.Tensor], causal: bool = False
|
||||
) -> torch.Tensor:
|
||||
# Triton is used if
|
||||
# - CUDA
|
||||
# - there's enough data to make it faster than pytorch. This could change over time, Triton is improving
|
||||
# - there was no previous failure
|
||||
|
||||
global _triton_registered_warnings
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available() and x.is_cuda and not _triton_registered_warnings:
|
||||
return _softmax_triton.apply(x, mask, log, causal)
|
||||
except RuntimeError as e:
|
||||
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
|
||||
# fallback to PyTorch's implementation, which streams the tensor in and out
|
||||
_triton_registered_warnings = True
|
||||
logger.warning(
|
||||
"Triton softmax kernel register spillover or invalid image caught."
|
||||
"Deactivating this kernel, please file an issue int the xFormers repository"
|
||||
)
|
||||
logger.warning(e)
|
||||
|
||||
if mask is not None:
|
||||
x = x + mask
|
||||
|
||||
if causal:
|
||||
x = x + torch.triu(torch.full_like(x, float("-inf")), diagonal=1)
|
||||
|
||||
if log:
|
||||
return torch.log_softmax(x, dim=-1)
|
||||
else:
|
||||
return torch.softmax(x, dim=-1)
|
||||
60
pkgs/xformers/triton/sum_strided.py
Normal file
60
pkgs/xformers/triton/sum_strided.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from xformers.triton.k_sum import k_sum_0
|
||||
|
||||
|
||||
def sum_2d_dim_0(x: torch.Tensor):
|
||||
"""
|
||||
Sum a 2D tensor across the first dimension
|
||||
"""
|
||||
|
||||
out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype)
|
||||
|
||||
assert (
|
||||
x.ndim == 2
|
||||
), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0"
|
||||
M, N = x.shape
|
||||
|
||||
# This kernel is not competitive for these sizes
|
||||
if M > 2048 or M < 8:
|
||||
return x.sum(dim=0)
|
||||
|
||||
assert (
|
||||
M >= 4
|
||||
), "This is a very specific kernel, requires the reduction dimension to be bigger than 4"
|
||||
|
||||
assert x.stride(1) == 1, (
|
||||
"We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n"
|
||||
" You would probably be better served with torch.sum()"
|
||||
)
|
||||
|
||||
BLOCK_M = min(triton.next_power_of_2(M), 2048)
|
||||
BLOCK_N = 32
|
||||
if BLOCK_M > 256:
|
||||
BLOCK_N = 16
|
||||
if BLOCK_M > 1024:
|
||||
BLOCK_N = 8
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(N, meta["BLOCK_N"]),)
|
||||
|
||||
# fmt: off
|
||||
k_sum_0[grid](
|
||||
out, x,
|
||||
x.stride(0),
|
||||
M, N,
|
||||
x.dtype == torch.float16,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
num_stages=4,
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
return out
|
||||
40
pkgs/xformers/triton/utils.py
Normal file
40
pkgs/xformers/triton/utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger("xformers")
|
||||
|
||||
|
||||
_gpu_is_old: Optional[bool] = None
|
||||
|
||||
|
||||
def gpu_capabilities_older_than_70() -> bool:
|
||||
"""Return True if the GPU's compute capability is older than SM70."""
|
||||
global _gpu_is_old
|
||||
if _gpu_is_old is None:
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, _ = torch.cuda.get_device_capability(f"cuda:{i}")
|
||||
if major < 7:
|
||||
_gpu_is_old = True
|
||||
if _gpu_is_old is None:
|
||||
_gpu_is_old = False
|
||||
return _gpu_is_old
|
||||
|
||||
|
||||
SUPPORTED_CUDA_DEVICES = ["V100", "A100", "T4"]
|
||||
|
||||
|
||||
def get_current_cuda_device():
|
||||
current_device = str(torch.cuda.get_device_properties(torch.cuda.current_device()))
|
||||
for device_str in SUPPORTED_CUDA_DEVICES:
|
||||
if current_device.find(device_str) > 0:
|
||||
return device_str
|
||||
|
||||
logger.warning("Unsupported device, Triton code generation may fail")
|
||||
return "P100" # default to an old GPU
|
||||
173
pkgs/xformers/triton/vararg_kernel.py
Normal file
173
pkgs/xformers/triton/vararg_kernel.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import functools
|
||||
import linecache
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class _ForLoopUnroller(ast.NodeTransformer):
|
||||
def __init__(self, target, inline_variables, loop_iter):
|
||||
self.loop_iter = loop_iter
|
||||
self.target = target
|
||||
self.inline_variables = inline_variables
|
||||
|
||||
def visit_Name(self, node):
|
||||
if node.id != self.target:
|
||||
return node
|
||||
return ast.Name(str(self.loop_iter))
|
||||
|
||||
def visit_Subscript(self, node):
|
||||
# Pattern-matching `value[slice]`
|
||||
if (
|
||||
isinstance(node.slice, ast.Name)
|
||||
and node.slice.id == self.target
|
||||
and isinstance(node.value, ast.Name)
|
||||
and node.value.id in self.inline_variables
|
||||
):
|
||||
return ast.Name(f"{node.value.id}{self.loop_iter}")
|
||||
return node
|
||||
|
||||
|
||||
class _VisitorUnrollKernel(ast.NodeTransformer):
|
||||
def __init__(self, N):
|
||||
self.inline_variables = set()
|
||||
self.N = N
|
||||
|
||||
def visit_AnnAssign(self, node):
|
||||
# Pattern-matching:
|
||||
# var_name: "VAR_ARGS_ARRAY"
|
||||
if (
|
||||
node.value is None
|
||||
and node.simple == 1
|
||||
and isinstance(node.target, ast.Name)
|
||||
and isinstance(node.annotation, ast.Constant)
|
||||
and node.annotation.value == "VAR_ARGS_ARRAY"
|
||||
):
|
||||
self.inline_variables.add(node.target.id)
|
||||
return []
|
||||
if node.value is not None:
|
||||
node.value = self.visit(node.value)
|
||||
if node.annotation is not None:
|
||||
node.annotation = self.visit(node.annotation)
|
||||
if node.target is not None:
|
||||
node.target = self.visit(node.target)
|
||||
return node
|
||||
|
||||
def visit_arguments(self, node):
|
||||
# Replace `args` annotated with `VAR_ARGS_ARRAY`
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if (
|
||||
arg.annotation is not None
|
||||
and isinstance(arg.annotation, ast.Constant)
|
||||
and arg.annotation.value == "VAR_ARGS_ARRAY"
|
||||
):
|
||||
self.inline_variables.add(arg.arg)
|
||||
new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)]
|
||||
continue
|
||||
new_args.append(arg)
|
||||
if node.vararg is not None:
|
||||
self.inline_variables.add(node.vararg.arg)
|
||||
new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)]
|
||||
node.vararg = None
|
||||
new_args += node.kwonlyargs
|
||||
node.kwonlyargs = []
|
||||
node.args = new_args
|
||||
return node
|
||||
|
||||
def visit_For(self, node):
|
||||
if (
|
||||
not isinstance(node.iter, ast.Call)
|
||||
or node.iter.func.id != "range"
|
||||
or len(node.iter.args) != 1
|
||||
or not isinstance(node.iter.args[0], ast.Call)
|
||||
or node.iter.args[0].func.id != "len"
|
||||
or len(node.iter.args[0].args) != 1
|
||||
or node.iter.args[0].args[0].id not in self.inline_variables
|
||||
):
|
||||
node.body = [self.visit(x) for x in node.body]
|
||||
return node
|
||||
# We know we have to modify this loop
|
||||
new_nodes = []
|
||||
for i in range(self.N):
|
||||
unroller = _ForLoopUnroller(
|
||||
target=node.target.id,
|
||||
inline_variables=self.inline_variables,
|
||||
loop_iter=i,
|
||||
)
|
||||
for body in node.body:
|
||||
body = copy.deepcopy(body)
|
||||
new_node = ast.fix_missing_locations(unroller.visit(body))
|
||||
new_node = self.visit(new_node)
|
||||
new_nodes.append(new_node)
|
||||
return new_nodes
|
||||
|
||||
|
||||
# Hackfix to get access to get source-code for
|
||||
# `exec`-created functions - see https://stackoverflow.com/a/69668999
|
||||
_getlines_orig = None
|
||||
_FILENAME_TO_SRC: Dict[str, str] = {}
|
||||
|
||||
|
||||
def _monkey_patched_getlines(filename, module_globals=None):
|
||||
if filename in _FILENAME_TO_SRC:
|
||||
return _FILENAME_TO_SRC[filename]
|
||||
else:
|
||||
return _getlines_orig(filename, module_globals) # type: ignore
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def unroll_varargs(kernel, N: int):
|
||||
"""
|
||||
Specializes a triton kernel with variable number of inputs
|
||||
to a specific number of inputs `N`.
|
||||
NOTE: Because it's quite costly to call `triton.jit`,
|
||||
we cache the returned value with `lru_cache`
|
||||
"""
|
||||
global _FILENAME_TO_SRC, _getlines_orig
|
||||
|
||||
k = triton.JITFunction(kernel.fn)
|
||||
parsed = ast.parse(k.src)
|
||||
nodeVisitor = _VisitorUnrollKernel(N=N)
|
||||
parsed = nodeVisitor.visit(parsed)
|
||||
parsed = ast.fix_missing_locations(parsed)
|
||||
|
||||
# NOTE: `ast.unparse` requires python 3.9+
|
||||
if (sys.version_info.major, sys.version_info.minor) <= (3, 8):
|
||||
raise RuntimeError("Error: This functionality requires python 3.9 or above")
|
||||
new_src = ast.unparse(parsed) # type: ignore
|
||||
|
||||
# Now we want to `eval` the function, but we need all this
|
||||
# boilerplate code to make sure triton can run `inspect.getsource`
|
||||
|
||||
fn_filename = f"<unroll_varargs-{kernel.fn.__name__}-{N}>"
|
||||
|
||||
# Create function given source
|
||||
code = compile(new_src, fn_filename, "exec")
|
||||
|
||||
_locals: Dict[str, Any] = {}
|
||||
exec(code, kernel.fn.__globals__, _locals)
|
||||
assert len(_locals) == 1, len(_locals)
|
||||
fn = next(iter(_locals.values()))
|
||||
# Patch `getlines` only the first time
|
||||
if not _FILENAME_TO_SRC:
|
||||
_getlines_orig = linecache.getlines
|
||||
linecache.getlines = _monkey_patched_getlines
|
||||
_FILENAME_TO_SRC[fn_filename] = new_src
|
||||
|
||||
jitted_fn = triton.jit(fn)
|
||||
jitted_fn.src = new_src
|
||||
return jitted_fn
|
||||
|
||||
|
||||
# Note: just import this to make mypy happy
|
||||
# when annotating variables with `VAR_ARGS_ARRAY`
|
||||
VAR_ARGS_ARRAY = List[Any]
|
||||
Reference in New Issue
Block a user