[Graph][Fusion] Add AddRMSNorm(with bias) and Quant Fusion Pattern (#5011)

### What this PR does / why we need it?
AddRMSNorm(with bias) and Quant Fusion Pattern

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-15 18:37:56 +08:00
committed by GitHub
parent 6de4bedd04
commit 5fae65f3a8
2 changed files with 120 additions and 4 deletions

View File

@@ -79,6 +79,64 @@ class AddRMSNormQuantPattern:
pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
bias = torch.randn(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(
out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. /
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
offset,
epsilon=self.eps,
beta=bias)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
@@ -99,6 +157,8 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()