import torch import torch_vacc from torch_vacc._vacc_libs import _torch_vacc class FusedRMSNormFunction(torch.autograd.Function): @staticmethod def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps) ctx.save_for_backward(input, weight, rsigma, var) ctx.eps = eps return output @staticmethod def backward(ctx, grad_output: torch.Tensor): input, weight, rsigma, var = ctx.saved_tensors grad_input, grad_weight = _torch_vacc.rms_norm_backward( grad_output, input, weight, rsigma, var, ctx.eps ) return grad_input, grad_weight, None def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): return FusedRMSNormFunction.apply(input, weight, eps) class FusedRMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps: float = 1e-6): super(FusedRMSNorm, self).__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps) output = output.to(dtype) return output