[Fusion] normalize fusion naming and enable e2e test (#4693)

### What this PR does / why we need it?
This PR standardizes the fusion naming, changing
`enable_quantization_fusion` to `fuse_norm_quant`, and enables e2e
testing.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-11 17:53:43 +08:00
committed by GitHub
parent 07c7131104
commit 18221c0e1d
8 changed files with 136 additions and 113 deletions

View File

@@ -103,6 +103,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_vlm.py
pytest -sv tests/e2e/singlecard/test_xlite.py
pytest -sv tests/e2e/singlecard/pooling/
pytest -sv tests/e2e/singlecard/compile/test_norm_quant_fusion.py
# ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

View File

@@ -17,65 +17,12 @@
from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence
import pytest
import torch
import torch.fx as fx
import torch.nn as nn
import torch_npu
import vllm.config
from torch._inductor.decomposition import select_decomp_table
from vllm.compilation.fx_utils import OpOverload
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.config import get_current_vllm_config
from vllm_ascend.compilation.compiler_interface import compile_fx
from vllm_ascend.compilation.passes.quant_fusion_pass import \
AddRMSNormQuantFusionPass
class TestModel(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm Quantization
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
quantized_output = torch_npu.npu_quantize(norm_output,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)
return quantized_output, new_residual
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.npu.npu_quantize.default
]
def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]
class TestBackend:
@@ -85,14 +32,12 @@ class TestBackend:
records the FX graph before and after the transformation.
"""
def __init__(self):
def __init__(self, custom_passes: Optional[List[Any]] = None):
vllm_config = get_current_vllm_config()
compile_config = vllm_config.compilation_config
self.custom_passes = [
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
]
self.inductor_config = compile_config.inductor_compile_config
self.inductor_config["graph_fusion_manager"] = self.post_pass
self.custom_passes = custom_passes
# Placeholders to store FX graphs for verification
self.graph_pre_pass = None
@@ -105,8 +50,9 @@ class TestBackend:
Apply custom graph transformation passes.
"""
self.graph_pre_pass = deepcopy(graph)
for pass_ in self.custom_passes:
pass_(graph)
if self.custom_passes is not None:
for pass_ in self.custom_passes:
pass_(graph)
self.graph_post_pass = deepcopy(graph)
return graph
@@ -136,11 +82,13 @@ class TestBackend:
)
return compiled_fn, None
def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]):
def __call__(self, gm: fx.GraphModule,
example_inputs: Optional[List[Any]]):
"""
Make the backend callable by torch.compile().
Returns a compiled executable function.
"""
assert example_inputs is not None
compiled_fn, _ = self.compile(
gm,
example_inputs,
@@ -180,40 +128,3 @@ class TestBackend:
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
print(f"Op {op}: post={num_post}")
assert num_post > 0, f"Op {op} not found in post-pass graph"
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
"""
torch.set_default_dtype(dtype)
torch.manual_seed(1)
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend()
model = TestModel(hidden_size, eps, device="npu")
model = model.to("npu")
x = torch.rand(num_tokens,
hidden_size,
device="npu",
dtype=dtype,
requires_grad=False)
result_unfused = model(x)
print("Unfused result:", [t.shape for t in result_unfused])
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
print("Fused result:", [t.shape for t in result_fused])
print("=== Checking operator fusion ===")
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@@ -0,0 +1,113 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import pytest
import torch
import torch.nn as nn
import torch_npu
import vllm.config
from vllm.compilation.fx_utils import OpOverload
from vllm.config import ModelConfig, VllmConfig
from tests.e2e.singlecard.compile.backend import TestBackend
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
class TestModel(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Quantization
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
quantized_output = torch_npu.npu_quantize(norm_output,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)
return quantized_output, new_residual
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.npu.npu_quantize.default
]
def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
"""
torch.set_default_dtype(dtype)
torch.manual_seed(1)
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend(
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
model = TestModel(hidden_size, eps, device="npu")
model = model.to("npu")
x = torch.rand(num_tokens,
hidden_size,
device="npu",
dtype=dtype,
requires_grad=False)
result_unfused = model(x)
print("Unfused result:", [t.shape for t in result_unfused])
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
print("Fused result:", [t.shape for t in result_fused])
print("=== Checking operator fusion ===")
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@@ -41,14 +41,14 @@ class TestAscendConfig(TestBase):
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertTrue(ascend_compilation_config.enable_quantization_fusion)
self.assertTrue(ascend_compilation_config.fuse_norm_quant)
@_clean_up_ascend_config
def test_init_ascend_config_with_additional_config(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"ascend_compilation_config": {
"enable_quantization_fusion": False,
"fuse_norm_quant": False,
},
"multistream_overlap_shared_expert": True,
"expert_map_path": "test_expert_map_path",
@@ -60,7 +60,7 @@ class TestAscendConfig(TestBase):
self.assertFalse(ascend_config.enable_npugraph_ex)
ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.enable_quantization_fusion)
self.assertFalse(ascend_compilation_config.fuse_norm_quant)
@_clean_up_ascend_config
def test_init_ascend_config_enable_npugraph_ex(self):

View File

@@ -190,19 +190,18 @@ class AscendCompilationConfig:
deployed on Ascend platforms.
"""
def __init__(self, enable_quantization_fusion: bool = True, **kwargs):
def __init__(self, fuse_norm_quant: bool = True, **kwargs):
"""
Initialize the configuration.
Args:
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization.
When set to True, the system will optimize quantization-related operations,
reducing the number of quantization/dequantization nodes.
fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize norm and quant operations.
Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.enable_quantization_fusion = enable_quantization_fusion
self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed

View File

@@ -46,8 +46,8 @@ class GraphFusionPassManager:
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
if self.ascend_compilation_config.get("enable_quantization_fusion",
True):
from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass
if self.ascend_compilation_config.get("fuse_norm_quant", True):
from .passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed

View File

@@ -88,8 +88,7 @@ class NPUPlatform(Platform):
Get the custom compile backend. Previously, we used EagerAdaptor by default.
To use graph fusion operations, we defined our own backend compiler.
"""
from vllm_ascend.compilation.compiler_interface import AscendCompiler
return AscendCompiler.__module__ + "." + AscendCompiler.__name__
return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
@classmethod
def pre_register_and_update(cls,
@@ -225,8 +224,8 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
from vllm_ascend.compilation.compiler_interface import AscendCompiler
compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__
# get custom compile backend for graph fusion
compilation_config.oot_compiler = cls.get_compile_backend()
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.mode = CompilationMode.NONE