[Quantization] register AscendQuantRMSNorm for quantization (#2856)

### What this PR does / why we need it?

modelslim will generate self.bias for rms norm in quantization, since
RMSNorm in vllm has no this parameter, so its nesscesary
to create a AscendQuantRmsNorm.
### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?

tested by deepseek-v3.1-w8a8

<img width="2496" height="592" alt="image"
src="https://github.com/user-attachments/assets/004c6e76-3d7a-4a1f-b59f-a14304012663"
/>


- vLLM version: main
- vLLM main:
d6249d0699

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-09-11 23:14:02 +08:00
committed by GitHub
parent eab3635850
commit f5a97e8fa5
4 changed files with 35 additions and 7 deletions

View File

@@ -15,7 +15,7 @@
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, cast
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -89,3 +89,28 @@ class AscendRMSNorm(RMSNorm):
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x
class AscendQuantRMSNorm(AscendRMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)

View File

@@ -9,8 +9,6 @@ from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
AscendW8A8DynamicLinearMethod)
patched = False
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
"W4A8_DYNAMIC": {
"linear": AscendW4A8DynamicLinearMethod,

View File

@@ -24,7 +24,7 @@ import os
from contextlib import contextmanager
from enum import Enum
from threading import Lock
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import torch_npu # noqa: F401 # noqa: F401
@@ -483,7 +483,7 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
return False
def register_ascend_customop():
def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"""Register Ascend CustomOP
NOTE: if the register branch requires model type, please use `vllm.config.get_current_vllm_config`,
@@ -497,7 +497,7 @@ def register_ascend_customop():
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.layernorm import AscendRMSNorm
from vllm_ascend.ops.layernorm import AscendQuantRMSNorm, AscendRMSNorm
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
@@ -526,6 +526,11 @@ def register_ascend_customop():
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
}
if vllm_config is not None and \
vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)

View File

@@ -83,7 +83,7 @@ class NPUWorker(WorkerBase):
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
register_ascend_customop()
register_ascend_customop(vllm_config)
# init ascend config and soc version
init_ascend_config(vllm_config)
init_ascend_soc_version()