From 804a9ec4e6365aa484e3f9db61e226598b406ee6 Mon Sep 17 00:00:00 2001 From: Zhang-Bryan <250470359+Zhang-Bryan@users.noreply.github.com> Date: Wed, 4 Feb 2026 15:53:53 +0800 Subject: [PATCH] [Fusion] Add rmsnorm dynamic quant fusion pass (#6274) ### What this PR does / why we need it? This PR introduces four new patterns to support the fusion of RMSNorm and DynamicQuant operators. After replacing the fusion operators, the execution time has been reduced from 22.8us to 16.9us. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/d7de043d55d1dd629554467e23874097e1c48993 Signed-off-by: Bryan <250470359+Zhang-Bryan@users.noreply.github.com> --- .../passes/norm_quant_fusion_pass.py | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) diff --git a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py index 9bea5ca5..5dcb98d1 100644 --- a/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py +++ b/vllm_ascend/compilation/passes/norm_quant_fusion_pass.py @@ -268,6 +268,200 @@ class AddRMSNormQuantSPPatternWithBias: pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) +class AddRMSNormDynamicQuantPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True] + ) + return ( + output[0], + output[3], + output[2], + ) + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias + ) + return ( + output[0], + output[3], + output[2], + ) + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantSPPattern: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight] + + def register(self, pm_pass: PatternMatcherPass): + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, rms_norm_weight, self.eps) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, rms_norm_weight: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True] + ) + out3 = output[3] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True) + out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True) + return quantized_output, out3, output[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + +class AddRMSNormDynamicQuantSPPatternWithBias: + def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6): + self.vllm_config = vllm_config + self.dtype = vllm_config.model_config.dtype + self.eps = eps + + def get_inputs(self): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu", dtype=self.dtype) + residual = torch.randn(2, 4, device="npu", dtype=self.dtype) + rms_norm_weight = torch.randn(4, device="npu", dtype=self.dtype) + rmsnorm_bias = torch.randn(4, device="npu", dtype=self.dtype) + return [rms_norm_input, residual, rms_norm_weight, rmsnorm_bias] + + def register(self, pm_pass: PatternMatcherPass): + def pattern( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops._C_ascend.npu_add_rms_norm_bias( + rms_norm_input, residual, rms_norm_weight, bias, self.eps + ) + out0 = output[0] + out1 = output[2] + out0 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out0, True) + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output[0], quantized_output[1], out1 + + def replacement( + rms_norm_input: torch.Tensor, + residual: torch.Tensor, + rms_norm_weight: torch.Tensor, + bias: torch.Tensor, + ): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + rms_norm_input, residual, rms_norm_weight, epsilon=self.eps, output_mask=[True, True], beta=bias + ) + out3 = output[3] + quantized_output = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(output[0], True) + out3 = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(out3, True) + return quantized_output, out3, output[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) + + class AddRMSNormQuantFusionPass(VllmInductorPass): """ A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend. @@ -286,9 +480,13 @@ class AddRMSNormQuantFusionPass(VllmInductorPass): for eps in common_epsilons: AddRMSNormQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantPattern(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantSPPattern(vllm_config, eps=eps).register(self.pattern_match_passes) if enable_custom_op(): AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) AddRMSNormQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) + AddRMSNormDynamicQuantSPPatternWithBias(vllm_config, eps=eps).register(self.pattern_match_passes) def __call__(self, graph: torch.fx.Graph): self.begin()