feat: use FlashInfer rmsnorm and silu (#907)
This commit is contained in:
29
python/sglang/srt/layers/activation.py
Normal file
29
python/sglang/srt/layers/activation.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from flashinfer.activation import silu_and_mul
|
||||||
|
|
||||||
|
|
||||||
|
class SiluAndMul(nn.Module):
|
||||||
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.silu(x[..., :d]) * x[..., d:]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = x.shape[:-1] + (d,)
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
silu_and_mul(x, out)
|
||||||
|
return out
|
||||||
62
python/sglang/srt/layers/layernorm.py
Normal file
62
python/sglang/srt/layers/layernorm.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
Copyright 2023-2024 SGLang Team
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
return x, residual
|
||||||
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * self.weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
@@ -23,8 +23,6 @@ from torch import nn
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
class InternLM2MLP(nn.Module):
|
class InternLM2MLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class InternLM2Attention(nn.Module):
|
class InternLM2Attention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class InternLMDecoderLayer(nn.Module):
|
class InternLMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class InternLM2Model(nn.Module):
|
class InternLM2Model(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class InternLM2ForCausalLM(nn.Module):
|
class InternLM2ForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ from torch import nn
|
|||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -39,6 +37,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|||||||
@@ -384,7 +384,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"0.1.3",
|
"0.1.4",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
60
python/sglang/test/test_layernorm.py
Normal file
60
python/sglang/test/test_layernorm.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class TestRMSNorm(unittest.TestCase):
|
||||||
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
|
NUM_TOKENS = [7, 83, 4096]
|
||||||
|
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
||||||
|
ADD_RESIDUAL = [False, True]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _run_rms_norm_test(self, num_tokens, hidden_size, add_residual, dtype, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
scale = 1 / (2 * hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
ref_out = layer.forward_native(x, residual)
|
||||||
|
out = layer(x, residual)
|
||||||
|
|
||||||
|
if add_residual:
|
||||||
|
self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2))
|
||||||
|
self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2))
|
||||||
|
else:
|
||||||
|
self.assertTrue(torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2))
|
||||||
|
|
||||||
|
def test_rms_norm(self):
|
||||||
|
for params in itertools.product(
|
||||||
|
self.NUM_TOKENS,
|
||||||
|
self.HIDDEN_SIZES,
|
||||||
|
self.ADD_RESIDUAL,
|
||||||
|
self.DTYPES,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
num_tokens=params[0],
|
||||||
|
hidden_size=params[1],
|
||||||
|
add_residual=params[2],
|
||||||
|
dtype=params[3],
|
||||||
|
seed=params[4],
|
||||||
|
):
|
||||||
|
self._run_rms_norm_test(*params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user