feat: use FlashInfer rmsnorm and silu (#907)

This commit is contained in:
Yineng Zhang
2024-08-11 12:57:13 +08:00
committed by GitHub
parent 43fbb6d919
commit 94752ac811
6 changed files with 156 additions and 10 deletions

View File

@@ -0,0 +1,29 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import torch.nn as nn
from flashinfer.activation import silu_and_mul
class SiluAndMul(nn.Module):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out

View File

@@ -0,0 +1,62 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual

View File

@@ -23,8 +23,6 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
class InternLM2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
class InternLM2Attention(nn.Module):
def __init__(
self,
hidden_size: int,
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
class InternLMDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
class InternLM2Model(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,

View File

@@ -24,8 +24,6 @@ from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
@@ -39,6 +37,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata

View File

@@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
"0.1.3",
"0.1.4",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",

View File

@@ -0,0 +1,60 @@
import itertools
import unittest
import torch
from sglang.srt.layers.layernorm import RMSNorm
class TestRMSNorm(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
ADD_RESIDUAL = [False, True]
SEEDS = [0]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
torch.manual_seed(seed)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
with torch.inference_mode():
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)
if add_residual:
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))
def test_rms_norm(self):
for params in itertools.product(
self.NUM_TOKENS,
self.HIDDEN_SIZES,
self.ADD_RESIDUAL,
self.DTYPES,
self.SEEDS,
):
with self.subTest(
num_tokens=params[0],
hidden_size=params[1],
add_residual=params[2],
dtype=params[3],
seed=params[4],
):
self._run_rms_norm_test(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)