from typing import Optional, Tuple, Union import torch def RMSNorm_forward_vacc( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: # if residual is not None: # x = x + residual # residual = x hidden_size = x.shape[-1] if hidden_size != self.hidden_size: raise ValueError("Expected hidden_size to be " f"{self.hidden_size}, but found: {hidden_size}") if self.variance_size_override is None: x_var = x else: if hidden_size < self.variance_size_override: raise ValueError( "Expected hidden_size to be at least " f"{self.variance_size_override}, but found: {hidden_size}") x_var = x[:, :, :self.variance_size_override] # x_var=x_var.unsqueeze(0) # out = torch.vacc.rms_norm(x_var,self.weight,self.variance_epsilon) # if residual is None: # return out.squeeze(0) # else: # return out.squeeze(0), residual out = torch.vacc.fused_residual_rmsnorm(x_var, self.weight, residual, self.variance_epsilon, x_var, residual) return out