diff --git a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py index 1534b038..54e37e21 100644 --- a/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/npugraph_ex_passes/graphex_norm_quant_fusion_pass.py @@ -58,7 +58,9 @@ class GraphEXAddRMSNormQuantPattern: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) @@ -123,10 +125,11 @@ class GraphEXAddRMSNormQuantPatternWithBias: """ Pattern for AddRMSNormQuantWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 @@ -188,7 +191,9 @@ class GraphEXAddRMSNormQuantSPPattern: """ Pattern for AddRMSNormQuantSPPattern fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, None, self.eps + ) out0 = output[0] out1 = output[2] out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) @@ -255,10 +260,11 @@ class GraphEXAddRMSNormQuantSPPatternWithBias: """ Pattern for AddRMSNormQuantSPPatternWithBias fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1