From 87a0b7b7c7630ae112c8384ad1af11ed29d55801 Mon Sep 17 00:00:00 2001 From: iiiklw <852373687@qq.com> Date: Fri, 13 Feb 2026 10:10:39 +0800 Subject: [PATCH] [bugfix] adapt bugfix for norm_quant_fusion_pass to npugraph_ex (#6726) ### What this PR does / why we need it? This PR adapts bugfixes from `norm_quant_fusion_pass` to `graphex_norm_quant_fusion_pass` for the `npugraph_ex` backend. The main changes are: - Replaced `torch.ops.npu.npu_add_rms_norm` with `torch.ops._C_ascend.npu_add_rms_norm_bias`. - For patterns without bias, `None` is passed as the bias argument. - For patterns with bias, the separate `add` operation for bias is removed and the bias is passed directly to `npu_add_rms_norm_bias`. This improves fusion. These changes ensure consistency and correctness for RMSNorm and quantization fusion patterns when using `npugraph_ex`. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 Signed-off-by: huyuanquan1 Co-authored-by: huyuanquan1 --- .../graphex_norm_quant_fusion_pass.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py index 1534b038..54e37e21 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py @@ -58,7 +58,9 @@ class GraphEXAddRMSNormQuantPattern: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) @@ -123,10 +125,11 @@ class GraphEXAddRMSNormQuantPatternWithBias: """ Pattern for AddRMSNormQuantWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 @@ -188,7 +191,9 @@ class GraphEXAddRMSNormQuantSPPattern: """ Pattern for AddRMSNormQuantSPPattern fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) @@ -255,10 +260,11 @@ class GraphEXAddRMSNormQuantSPPatternWithBias: """ Pattern for AddRMSNormQuantSPPatternWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1