feat: update GemmaRMSNorm (#1232)
This commit is contained in:
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||
from flashinfer.norm import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
gemma_rmsnorm,
|
||||
rmsnorm,
|
||||
)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
|
||||
@@ -63,3 +68,44 @@ class RMSNorm(CustomOp):
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if residual is not None:
|
||||
gemma_fused_add_rmsnorm(
|
||||
x, residual, self.weight.data, self.variance_epsilon
|
||||
)
|
||||
return x, residual
|
||||
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
@@ -22,11 +22,6 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -39,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
@@ -50,52 +46,6 @@ def get_attention_sliding_window_size(config):
|
||||
return config.sliding_window - 1
|
||||
|
||||
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
Two differences from the above RMSNorm:
|
||||
1. x * (1 + w) instead of x * w.
|
||||
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
orig_dtype = x.dtype
|
||||
if residual is not None:
|
||||
x = x + residual
|
||||
residual = x
|
||||
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
x = x * (1.0 + self.weight.float())
|
||||
x = x.to(orig_dtype)
|
||||
return x if residual is None else (x, residual)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
||||
return self.forward_native(x, residual)
|
||||
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
|
||||
class TestRMSNorm(unittest.TestCase):
|
||||
@@ -56,5 +56,57 @@ class TestRMSNorm(unittest.TestCase):
|
||||
self._run_rms_norm_test(*params)
|
||||
|
||||
|
||||
class TestGemmaRMSNorm(unittest.TestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 4096]
|
||||
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("CUDA is not available")
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
def _run_gemma_rms_norm_test(
|
||||
self, num_tokens, hidden_size, add_residual, dtype, seed
|
||||
):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
layer = GemmaRMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = layer.forward_native(x, residual)
|
||||
out = layer(x, residual)
|
||||
|
||||
if add_residual:
|
||||
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3))
|
||||
else:
|
||||
self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_gemma_rms_norm(self):
|
||||
for params in itertools.product(
|
||||
self.NUM_TOKENS,
|
||||
self.HIDDEN_SIZES,
|
||||
self.ADD_RESIDUAL,
|
||||
self.DTYPES,
|
||||
self.SEEDS,
|
||||
):
|
||||
with self.subTest(
|
||||
num_tokens=params[0],
|
||||
hidden_size=params[1],
|
||||
add_residual=params[2],
|
||||
dtype=params[3],
|
||||
seed=params[4],
|
||||
):
|
||||
self._run_gemma_rms_norm_test(*params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user