From 5fae65f3a8612ed4034544a22632b9cdaad32d61 Mon Sep 17 00:00:00 2001 From: Icey <1790571317@qq.com> Date: Mon, 15 Dec 2025 18:37:56 +0800 Subject: [PATCH] [Graph][Fusion] Add AddRMSNorm(with bias) and Quant Fusion Pattern (#5011) ### What this PR does / why we need it? AddRMSNorm(with bias) and Quant Fusion Pattern ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: wxsIcey <1790571317@qq.com> --- .../compile/test_norm_quant_fusion.py | 64 +++++++++++++++++-- .../passes/norm_quant_fusion_pass.py | 60 +++++++++++++++++ 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py index 3b864e8c..2dd09de1 100644 --- a/tests/e2e/singlecard/compile/test_norm_quant_fusion.py +++ b/tests/e2e/singlecard/compile/test_norm_quant_fusion.py @@ -29,10 +29,10 @@ from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \ AddRMSNormQuantFusionPass -class TestModel(nn.Module): +class TestModelWithoutBias(nn.Module): """ A minimal test model that simulates the pattern: - AddRMSNorm → Quantization + AddRMSNorm → Quantization (without bias) """ def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): @@ -75,12 +75,65 @@ class TestModel(nn.Module): return [torch.ops.npu.npu_add_rms_norm_quant.default] +class TestModelWithBias(nn.Module): + """ + A test model that simulates the pattern: + AddRMSNorm → Add Bias → Quantization (with bias) + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.rms_norm_weight = nn.Parameter( + torch.randn(hidden_size, device=device)) + self.bias = nn.Parameter(torch.randn(hidden_size, device=device)) + self.quant_scale = torch.tensor([1.0], device=device) + self.quant_offset = torch.tensor([0.0], device=device) + + def forward(self, x): + """ + Forward pass: + 1. Perform npu_add_rms_norm + 2. Add bias + 3. Quantize to int8 + Returns both quantized output and updated residual. + """ + residual = torch.zeros_like(x) + + norm_output, _, new_residual = torch_npu.npu_add_rms_norm( + x, residual, self.rms_norm_weight, self.eps) + + # Add bias + norm_output_with_bias = norm_output + self.bias + + quantized_output = torch_npu.npu_quantize(norm_output_with_bias, + self.quant_scale, + self.quant_offset, + torch.qint8, -1, False) + + return quantized_output, new_residual + + def ops_in_model_before(self) -> List[OpOverload]: + """Return the list of expected operators BEFORE fusion.""" + return [ + torch.ops.npu.npu_add_rms_norm.default, + torch.ops.aten.add.Tensor, # Add bias operation + torch.ops.npu.npu_quantize.default + ] + + def ops_in_model_after(self) -> List[OpOverload]: + """Return the list of expected operators AFTER successful fusion.""" + return [torch.ops.npu.npu_add_rms_norm_quant.default] + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("num_tokens", [257]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("use_bias", [False, True]) def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, - num_tokens: int, eps: float): + num_tokens: int, eps: float, use_bias: bool): """ End-to-end test for AddRMSNorm+Quantize fusion. Compares: Operator presence/absence before and after graph transformation @@ -93,7 +146,10 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int, with vllm.config.set_current_vllm_config(vllm_config): backend = TestBackend( custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)]) - model = TestModel(hidden_size, eps, device="npu") + if use_bias: + model = TestModelWithBias(hidden_size, eps, device="npu") + else: + model = TestModelWithoutBias(hidden_size, eps, device="npu") model = model.to("npu") x = torch.rand(num_tokens, diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index 87f0e1a4..c23449bf 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -79,6 +79,64 @@ class AddRMSNormQuantPattern: pm.fwd_only, pm_pass) +class AddRMSNormQuantPatternWithBias: + + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.tensor([1.0], device="npu") + offset = torch.tensor([0.0], device="npu") + bias = torch.randn(4, device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = out0 + bias + quantized_output = torch.ops.npu.npu_quantize( + out0, scale, offset, torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor, bias: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + 1. / + scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + offset, + epsilon=self.eps, + beta=bias) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + class AddRMSNormQuantFusionPass(VllmInductorPass): """ A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. @@ -99,6 +157,8 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): for eps in common_epsilons: AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register( + self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin()