From c9064e6fd9a5356ee579e9d452bfad725f8e6f2c Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 24 Aug 2024 18:58:16 +1000 Subject: [PATCH] feat: use gelu_tanh_and_mul (#1193) --- python/sglang/srt/layers/activation.py | 18 ++++++++- python/sglang/srt/models/gemma2.py | 4 +- python/sglang/test/test_activation.py | 55 ++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 python/sglang/test/test_activation.py diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index a6f05610b..d0e062660 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -15,7 +15,7 @@ limitations under the License. import torch import torch.nn.functional as F -from flashinfer.activation import silu_and_mul +from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul from vllm.model_executor.custom_op import CustomOp @@ -37,3 +37,19 @@ class SiluAndMul(CustomOp): out = torch.empty(output_shape, dtype=x.dtype, device=x.device) silu_and_mul(x, out) return out + + +class GeluAndMul(CustomOp): + def __init__(self, **kwargs): + super().__init__() + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate="tanh") * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + gelu_tanh_and_mul(x, out) + return out diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 80b99742e..37d926c34 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size # FIXME: temporary solution, remove after next vllm release from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import GeluAndMul # from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import ( @@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module): "function. Please set `hidden_act` and `hidden_activation` to " "`gelu_pytorch_tanh`." ) - self.act_fn = GeluAndMul(approximate="tanh") + self.act_fn = GeluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py new file mode 100644 index 000000000..357a23319 --- /dev/null +++ b/python/sglang/test/test_activation.py @@ -0,0 +1,55 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import GeluAndMul + + +class TestGeluAndMul(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed): + torch.manual_seed(seed) + + layer = GeluAndMul().to(dtype=dtype) + x = torch.randn(num_tokens, 2 * d, dtype=dtype) + + with torch.inference_mode(): + ref_out = layer.forward_native(x) + out = layer.forward_cuda(x) + + if dtype == torch.bfloat16: + atol = rtol = 1e-2 + else: + atol = rtol = 1e-3 + + self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol)) + + def test_gelu_and_mul(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + seed=params[3], + ): + self._run_gelu_and_mul_test(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)