From 56269eae0e553909ab8dad60deb5c9235988a04e Mon Sep 17 00:00:00 2001 From: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Date: Thu, 12 Feb 2026 09:26:05 +0800 Subject: [PATCH] [BugFix] Fix AddRMSNormQuant not taking effect (#6620) ### What this PR does / why we need it? Fix the issue where, in graph mode, the fused `AddRMSNormQuant` operator does not take effect when there is no bias. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d7e17aaacd5ed1b4b4be6bcfef3a1b7cbc84fc9a --------- Signed-off-by: ZYang6263 --- .../e2e/singlecard/compile/test_norm_quant_fusion.py | 12 ++++++------ tests/e2e/singlecard/test_quantization.py | 4 ++-- .../compilation/passes/norm_quant_fusion_pass.py | 12 ++++++++---- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py index 99637868..b39f4f43 100644 --- a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py +++ b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py @@ -74,8 +74,8 @@ class TestModelWithoutBias(nn.Module): """ residual = torch.zeros_like(x) - norm_output, _, new_residual = torch_npu.npu_add_rms_norm( - x, residual, self.rms_norm_weight, self.eps) + norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias( + x, residual, self.rms_norm_weight, None, self.eps) quantized_output = torch.ops.vllm.quantize(norm_output, self.quant_scale, @@ -87,7 +87,7 @@ class TestModelWithoutBias(nn.Module): def ops_in_model_before(self) -> List[OpOverload]: """Return the list of expected operators BEFORE fusion.""" return [ - torch.ops.npu.npu_add_rms_norm.default, + torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.quantize.default ] @@ -187,8 +187,8 @@ class TestModelSPWithoutBias(nn.Module): """ residual = torch.zeros_like(x) - norm_output, _, new_residual = torch_npu.npu_add_rms_norm( - x, residual, self.rms_norm_weight, self.eps) + norm_output, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias( + x, residual, self.rms_norm_weight, None, self.eps) norm_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( norm_output, True) @@ -203,7 +203,7 @@ class TestModelSPWithoutBias(nn.Module): def ops_in_model_before(self) -> List[OpOverload]: """Return the list of expected operators BEFORE fusion.""" return [ - torch.ops.npu.npu_add_rms_norm.default, + torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, torch.ops.vllm.quantize.default ] diff --git a/tests/e2e/singlecard/test_quantization.py b/tests/e2e/singlecard/test_quantization.py index 4457a05f..119be0c2 100644 --- a/tests/e2e/singlecard/test_quantization.py +++ b/tests/e2e/singlecard/test_quantization.py @@ -26,8 +26,8 @@ def test_qwen3_w8a8_quant(): ] vllm_target_outputs = [([ 85, 4086, 44, 374, 264, 1550, 42747, 628, 323, 4938, 72816, 44378, 323, - 13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 369, 3460 - ], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed for large' + 13480, 4712, 369, 444, 10994, 82, 13, 1084, 374, 6188, 311, 387 + ], 'vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be' )] with VllmRunner( diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index e91c54a6..04d823ee 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -60,7 +60,9 @@ class AddRMSNormQuantPattern: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) @@ -179,7 +181,9 @@ class AddRMSNormQuantSPPattern: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) @@ -482,11 +486,11 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: - AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) - AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) if enable_custom_op(): + AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes)