[CustomOp] Register RMSNorm instead of overwrite forward_oot (#2284)

### What this PR does / why we need it?
Use function CustomOp.register_oot to achieve the customop registery
```
from vllm.model_executor.custom_op import CustomOp
CustomOp.register_oot(_decorated_op_cls=AscendRMSNorm, name="RMSNorm")
```

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.10.0
- vLLM main:
afa5b7ca0b

---------

Signed-off-by: Icey <1790571317@qq.com>
This commit is contained in:
Icey
2025-08-14 17:18:30 +08:00
committed by GitHub
parent e14f2ef669
commit c721ae6042
4 changed files with 85 additions and 28 deletions

View File

@@ -20,8 +20,6 @@ from typing import Optional, Tuple, Union
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm_ascend.utils import is_310p
class AddRMSNormW8A8Quant(RMSNorm):
# Fuse AddRmsNorm and W8A8 quantization ops together
@@ -60,27 +58,28 @@ class AddRMSNormW8A8Quant(RMSNorm):
return x
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
class AscendRMSNorm(RMSNorm):
if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon)
return x
from vllm_ascend.utils import is_310p
if residual is not None:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
return x, residual
RMSNorm.forward_oot = forward_oot
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x