fix: support gelu_new activation function in gpt2 (#3712)

This commit is contained in:
Xiuyu Li
2025-03-04 04:09:52 -08:00
committed by GitHub
parent 37373ef2bb
commit 9545bfb28a
2 changed files with 24 additions and 7 deletions

View File

@@ -14,6 +14,7 @@
"""Fused operators for activation layers.""" """Fused operators for activation layers."""
import logging import logging
import math
from typing import Optional from typing import Optional
import torch import torch
@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp):
return out return out
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
return self.forward_native(x)
class QuickGELU(CustomOp): class QuickGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)

View File

@@ -17,14 +17,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Optional, Tuple, Type
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import NewGELU
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module):
self, self,
intermediate_size: int, intermediate_size: int,
config: GPT2Config, config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn( self.act = act_layer()
config.activation_function, quant_config, intermediate_size
)
def forward( def forward(
self, self,
@@ -136,6 +135,7 @@ class GPT2Block(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
act_layer: Type[nn.Module] = NewGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
@@ -148,7 +148,13 @@ class GPT2Block(nn.Module):
layer_id, config, quant_config, prefix=f"{prefix}.attn" layer_id, config, quant_config, prefix=f"{prefix}.attn"
) )
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") self.mlp = GPT2MLP(
inner_dim,
config,
act_layer=act_layer,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward( def forward(
self, self,
@@ -190,7 +196,7 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPT2Block(i, config, quant_config) GPT2Block(i, config, quant_config=quant_config)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )