npu fused op (#7386)

Co-authored-by: Li Junwen <lijunwen13@hisilicon.com>
This commit is contained in:
ll819214
2025-06-25 16:54:20 +08:00
committed by GitHub
parent a07f8ae4b7
commit 506a2d5934
4 changed files with 70 additions and 2 deletions

View File

@@ -48,6 +48,9 @@ if _is_cuda:
logger = logging.getLogger(__name__)
if is_npu():
import torch_npu
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -70,6 +73,10 @@ class SiluAndMul(CustomOp):
else:
return self.forward_native(x)
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
out = torch_npu.npu_swiglu(x)
return out
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):