Files
enginex-bi_series-vllm/pkgs/xformers/ops/rmsnorm.py
2025-08-05 19:02:46 +08:00

114 lines
3.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
from torch import nn
from .. import _is_triton_available
def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6):
"""
RMS Normalization along the last dimension.
This is similar to torch.nn.functional.normalize but with eps being added
instead of max.
Expects x contiguous of shape (..., dim), and returns normalized data
of the same shape. For each dim-length vector x, the result has
x / sqrt( x*x.sum() + eps)
If weights are included, they are a contiguous parameter of length dim
which multiplies the result.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
assert _is_triton_available()
from .triton.rmsnorm_kernels import _rms_norm_forward
if torch.is_grad_enabled() and (
x.requires_grad or (weight is not None and weight.requires_grad)
):
raise ValueError("Gradients not supported.")
return _rms_norm_forward(x, weight, eps)
def rms_norm_add(
x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor], eps: float = 1e-6
):
"""
An addition fused with rms_norm.
z = rms_norm_add(x, y, weight, eps)
is equivalent to
x += y
z = rms_norm(x, weight, eps)
where x, y and z are all contiguous.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
if torch.is_grad_enabled() and (
x.requires_grad
or y.requires_grad
or (weight is not None and weight.requires_grad)
):
raise ValueError("Gradients not supported.")
assert _is_triton_available()
from .triton.rmsnorm_kernels import _rms_norm_add_forward
return _rms_norm_add_forward(x, y, weight, eps)
class RMSNorm(torch.nn.Module):
"""
RMS Normalization layer along the last dimension.
This is similar to torch.nn.functional.normalize but with eps being added
instead of max.
Expects contiguous input of shape (..., dim), and returns normalized data
of the same shape. For each dim-length vector x, the result has
x / sqrt( x*x.sum() + eps)
If weights are included, they are a parameter of length dim which multiplies
the result.
This functionality is experimental. Its API might be changed without warnings.
Use it at your own risk.
"""
def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6):
super().__init__()
self.eps = eps
if include_weight:
self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, x: torch.Tensor):
return rms_norm(x, self.weight, self.eps) # type: ignore
def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
"""
An addition fused with forward.
z = layer.increment_and_forward_(x, y)
is equivalent to
x += y
z = layer(x)
"""
return rms_norm_add(x, y, self.weight, self.eps) # type: ignore