feat: update GemmaRMSNorm (#1232)

This commit is contained in:
Yineng Zhang
2024-08-28 22:47:34 +10:00
committed by GitHub
parent 66975360e7
commit b1a540ec42
3 changed files with 101 additions and 53 deletions

View File

@@ -19,7 +19,12 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
rmsnorm,
)
from vllm.model_executor.custom_op import CustomOp
@@ -63,3 +68,44 @@ class RMSNorm(CustomOp):
return x
else:
return x, residual
class GemmaRMSNorm(CustomOp):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
if residual is not None:
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
gemma_fused_add_rmsnorm(
x, residual, self.weight.data, self.variance_epsilon
)
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out