diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 82c39c2ac..1b8da93fe 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -14,6 +14,7 @@ """Fused operators for activation layers.""" import logging +import math from typing import Optional import torch @@ -72,6 +73,16 @@ class GeluAndMul(CustomOp): return out +class NewGELU(CustomOp): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # TODO: Implement the CUDA kernel for NewGELU in sgl-kernel + return self.forward_native(x) + + class QuickGELU(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: return x * torch.sigmoid(1.702 * x) diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 04c3005ce..c9b78e6f6 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -17,14 +17,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional, Tuple, Type import torch from torch import nn from transformers import GPT2Config from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size -from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.activation import NewGELU from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -97,6 +97,7 @@ class GPT2MLP(nn.Module): self, intermediate_size: int, config: GPT2Config, + act_layer: Type[nn.Module] = NewGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -116,9 +117,7 @@ class GPT2MLP(nn.Module): quant_config=quant_config, prefix=f"{prefix}.c_proj", ) - self.act = get_act_fn( - config.activation_function, quant_config, intermediate_size - ) + self.act = act_layer() def forward( self, @@ -136,6 +135,7 @@ class GPT2Block(nn.Module): self, layer_id: int, config: GPT2Config, + act_layer: Type[nn.Module] = NewGELU, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -148,7 +148,13 @@ class GPT2Block(nn.Module): layer_id, config, quant_config, prefix=f"{prefix}.attn" ) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") + self.mlp = GPT2MLP( + inner_dim, + config, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) def forward( self, @@ -190,7 +196,7 @@ class GPT2Model(nn.Module): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList( [ - GPT2Block(i, config, quant_config) + GPT2Block(i, config, quant_config=quant_config) for i in range(config.num_hidden_layers) ] )