init
This commit is contained in:
42
torch_vacc/fused_ops/rms_norm.py
Normal file
42
torch_vacc/fused_ops/rms_norm.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch_vacc
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
|
||||
class FusedRMSNormFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||
output, rsigma, var = torch.ops.vacc.rms_norm_forward(input, weight, eps)
|
||||
ctx.save_for_backward(input, weight, rsigma, var)
|
||||
ctx.eps = eps
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output: torch.Tensor):
|
||||
input, weight, rsigma, var = ctx.saved_tensors
|
||||
grad_input, grad_weight = _torch_vacc.rms_norm_backward(
|
||||
grad_output, input, weight, rsigma, var, ctx.eps
|
||||
)
|
||||
return grad_input, grad_weight, None
|
||||
|
||||
|
||||
def rms_norm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
|
||||
return FusedRMSNormFunction.apply(input, weight, eps)
|
||||
|
||||
|
||||
class FusedRMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6):
|
||||
super(FusedRMSNorm, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
output = FusedRMSNormFunction.apply(hidden_states, self.weight, self.eps)
|
||||
|
||||
output = output.to(dtype)
|
||||
return output
|
||||
Reference in New Issue
Block a user