33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from typing import Optional, Tuple, Union
|
|
import torch
|
|
|
|
|
|
def RMSNorm_forward_vacc(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
# if residual is not None:
|
|
# x = x + residual
|
|
# residual = x
|
|
hidden_size = x.shape[-1]
|
|
if hidden_size != self.hidden_size:
|
|
raise ValueError("Expected hidden_size to be "
|
|
f"{self.hidden_size}, but found: {hidden_size}")
|
|
if self.variance_size_override is None:
|
|
x_var = x
|
|
else:
|
|
if hidden_size < self.variance_size_override:
|
|
raise ValueError(
|
|
"Expected hidden_size to be at least "
|
|
f"{self.variance_size_override}, but found: {hidden_size}")
|
|
x_var = x[:, :, :self.variance_size_override]
|
|
# x_var=x_var.unsqueeze(0)
|
|
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
|
|
# if residual is None:
|
|
# return out.squeeze(0)
|
|
# else:
|
|
# return out.squeeze(0), residual
|
|
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
|
|
return out
|