feat: replace GeluAndMul (#1234)

This commit is contained in:
Yineng Zhang
2024-08-29 00:07:02 +10:00
committed by GitHub
parent bf53bf5142
commit c411f32e1c
3 changed files with 13 additions and 7 deletions

View File

@@ -18,7 +18,7 @@ from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp):
class GeluAndMul(CustomOp):
def __init__(self, **kwargs):
def __init__(self, approximate="tanh"):
super().__init__()
self.approximate = approximate
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
gelu_tanh_and_mul(x, out)
if self.approximate == "tanh":
gelu_tanh_and_mul(x, out)
elif self.approximate == "none":
gelu_and_mul(x, out)
else:
raise RuntimeError("GeluAndMul only support tanh or none")
return out