diff --git a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py index 057fe888..d08e69c4 100644 --- a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py +++ b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py @@ -32,6 +32,7 @@ from tests.e2e.singlecard.compile.backend import TestBackend from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass +from vllm_ascend.utils import enable_custom_op class TestModelWithoutBias(nn.Module): @@ -124,11 +125,8 @@ class TestModelWithBias(nn.Module): """ residual = torch.zeros_like(x) - norm_output, _, new_residual = torch_npu.npu_add_rms_norm( - x, residual, self.rms_norm_weight, self.eps) - - # Add bias - norm_output_with_bias = norm_output + self.bias + norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias( + x, residual, self.rms_norm_weight, self.bias, self.eps) quantized_output = torch.ops.vllm.quantize(norm_output_with_bias, self.quant_scale, @@ -140,8 +138,7 @@ class TestModelWithBias(nn.Module): def ops_in_model_before(self) -> List[OpOverload]: """Return the list of expected operators BEFORE fusion.""" return [ - torch.ops.npu.npu_add_rms_norm.default, - torch.ops.aten.add.Tensor, # Add bias operation + torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.quantize.default ] @@ -249,11 +246,8 @@ class TestModelSPWithBias(nn.Module): """ residual = torch.zeros_like(x) - norm_output, _, new_residual = torch_npu.npu_add_rms_norm( - x, residual, self.rms_norm_weight, self.eps) - - # Add bias - norm_output_with_bias = norm_output + self.bias + norm_output_with_bias, _, new_residual = torch.ops._C_ascend.npu_add_rms_norm_bias( + x, residual, self.rms_norm_weight, self.bias, self.eps) norm_output_with_bias = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( norm_output_with_bias, True) @@ -268,8 +262,7 @@ class TestModelSPWithBias(nn.Module): def ops_in_model_before(self) -> List[OpOverload]: """Return the list of expected operators BEFORE fusion.""" return [ - torch.ops.npu.npu_add_rms_norm.default, - torch.ops.aten.add.Tensor, # Add bias operation + torch.ops._C_ascend.npu_add_rms_norm_bias.default, torch.ops.vllm.maybe_all_gather_and_maybe_unpad.default, torch.ops.vllm.quantize.default ] @@ -322,6 +315,8 @@ def test_rmsnorm_quant_fusion( AddRMSNormQuantFusionPass(vllm_config=vllm_config) ]) if use_bias: + if not enable_custom_op(): + return if sp_enable: model = TestModelSPWithBias(hidden_size, dtype, diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index e26b8429..9bea5ca5 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -23,6 +23,8 @@ from vllm.config import VllmConfig from vllm.config.compilation import Range from vllm.logger import logger +from vllm_ascend.utils import enable_custom_op + class AddRMSNormQuantPattern: def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): @@ -113,10 +115,11 @@ class AddRMSNormQuantPatternWithBias: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 @@ -233,10 +236,11 @@ class AddRMSNormQuantSPPatternWithBias: """ Pattern for AddRMSNormQuant fusion. """ - output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) out0 = output[0] out1 = output[2] - out0 = out0 + bias out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) quantized_output = torch.ops.vllm.quantize(out0, scale, scale_reciprocal, offset) return quantized_output, out1 @@ -281,9 +285,10 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): common_epsilons = [1e-5, 1e-6] for eps in common_epsilons: AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) - AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) - AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + if enable_custom_op(): + AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin()