npu fused op (#7386)
Co-authored-by: Li Junwen <lijunwen13@hisilicon.com>
This commit is contained in:
@@ -48,6 +48,9 @@ if _is_cuda:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_npu():
|
||||
import torch_npu
|
||||
|
||||
|
||||
class SiluAndMul(CustomOp):
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
|
||||
else:
|
||||
return self.forward_native(x)
|
||||
|
||||
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch_npu.npu_swiglu(x)
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(CustomOp):
|
||||
def __init__(self, approximate="tanh"):
|
||||
|
||||
Reference in New Issue
Block a user