237 lines
7.5 KiB
Python
237 lines
7.5 KiB
Python
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the BSD license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
|
||
|
|
# CREDITS: the underlying kernel comes straight from the Triton tutorials
|
||
|
|
# see https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import triton
|
||
|
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||
|
|
|
||
|
|
from xformers.triton.k_layer_norm import (
|
||
|
|
layer_norm_bwd_dwdb,
|
||
|
|
layer_norm_bwd_dx_fused,
|
||
|
|
layer_norm_fw,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = logging.getLogger("xformers")
|
||
|
|
|
||
|
|
|
||
|
|
_triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32
|
||
|
|
_triton_registered_warnings = False
|
||
|
|
|
||
|
|
|
||
|
|
class _LayerNorm(torch.autograd.Function):
|
||
|
|
@staticmethod
|
||
|
|
@custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None)
|
||
|
|
def forward(ctx, x, weight, bias, eps):
|
||
|
|
# catch eps being too small if the tensors are fp16
|
||
|
|
if x.dtype == torch.float16:
|
||
|
|
eps = max(eps, 1.6e-5)
|
||
|
|
|
||
|
|
# allocate output
|
||
|
|
y = torch.empty_like(x)
|
||
|
|
|
||
|
|
# reshape input data into 2D tensor
|
||
|
|
x_arg = x.reshape(-1, x.shape[-1])
|
||
|
|
M, N = x_arg.shape
|
||
|
|
|
||
|
|
# allocate mean and std, they'll be used in the backward pass
|
||
|
|
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||
|
|
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
|
||
|
|
|
||
|
|
# Less than 64KB per feature: enqueue fused kernel
|
||
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||
|
|
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||
|
|
if N > BLOCK_SIZE_N:
|
||
|
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||
|
|
|
||
|
|
if not x_arg.is_contiguous() or not y.is_contiguous():
|
||
|
|
global _triton_registered_warnings
|
||
|
|
if not _triton_registered_warnings:
|
||
|
|
logger.warning(
|
||
|
|
"Non-contiguous input tensor found. Making it contiguous,"
|
||
|
|
+ " but could have perf or trainer implications"
|
||
|
|
)
|
||
|
|
|
||
|
|
_triton_registered_warnings = True
|
||
|
|
|
||
|
|
x_arg = x_arg.contiguous()
|
||
|
|
y = y.contiguous()
|
||
|
|
|
||
|
|
# heuristics for number of warps.
|
||
|
|
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)
|
||
|
|
|
||
|
|
# enqueue kernel
|
||
|
|
# fmt: off
|
||
|
|
layer_norm_fw[(M,)](
|
||
|
|
x_arg, y, weight, bias, mean, rstd,
|
||
|
|
x_arg.stride(0),
|
||
|
|
N,
|
||
|
|
eps,
|
||
|
|
num_warps=num_warps,
|
||
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
||
|
|
affine=weight is not None
|
||
|
|
)
|
||
|
|
# fmt: on
|
||
|
|
|
||
|
|
ctx.save_for_backward(x, mean, rstd, weight)
|
||
|
|
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
|
||
|
|
ctx.num_warps = num_warps
|
||
|
|
|
||
|
|
return y.reshape_as(x)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
@custom_bwd
|
||
|
|
def backward(
|
||
|
|
ctx, dy
|
||
|
|
): # pragma: no cover # this is covered, but called directly from C++
|
||
|
|
x, mean, rstd, weight = ctx.saved_tensors
|
||
|
|
|
||
|
|
# flatten the batch dimension, if any.
|
||
|
|
# We're interested in 'samples' x norm_dimension
|
||
|
|
x = x.reshape(-1, x.size(-1))
|
||
|
|
M, N = x.size()
|
||
|
|
|
||
|
|
# heuristics for amount of parallel reduction stream for DG/DB
|
||
|
|
GROUP_SIZE_M = 32
|
||
|
|
if N <= 8192:
|
||
|
|
GROUP_SIZE_M = 64
|
||
|
|
if N <= 4096:
|
||
|
|
GROUP_SIZE_M = 96
|
||
|
|
if N <= 2048:
|
||
|
|
GROUP_SIZE_M = 128
|
||
|
|
if N <= 1024:
|
||
|
|
GROUP_SIZE_M = 256
|
||
|
|
|
||
|
|
if dy.dtype == torch.float32:
|
||
|
|
GROUP_SIZE_M = GROUP_SIZE_M // 2
|
||
|
|
|
||
|
|
# allocate output
|
||
|
|
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
|
||
|
|
t_args = {"dtype": x.dtype, "device": x.device}
|
||
|
|
_dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args)
|
||
|
|
_db = torch.empty_like(_dw)
|
||
|
|
dw = torch.empty((x.size(-1),), **t_args)
|
||
|
|
db = torch.empty_like(dw)
|
||
|
|
dy = dy.contiguous()
|
||
|
|
dx = torch.empty_like(dy)
|
||
|
|
|
||
|
|
# Check the tensor shapes and layouts
|
||
|
|
# we suppose in the kernel that they have the same size and are contiguous
|
||
|
|
assert (
|
||
|
|
dy.numel() == x.numel()
|
||
|
|
), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
|
||
|
|
|
||
|
|
# enqueue kernel using forward pass heuristics
|
||
|
|
# also compute partial sums for DW and DB
|
||
|
|
num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)
|
||
|
|
|
||
|
|
# fmt: off
|
||
|
|
layer_norm_bwd_dx_fused[(M,)](
|
||
|
|
dx, dy, _dw, _db, x,
|
||
|
|
weight if weight is not None else x,
|
||
|
|
mean, rstd,
|
||
|
|
locks,
|
||
|
|
x.stride(0),
|
||
|
|
N,
|
||
|
|
affine=weight is not None,
|
||
|
|
GROUP_SIZE_M=GROUP_SIZE_M,
|
||
|
|
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
|
||
|
|
num_warps=num_warps
|
||
|
|
)
|
||
|
|
# fmt: on
|
||
|
|
|
||
|
|
def grid(meta):
|
||
|
|
return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
|
||
|
|
|
||
|
|
# accumulate partial sums in separate kernel
|
||
|
|
# fmt: off
|
||
|
|
layer_norm_bwd_dwdb[grid](
|
||
|
|
_dw, _db, dw, db,
|
||
|
|
GROUP_SIZE_M,
|
||
|
|
N,
|
||
|
|
BLOCK_SIZE_M=32,
|
||
|
|
BLOCK_SIZE_N=64
|
||
|
|
)
|
||
|
|
# fmt: on
|
||
|
|
|
||
|
|
dx = dx.reshape_as(dy)
|
||
|
|
return dx, dw, db, None
|
||
|
|
|
||
|
|
|
||
|
|
class FusedLayerNorm(nn.Module):
|
||
|
|
"""
|
||
|
|
Handle a layer normalization, like torch.nn.LayerNorm_.
|
||
|
|
|
||
|
|
This implementation should be measurably faster than the default PyTorch layernorm (as of PyTorch 1.9),
|
||
|
|
both for training and inference worloads.
|
||
|
|
|
||
|
|
.. NOTE: Computations under Torch AMP are kept as float32 by default, one can change this to be float16
|
||
|
|
by setting the flag `xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True`
|
||
|
|
|
||
|
|
.. _torch.nn.LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
|
||
|
|
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, normalized_shape, affine=True, eps=1e-06):
|
||
|
|
super().__init__()
|
||
|
|
if affine:
|
||
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
||
|
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||
|
|
else:
|
||
|
|
self.weight = self.bias = None
|
||
|
|
self.epsilon = eps
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
return layer_norm(x, self.weight, self.bias, self.epsilon)
|
||
|
|
|
||
|
|
def init_weights(self, *args, **kwargs):
|
||
|
|
with torch.no_grad():
|
||
|
|
if self.weight is not None:
|
||
|
|
self.weight.fill_(1.0)
|
||
|
|
|
||
|
|
if self.bias is not None:
|
||
|
|
self.bias.fill_(0.0)
|
||
|
|
|
||
|
|
|
||
|
|
def layer_norm(
|
||
|
|
x: torch.Tensor,
|
||
|
|
weight: Optional[torch.Tensor] = None,
|
||
|
|
bias: Optional[torch.Tensor] = None,
|
||
|
|
eps: float = 1e-06,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
|
||
|
|
global _triton_registered_warnings
|
||
|
|
|
||
|
|
r"""Applies normalization over a mini batch of inputs"""
|
||
|
|
|
||
|
|
try:
|
||
|
|
if (
|
||
|
|
not _triton_registered_warnings
|
||
|
|
and torch.cuda.is_available()
|
||
|
|
and x.is_cuda
|
||
|
|
and weight is not None
|
||
|
|
and bias is not None
|
||
|
|
):
|
||
|
|
return _LayerNorm.apply(x, weight, bias, eps)
|
||
|
|
except RuntimeError as e:
|
||
|
|
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
|
||
|
|
# fallback to PyTorch's implementation, which streams the tensor in and out
|
||
|
|
_triton_registered_warnings = True
|
||
|
|
logger.warning(
|
||
|
|
"Triton layernorm kernel register spillover or invalid image caught. "
|
||
|
|
"Deactivating this kernel, please file an issue in the xFormers repository"
|
||
|
|
)
|
||
|
|
logger.warning(e)
|
||
|
|
|
||
|
|
return torch.nn.functional.layer_norm(
|
||
|
|
x, [x.shape[-1]], weight=weight, bias=bias, eps=eps
|
||
|
|
)
|