From f5a97e8fa5440df6735d1121f813cda7f1257367 Mon Sep 17 00:00:00 2001 From: 22dimensions Date: Thu, 11 Sep 2025 23:14:02 +0800 Subject: [PATCH] [Quantization] register AscendQuantRMSNorm for quantization (#2856) ### What this PR does / why we need it? modelslim will generate self.bias for rms norm in quantization, since RMSNorm in vllm has no this parameter, so its nesscesary to create a AscendQuantRmsNorm. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? tested by deepseek-v3.1-w8a8 image - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/d6249d069965f88ff0042b638704f4cc66d52de4 Signed-off-by: 22dimensions --- vllm_ascend/ops/layernorm.py | 27 ++++++++++++++++++++++++++- vllm_ascend/quantization/utils.py | 2 -- vllm_ascend/utils.py | 11 ++++++++--- vllm_ascend/worker/worker_v1.py | 2 +- 4 files changed, 35 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index d97d771..ccd031c 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, cast import torch from vllm.model_executor.layers.layernorm import RMSNorm @@ -89,3 +89,28 @@ class AscendRMSNorm(RMSNorm): x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) return x + + +class AscendQuantRMSNorm(AscendRMSNorm): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x, residual = super().forward_oot(x, residual) + return x.add_(self.bias), residual + return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index f4cd0d0..dc5845a 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -9,8 +9,6 @@ from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, AscendW8A8DynamicLinearMethod) -patched = False - ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { "W4A8_DYNAMIC": { "linear": AscendW4A8DynamicLinearMethod, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index cd014d0..ca51327 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -24,7 +24,7 @@ import os from contextlib import contextmanager from enum import Enum from threading import Lock -from typing import TYPE_CHECKING, List, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch_npu # noqa: F401 # noqa: F401 @@ -483,7 +483,7 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): return False -def register_ascend_customop(): +def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): """Register Ascend CustomOP NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`, @@ -497,7 +497,7 @@ def register_ascend_customop(): from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.common_fused_moe import AscendFusedMoE - from vllm_ascend.ops.layernorm import AscendRMSNorm + from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, @@ -526,6 +526,11 @@ def register_ascend_customop(): "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, } + if vllm_config is not None and \ + vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm + for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index ef23645..8af3d31 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -83,7 +83,7 @@ class NPUWorker(WorkerBase): from vllm_ascend import ops ops.register_dummy_fusion_op() _register_atb_extensions() - register_ascend_customop() + register_ascend_customop(vllm_config) # init ascend config and soc version init_ascend_config(vllm_config) init_ascend_soc_version()