feat: replace GeluAndMul (#1234)
This commit is contained in:
@@ -18,7 +18,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
||||
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp):
|
||||
|
||||
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, approximate="tanh"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
gelu_tanh_and_mul(x, out)
|
||||
if self.approximate == "tanh":
|
||||
gelu_tanh_and_mul(x, out)
|
||||
elif self.approximate == "none":
|
||||
gelu_and_mul(x, out)
|
||||
else:
|
||||
raise RuntimeError("GeluAndMul only support tanh or none")
|
||||
return out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user