Files
xc-llm-ascend/vllm_ascend/_310p/ops/layernorm.py
Shaoxu Cheng 83bd77c983 [310p]: add rmsnorm gated fallback and unit test (#7424)
### What this PR does / why we need it?
RFC #7394
310P cannot use the fused `rmsnormgated` operator and must fall back to
the native implementation.

### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
ut
- vLLM version: v0.17.0
- vLLM main:
4497431df6

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
2026-03-24 09:00:11 +08:00

52 lines
1.7 KiB
Python

import torch
import torch_npu
from vllm.model_executor.layers.layernorm import RMSNormGated
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
class AscendRMSNorm310(AscendRMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is not None:
x, _, residual = torch_npu.npu_add_rms_norm(x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x, residual
x, _ = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x
class AscendGemmaRMSNorm310(AscendGemmaRMSNorm):
def forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is not None:
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
return x, residual
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon)
return x
class AscendRMSNormGated310(RMSNormGated):
def forward_oot(
self,
x: torch.Tensor,
z: torch.Tensor | None = None,
) -> torch.Tensor:
# 310P should not depend on the Triton-gated layernorm path.
# Reuse the upstream native implementation directly.
return super().forward_native(x, z)