[Feat.][310P] addrmsnorm for 300I DUO (#6704)

### What this PR does / why we need it?
This PR integrates the `npu_add_rms_norm` fused kernel for RMSNorm
operations with residual connections on 310P devices. This change
optimizes the computation by replacing a two-step process (manual
residual addition followed by RMSNorm) with a single, more efficient
fused operation. This is needed to improve the performance of models
utilizing RMSNorm with residual connections on the 310P architecture.

Fixes #

### Does this PR introduce _any_ user-facing change?
No, this PR introduces an internal optimization and does not change any
user-facing APIs or behaviors.

### How was this patch tested?
This patch was tested with updated unit tests
(`test_RMSNorm_forward_310p`) that mock the `npu_add_rms_norm` operation
to verify the correctness of the fused kernel integration.

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
Shaoxu Cheng
2026-02-13 15:40:49 +08:00
committed by GitHub
parent 7164990904
commit f40256b697
4 changed files with 12 additions and 80 deletions

View File

@@ -11,13 +11,9 @@ class AscendRMSNorm310(AscendRMSNorm):
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is not None:
if x is None or x.numel() == 0 or x.shape[-1] == 0:
x = residual
else:
x = x + residual
residual = x
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x, residual
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)