Files

33 lines
1.2 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
from typing import Optional, Tuple, Union
import torch
def RMSNorm_forward_vacc(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if residual is not None:
# x = x + residual
# residual = x
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
# x_var=x_var.unsqueeze(0)
# out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon)
# if residual is None:
# return out.squeeze(0)
# else:
# return out.squeeze(0), residual
out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual)
return out