From 23db56a34063c5285f7d34e80e45c1888b877bd3 Mon Sep 17 00:00:00 2001 From: huangdong2022 <161736910+huangdong2022@users.noreply.github.com> Date: Thu, 9 Oct 2025 20:18:10 +0800 Subject: [PATCH] [Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with norm bias (#3205) ### What this PR does / why we need it? 1. qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and quant op' during quantization scene. 2. torch_npu.add_rms_norm_quant op fixed accuracy while model weights is quantized by anti_method m4, m4 quantization is asymmetric outlier suppression method, it will generate none-zero norm bias, add_rms_norm_quant op updated to add this parameter to calculate. ### Does this PR introduce _any_ user-facing change? please use a torch_npu version >= torch_npu-2.7.1.dev20250919 ### How was this patch tested? 1. no special parameters to set, no new envs to set. 2. use qwen3 moe quantization model to test ,such as Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8, Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: huangdong2022 Signed-off-by: h30027576 --- tests/ut/ops/test_layernorm.py | 24 ++++++++--- vllm_ascend/ascend_forward_context.py | 4 +- vllm_ascend/ops/layernorm.py | 61 +++++++++++++++------------ vllm_ascend/utils.py | 8 +--- 4 files changed, 57 insertions(+), 40 deletions(-) diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index b0c05a2..dd99088 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -24,7 +24,7 @@ def mock_add_rms_norm(x, residual, weight, eps): def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset, - epsilon): + beta, epsilon): x_out = 2 * x residual_out = 2 * residual x_out_quant = x_out.to(torch.int8) @@ -94,7 +94,7 @@ class TestAscendRMSNorm(PytestBase): mock_model_instance = mocker.MagicMock() mock_forward_context.model_instance = mock_model_instance mock_model_instance.model.layers = [ - mocker.MagicMock() for _ in range(2) + mocker.MagicMock() for _ in range(3) ] mock_layer_0 = mock_model_instance.model.layers[0] @@ -124,7 +124,7 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.addrmsnorm_quant_fusion_enabled = True mock_forward_context.prefetch_mlp_enabled = False mock_forward_context.layer_idx = 0 - mock_forward_context.num_hidden_layers = 2 + mock_forward_context.num_hidden_layers = 3 mock_forward_context.fusion_linear = "gate_up_dense" # Ensure fusion and layer_idx increment are handled correctly @@ -144,18 +144,32 @@ class TestAscendRMSNorm(PytestBase): assert mock_forward_context.fusion_linear == "gate_up_dense" assert mock_forward_context.layer_idx == 1 + mock_forward_context.fusion_linear = "gate_moe" x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 3 - assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.fusion_linear == "qkv_moe" assert mock_forward_context.layer_idx == 2 x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 4 - assert mock_forward_context.fusion_linear == "qkv_dense" + assert mock_forward_context.fusion_linear == "gate_moe" assert mock_forward_context.layer_idx == 2 + # last layer returned directly + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 5 + assert mock_forward_context.fusion_linear == "qkv_moe" + assert mock_forward_context.layer_idx == 3 + + x_out, residual_out = layer.forward_oot(x, residual) + + assert mock_get_forward_context.call_count == 6 + assert mock_forward_context.fusion_linear == "qkv_moe" + assert mock_forward_context.layer_idx == 3 + if __name__ == '__main__': unittest.main() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 607f029..ad61245 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -147,12 +147,14 @@ def set_ascend_forward_context( # Once the necessary conditions are met, support for MOE models will also be added. from vllm_ascend.quantization.quant_config import AscendQuantConfig addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ - vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \ + vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3", "qwen3_moe"] and \ forward_context.layer_idx is not None if addrmsnorm_quant_fusion_enabled: forward_context.model_instance = model_instance forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" + if vllm_config.model_config.hf_config.model_type == "qwen3_moe": + forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe" forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled if num_tokens is None and attn_metadata is not None: diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 3dfca53..344a8dc 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,9 +15,10 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union, cast +from typing import Optional, Tuple, Union import torch +from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm @@ -27,6 +28,7 @@ def _addrmsnorm_forward_oot( x: torch.Tensor, residual: torch.Tensor, layer: Optional[torch.nn.Module] = None, + bias: Optional[torch.nn.Parameter] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu @@ -39,6 +41,7 @@ def _addrmsnorm_forward_oot( self.weight, layer.aclnn_input_scale, layer.aclnn_input_offset, + beta=bias, epsilon=self.variance_epsilon) else: if is_310p(): @@ -50,12 +53,31 @@ def _addrmsnorm_forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) + if bias is not None: + x.add_(bias) torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual class AscendRMSNorm(RMSNorm): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + vllm_config = get_current_vllm_config() + self.bias = None + # quantization with anti_method m4 will generate none-zero norm bias + if vllm_config is not None and vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + def forward_oot( self, x: torch.Tensor, @@ -67,10 +89,13 @@ class AscendRMSNorm(RMSNorm): residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( - self, x, residual, self.next_need_quant_fusion_linear) + self, x, residual, self.next_need_quant_fusion_linear, + self.bias) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) + if self.bias is not None: + x.add_(self.bias) return x @property @@ -100,6 +125,13 @@ class AscendRMSNorm(RMSNorm): # does not need to be repeated if not forward_context.prefetch_mlp_enabled: forward_context.layer_idx += 1 + elif fusion_linear == "qkv_moe": + next_linear = model_instance.model.layers[ + layer_idx].self_attn.qkv_proj + forward_context.fusion_linear = "gate_moe" + elif fusion_linear == "gate_moe": + forward_context.fusion_linear = "qkv_moe" + forward_context.layer_idx += 1 from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod if next_linear is not None and \ not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): @@ -107,31 +139,6 @@ class AscendRMSNorm(RMSNorm): return next_linear -class AscendQuantRMSNorm(AscendRMSNorm): - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - - def forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is not None: - x, residual = super().forward_oot(x, residual) - return x.add_(self.bias), residual - return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) - - class AscendGemmaRMSNorm(GemmaRMSNorm): def forward_oot( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 805fd57..6157914 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -505,8 +505,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) - from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, - AscendQuantRMSNorm, AscendRMSNorm) + from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, @@ -537,11 +536,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, } - if vllm_config is not None and \ - vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): - REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm - for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)