feat: use gelu_tanh_and_mul (#1193)

This commit is contained in:
Yineng Zhang
2024-08-24 18:58:16 +10:00
committed by GitHub
parent a5b14ad043
commit c9064e6fd9
3 changed files with 74 additions and 3 deletions

View File

@@ -15,7 +15,7 @@ limitations under the License.
import torch
import torch.nn.functional as F
from flashinfer.activation import silu_and_mul
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
from vllm.model_executor.custom_op import CustomOp
@@ -37,3 +37,19 @@ class SiluAndMul(CustomOp):
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
silu_and_mul(x, out)
return out
class GeluAndMul(CustomOp):
def __init__(self, **kwargs):
super().__init__()
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
gelu_tanh_and_mul(x, out)
return out