diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 5df387cb2..9047197af 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -18,7 +18,7 @@ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul +from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -43,18 +43,24 @@ class SiluAndMul(CustomOp): class GeluAndMul(CustomOp): - def __init__(self, **kwargs): + def __init__(self, approximate="tanh"): super().__init__() + self.approximate = approximate def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - return F.gelu(x[..., :d], approximate="tanh") * x[..., d:] + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - gelu_tanh_and_mul(x, out) + if self.approximate == "tanh": + gelu_tanh_and_mul(x, out) + elif self.approximate == "none": + gelu_and_mul(x, out) + else: + raise RuntimeError("GeluAndMul only support tanh or none") return out diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 990937f51..ae3b1b194 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -23,7 +23,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention @@ -60,7 +60,7 @@ class GemmaMLP(nn.Module): bias=False, quant_config=quant_config, ) - self.act_fn = GeluAndMul() + self.act_fn = GeluAndMul("none") def forward(self, x): gate_up, _ = self.gate_up_proj(x) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index e38584741..08288c510 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -96,7 +96,7 @@ class TestGenerationModels(unittest.TestCase): if hf_logprobs.shape[0] <= 100: assert torch.all( abs(hf_logprobs - srt_logprobs) < prefill_tolerance - ), "prefill logprobs are not all close" + ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}")