diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py new file mode 100644 index 000000000..836884dca --- /dev/null +++ b/python/sglang/srt/layers/activation.py @@ -0,0 +1,29 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import torch +import torch.nn as nn +from flashinfer.activation import silu_and_mul + + +class SiluAndMul(nn.Module): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + silu_and_mul(x, out) + return out diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py new file mode 100644 index 000000000..e29993a4c --- /dev/null +++ b/python/sglang/srt/layers/layernorm.py @@ -0,0 +1,62 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from flashinfer.norm import fused_add_rmsnorm, rmsnorm + + +class RMSNorm(nn.Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + + if residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) + return out + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 394d00504..f2947e991 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -23,8 +23,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata class InternLM2MLP(nn.Module): - def __init__( self, hidden_size: int, @@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module): class InternLM2Attention(nn.Module): - def __init__( self, hidden_size: int, @@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module): class InternLMDecoderLayer(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module): class InternLM2Model(nn.Module): - def __init__( self, config: PretrainedConfig, @@ -254,7 +250,6 @@ class InternLM2Model(nn.Module): class InternLM2ForCausalLM(nn.Module): - def __init__( self, config: PretrainedConfig, diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index 20f8970f7..9de8d33c5 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -24,8 +24,6 @@ from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -39,6 +37,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 269aed66f..8b6766335 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs): if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.3", + "0.1.4", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/test/test_layernorm.py b/python/sglang/test/test_layernorm.py new file mode 100644 index 000000000..ab61aa804 --- /dev/null +++ b/python/sglang/test/test_layernorm.py @@ -0,0 +1,60 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.layernorm import RMSNorm + + +class TestRMSNorm(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 4096] + HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] + ADD_RESIDUAL = [False, True] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed): + torch.manual_seed(seed) + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + + with torch.inference_mode(): + ref_out = layer.forward_native(x, residual) + out = layer(x, residual) + + if add_residual: + self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)) + self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)) + else: + self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)) + + def test_rms_norm(self): + for params in itertools.product( + self.NUM_TOKENS, + self.HIDDEN_SIZES, + self.ADD_RESIDUAL, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + hidden_size=params[1], + add_residual=params[2], + dtype=params[3], + seed=params[4], + ): + self._run_rms_norm_test(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)