Files
enginex-bi_series-vllm/pkgs/xformers/triton/layer_norm.py

237 lines
7.5 KiB
Python
Raw Normal View History

2025-08-05 19:02:46 +08:00
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: the underlying kernel comes straight from the Triton tutorials
# see https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
import logging
from typing import Optional
import torch
import torch.nn as nn
import triton
from torch.cuda.amp import custom_bwd, custom_fwd
from xformers.triton.k_layer_norm import (
layer_norm_bwd_dwdb,
layer_norm_bwd_dx_fused,
layer_norm_fw,
)
logger = logging.getLogger("xformers")
_triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32
_triton_registered_warnings = False
class _LayerNorm(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None)
def forward(ctx, x, weight, bias, eps):
# catch eps being too small if the tensors are fp16
if x.dtype == torch.float16:
eps = max(eps, 1.6e-5)
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
# allocate mean and std, they'll be used in the backward pass
mean = torch.empty((M,), dtype=torch.float32, device="cuda")
rstd = torch.empty((M,), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
if not x_arg.is_contiguous() or not y.is_contiguous():
global _triton_registered_warnings
if not _triton_registered_warnings:
logger.warning(
"Non-contiguous input tensor found. Making it contiguous,"
+ " but could have perf or trainer implications"
)
_triton_registered_warnings = True
x_arg = x_arg.contiguous()
y = y.contiguous()
# heuristics for number of warps.
num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16)
# enqueue kernel
# fmt: off
layer_norm_fw[(M,)](
x_arg, y, weight, bias, mean, rstd,
x_arg.stride(0),
N,
eps,
num_warps=num_warps,
BLOCK_SIZE_N=BLOCK_SIZE_N,
affine=weight is not None
)
# fmt: on
ctx.save_for_backward(x, mean, rstd, weight)
ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
ctx.num_warps = num_warps
return y.reshape_as(x)
@staticmethod
@custom_bwd
def backward(
ctx, dy
): # pragma: no cover # this is covered, but called directly from C++
x, mean, rstd, weight = ctx.saved_tensors
# flatten the batch dimension, if any.
# We're interested in 'samples' x norm_dimension
x = x.reshape(-1, x.size(-1))
M, N = x.size()
# heuristics for amount of parallel reduction stream for DG/DB
GROUP_SIZE_M = 32
if N <= 8192:
GROUP_SIZE_M = 64
if N <= 4096:
GROUP_SIZE_M = 96
if N <= 2048:
GROUP_SIZE_M = 128
if N <= 1024:
GROUP_SIZE_M = 256
if dy.dtype == torch.float32:
GROUP_SIZE_M = GROUP_SIZE_M // 2
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
t_args = {"dtype": x.dtype, "device": x.device}
_dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args)
_db = torch.empty_like(_dw)
dw = torch.empty((x.size(-1),), **t_args)
db = torch.empty_like(dw)
dy = dy.contiguous()
dx = torch.empty_like(dy)
# Check the tensor shapes and layouts
# we suppose in the kernel that they have the same size and are contiguous
assert (
dy.numel() == x.numel()
), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm"
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16)
# fmt: off
layer_norm_bwd_dx_fused[(M,)](
dx, dy, _dw, _db, x,
weight if weight is not None else x,
mean, rstd,
locks,
x.stride(0),
N,
affine=weight is not None,
GROUP_SIZE_M=GROUP_SIZE_M,
BLOCK_SIZE_N=ctx.BLOCK_SIZE_N,
num_warps=num_warps
)
# fmt: on
def grid(meta):
return [triton.cdiv(N, meta["BLOCK_SIZE_N"])]
# accumulate partial sums in separate kernel
# fmt: off
layer_norm_bwd_dwdb[grid](
_dw, _db, dw, db,
GROUP_SIZE_M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=64
)
# fmt: on
dx = dx.reshape_as(dy)
return dx, dw, db, None
class FusedLayerNorm(nn.Module):
"""
Handle a layer normalization, like torch.nn.LayerNorm_.
This implementation should be measurably faster than the default PyTorch layernorm (as of PyTorch 1.9),
both for training and inference worloads.
.. NOTE: Computations under Torch AMP are kept as float32 by default, one can change this to be float16
by setting the flag `xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True`
.. _torch.nn.LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
"""
def __init__(self, normalized_shape, affine=True, eps=1e-06):
super().__init__()
if affine:
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
else:
self.weight = self.bias = None
self.epsilon = eps
def forward(self, x):
return layer_norm(x, self.weight, self.bias, self.epsilon)
def init_weights(self, *args, **kwargs):
with torch.no_grad():
if self.weight is not None:
self.weight.fill_(1.0)
if self.bias is not None:
self.bias.fill_(0.0)
def layer_norm(
x: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-06,
) -> torch.Tensor:
global _triton_registered_warnings
r"""Applies normalization over a mini batch of inputs"""
try:
if (
not _triton_registered_warnings
and torch.cuda.is_available()
and x.is_cuda
and weight is not None
and bias is not None
):
return _LayerNorm.apply(x, weight, bias, eps)
except RuntimeError as e:
# Catch cases where the current GPU does not have enough registers to hold a full tensor line
# fallback to PyTorch's implementation, which streams the tensor in and out
_triton_registered_warnings = True
logger.warning(
"Triton layernorm kernel register spillover or invalid image caught. "
"Deactivating this kernel, please file an issue in the xFormers repository"
)
logger.warning(e)
return torch.nn.functional.layer_norm(
x, [x.shape[-1]], weight=weight, bias=bias, eps=eps
)