fix: support gelu_new activation function in gpt2 (#3712)
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
"""Fused operators for activation layers."""
|
"""Fused operators for activation layers."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -72,6 +73,16 @@ class GeluAndMul(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class NewGELU(CustomOp):
|
||||||
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
c = math.sqrt(2.0 / math.pi)
|
||||||
|
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# TODO: Implement the CUDA kernel for NewGELU in sgl-kernel
|
||||||
|
return self.forward_native(x)
|
||||||
|
|
||||||
|
|
||||||
class QuickGELU(CustomOp):
|
class QuickGELU(CustomOp):
|
||||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return x * torch.sigmoid(1.702 * x)
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|||||||
@@ -17,14 +17,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GPT2Config
|
from transformers import GPT2Config
|
||||||
|
|
||||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.activation import get_act_fn
|
from sglang.srt.layers.activation import NewGELU
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -97,6 +97,7 @@ class GPT2MLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
act_layer: Type[nn.Module] = NewGELU,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
@@ -116,9 +117,7 @@ class GPT2MLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.c_proj",
|
prefix=f"{prefix}.c_proj",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(
|
self.act = act_layer()
|
||||||
config.activation_function, quant_config, intermediate_size
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -136,6 +135,7 @@ class GPT2Block(nn.Module):
|
|||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
config: GPT2Config,
|
config: GPT2Config,
|
||||||
|
act_layer: Type[nn.Module] = NewGELU,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
@@ -148,7 +148,13 @@ class GPT2Block(nn.Module):
|
|||||||
layer_id, config, quant_config, prefix=f"{prefix}.attn"
|
layer_id, config, quant_config, prefix=f"{prefix}.attn"
|
||||||
)
|
)
|
||||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||||
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
self.mlp = GPT2MLP(
|
||||||
|
inner_dim,
|
||||||
|
config,
|
||||||
|
act_layer=act_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -190,7 +196,7 @@ class GPT2Model(nn.Module):
|
|||||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||||
self.h = nn.ModuleList(
|
self.h = nn.ModuleList(
|
||||||
[
|
[
|
||||||
GPT2Block(i, config, quant_config)
|
GPT2Block(i, config, quant_config=quant_config)
|
||||||
for i in range(config.num_hidden_layers)
|
for i in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user