diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 836884dca..c767327a6 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -13,15 +13,17 @@ limitations under the License. import torch import torch.nn as nn +import torch.nn.functional as F from flashinfer.activation import silu_and_mul +from vllm.model_executor.custom_op import CustomOp -class SiluAndMul(nn.Module): +class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index e29993a4c..2a55c25e5 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from vllm.model_executor.custom_op import CustomOp -class RMSNorm(nn.Module): +class RMSNorm(CustomOp): def __init__( self, hidden_size: int, @@ -30,7 +31,7 @@ class RMSNorm(nn.Module): self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def forward( + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None,