Files
enginex-vastai-va16-vllm/torch_vacc/fused_ops/rms_norm.py
2026-04-02 04:55:00 +00:00

43 lines
1.4 KiB
Python

import torch
import torch_vacc
from torch_vacc._vacc_libs import _torch_vacc
class FusedRMSNormFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
ctx.save_for_backward(input, weight, rsigma, var)
ctx.eps = eps
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
input, weight, rsigma, var = ctx.saved_tensors
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
grad_output, input, weight, rsigma, var, ctx.eps
)
return grad_input, grad_weight, None
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
return FusedRMSNormFunction.apply(input, weight, eps)
class FusedRMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps: float = 1e-6):
super(FusedRMSNorm, self).__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
output = output.to(dtype)
return output