43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import torch
|
|
import torch_vacc
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
|
|
|
|
|
class FusedRMSNormFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
|
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
|
|
ctx.save_for_backward(input, weight, rsigma, var)
|
|
ctx.eps = eps
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output: torch.Tensor):
|
|
input, weight, rsigma, var = ctx.saved_tensors
|
|
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
|
|
grad_output, input, weight, rsigma, var, ctx.eps
|
|
)
|
|
return grad_input, grad_weight, None
|
|
|
|
|
|
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
|
return FusedRMSNormFunction.apply(input, weight, eps)
|
|
|
|
|
|
class FusedRMSNorm(torch.nn.Module):
|
|
def __init__(self, hidden_size, eps: float = 1e-6):
|
|
super(FusedRMSNorm, self).__init__()
|
|
self.eps = eps
|
|
|
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
|
|
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
|
|
|
|
output = output.to(dtype)
|
|
return output
|