[Graph][Fusion] Add AddRMSNorm(with bias) and Quant Fusion Pattern (#5011)

### What this PR does / why we need it?
AddRMSNorm(with bias) and Quant Fusion Pattern

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-15 18:37:56 +08:00
committed by GitHub
parent 6de4bedd04
commit 5fae65f3a8
2 changed files with 120 additions and 4 deletions

View File

@@ -29,10 +29,10 @@ from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
class TestModel(nn.Module):
class TestModelWithoutBias(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Quantization
AddRMSNorm → Quantization (without bias)
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
@@ -75,12 +75,65 @@ class TestModel(nn.Module):
return [torch.ops.npu.npu_add_rms_norm_quant.default]
class TestModelWithBias(nn.Module):
"""
A test model that simulates the pattern:
AddRMSNorm → Add Bias → Quantization (with bias)
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.bias = nn.Parameter(torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Add bias
3. Quantize to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
# Add bias
norm_output_with_bias = norm_output + self.bias
quantized_output = torch_npu.npu_quantize(norm_output_with_bias,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)
return quantized_output, new_residual
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.aten.add.Tensor, # Add bias operation
torch.ops.npu.npu_quantize.default
]
def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("use_bias", [False, True])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
num_tokens: int, eps: float, use_bias: bool):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
@@ -93,7 +146,10 @@ def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend(
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
model = TestModel(hidden_size, eps, device="npu")
if use_bias:
model = TestModelWithBias(hidden_size, eps, device="npu")
else:
model = TestModelWithoutBias(hidden_size, eps, device="npu")
model = model.to("npu")
x = torch.rand(num_tokens,

View File

@@ -79,6 +79,64 @@ class AddRMSNormQuantPattern:
pm.fwd_only, pm_pass)
class AddRMSNormQuantPatternWithBias:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
bias = torch.randn(4, device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset, bias]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
out0 = out0 + bias
quantized_output = torch.ops.npu.npu_quantize(
out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor, bias: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. /
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
offset,
epsilon=self.eps,
beta=bias)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
@@ -99,6 +157,8 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
AddRMSNormQuantPatternWithBias(vllm_config, eps=eps).register(
self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()