npu fused op (#7386)

Co-authored-by: Li Junwen <lijunwen13@hisilicon.com>
This commit is contained in:
ll819214
2025-06-25 16:54:20 +08:00
committed by GitHub
parent a07f8ae4b7
commit 506a2d5934
4 changed files with 70 additions and 2 deletions

View File

@@ -52,6 +52,9 @@ elif _is_hip:
logger = logging.getLogger(__name__)
if is_npu():
import torch_npu
class RMSNorm(CustomOp):
def __init__(
@@ -76,6 +79,18 @@ class RMSNorm(CustomOp):
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_npu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
out, _, residual_out = torch_npu.npu_add_rms_norm(
residual, x, self.weight.data, self.variance_epsilon
)
return out, residual_out
return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]
def forward_aiter(
self,
x: torch.Tensor,