[Fusion] normalize fusion naming and enable e2e test (#4693)

### What this PR does / why we need it?
This PR standardizes the fusion naming, changing
`enable_quantization_fusion` to `fuse_norm_quant`, and enables e2e
testing.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added/existing test.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-11 17:53:43 +08:00
committed by GitHub
parent 07c7131104
commit 18221c0e1d
8 changed files with 136 additions and 113 deletions

View File

@@ -103,6 +103,7 @@ jobs:
pytest -sv tests/e2e/singlecard/test_vlm.py pytest -sv tests/e2e/singlecard/test_vlm.py
pytest -sv tests/e2e/singlecard/test_xlite.py pytest -sv tests/e2e/singlecard/test_xlite.py
pytest -sv tests/e2e/singlecard/pooling/ pytest -sv tests/e2e/singlecard/pooling/
pytest -sv tests/e2e/singlecard/compile/test_norm_quant_fusion.py
# ------------------------------------ v1 spec decode test ------------------------------------ # # ------------------------------------ v1 spec decode test ------------------------------------ #
pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

View File

@@ -17,65 +17,12 @@
from copy import deepcopy from copy import deepcopy
from typing import Any, Callable, List, Optional, Sequence from typing import Any, Callable, List, Optional, Sequence
import pytest
import torch
import torch.fx as fx import torch.fx as fx
import torch.nn as nn
import torch_npu
import vllm.config
from torch._inductor.decomposition import select_decomp_table from torch._inductor.decomposition import select_decomp_table
from vllm.compilation.fx_utils import OpOverload from vllm.compilation.fx_utils import OpOverload
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm_ascend.compilation.compiler_interface import compile_fx from vllm_ascend.compilation.compiler_interface import compile_fx
from vllm_ascend.compilation.passes.quant_fusion_pass import \
AddRMSNormQuantFusionPass
class TestModel(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm Quantization
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
quantized_output = torch_npu.npu_quantize(norm_output,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)
return quantized_output, new_residual
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.npu.npu_quantize.default
]
def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]
class TestBackend: class TestBackend:
@@ -85,14 +32,12 @@ class TestBackend:
records the FX graph before and after the transformation. records the FX graph before and after the transformation.
""" """
def __init__(self): def __init__(self, custom_passes: Optional[List[Any]] = None):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
compile_config = vllm_config.compilation_config compile_config = vllm_config.compilation_config
self.custom_passes = [
AddRMSNormQuantFusionPass(vllm_config=vllm_config)
]
self.inductor_config = compile_config.inductor_compile_config self.inductor_config = compile_config.inductor_compile_config
self.inductor_config["graph_fusion_manager"] = self.post_pass self.inductor_config["graph_fusion_manager"] = self.post_pass
self.custom_passes = custom_passes
# Placeholders to store FX graphs for verification # Placeholders to store FX graphs for verification
self.graph_pre_pass = None self.graph_pre_pass = None
@@ -105,6 +50,7 @@ class TestBackend:
Apply custom graph transformation passes. Apply custom graph transformation passes.
""" """
self.graph_pre_pass = deepcopy(graph) self.graph_pre_pass = deepcopy(graph)
if self.custom_passes is not None:
for pass_ in self.custom_passes: for pass_ in self.custom_passes:
pass_(graph) pass_(graph)
self.graph_post_pass = deepcopy(graph) self.graph_post_pass = deepcopy(graph)
@@ -136,11 +82,13 @@ class TestBackend:
) )
return compiled_fn, None return compiled_fn, None
def __call__(self, gm: fx.GraphModule, example_inputs: List[Any]): def __call__(self, gm: fx.GraphModule,
example_inputs: Optional[List[Any]]):
""" """
Make the backend callable by torch.compile(). Make the backend callable by torch.compile().
Returns a compiled executable function. Returns a compiled executable function.
""" """
assert example_inputs is not None
compiled_fn, _ = self.compile( compiled_fn, _ = self.compile(
gm, gm,
example_inputs, example_inputs,
@@ -180,40 +128,3 @@ class TestBackend:
num_post = len(self.find_nodes_by_target(self.graph_post_pass, op)) num_post = len(self.find_nodes_by_target(self.graph_post_pass, op))
print(f"Op {op}: post={num_post}") print(f"Op {op}: post={num_post}")
assert num_post > 0, f"Op {op} not found in post-pass graph" assert num_post > 0, f"Op {op} not found in post-pass graph"
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
"""
torch.set_default_dtype(dtype)
torch.manual_seed(1)
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend()
model = TestModel(hidden_size, eps, device="npu")
model = model.to("npu")
x = torch.rand(num_tokens,
hidden_size,
device="npu",
dtype=dtype,
requires_grad=False)
result_unfused = model(x)
print("Unfused result:", [t.shape for t in result_unfused])
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
print("Fused result:", [t.shape for t in result_fused])
print("=== Checking operator fusion ===")
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@@ -0,0 +1,113 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import pytest
import torch
import torch.nn as nn
import torch_npu
import vllm.config
from vllm.compilation.fx_utils import OpOverload
from vllm.config import ModelConfig, VllmConfig
from tests.e2e.singlecard.compile.backend import TestBackend
from vllm_ascend.compilation.passes.norm_quant_fusion_pass import \
AddRMSNormQuantFusionPass
class TestModel(nn.Module):
"""
A minimal test model that simulates the pattern:
AddRMSNorm → Quantization
"""
def __init__(self, hidden_size: int, eps: float = 1e-6, device="npu"):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.rms_norm_weight = nn.Parameter(
torch.randn(hidden_size, device=device))
self.quant_scale = torch.tensor([1.0], device=device)
self.quant_offset = torch.tensor([0.0], device=device)
def forward(self, x):
"""
Forward pass:
1. Perform npu_add_rms_norm
2. Quantize the normalized output to int8
Returns both quantized output and updated residual.
"""
residual = torch.zeros_like(x)
norm_output, _, new_residual = torch_npu.npu_add_rms_norm(
x, residual, self.rms_norm_weight, self.eps)
quantized_output = torch_npu.npu_quantize(norm_output,
self.quant_scale,
self.quant_offset,
torch.qint8, -1, False)
return quantized_output, new_residual
def ops_in_model_before(self) -> List[OpOverload]:
"""Return the list of expected operators BEFORE fusion."""
return [
torch.ops.npu.npu_add_rms_norm.default,
torch.ops.npu.npu_quantize.default
]
def ops_in_model_after(self) -> List[OpOverload]:
"""Return the list of expected operators AFTER successful fusion."""
return [torch.ops.npu.npu_add_rms_norm_quant.default]
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
def test_rmsnorm_quant_fusion(dtype: torch.dtype, hidden_size: int,
num_tokens: int, eps: float):
"""
End-to-end test for AddRMSNorm+Quantize fusion.
Compares: Operator presence/absence before and after graph transformation
"""
torch.set_default_dtype(dtype)
torch.manual_seed(1)
vllm_config = VllmConfig(model_config=ModelConfig(dtype=dtype))
with vllm.config.set_current_vllm_config(vllm_config):
backend = TestBackend(
custom_passes=[AddRMSNormQuantFusionPass(vllm_config=vllm_config)])
model = TestModel(hidden_size, eps, device="npu")
model = model.to("npu")
x = torch.rand(num_tokens,
hidden_size,
device="npu",
dtype=dtype,
requires_grad=False)
result_unfused = model(x)
print("Unfused result:", [t.shape for t in result_unfused])
model_fused = torch.compile(model, backend=backend)
result_fused = model_fused(x)
print("Fused result:", [t.shape for t in result_fused])
print("=== Checking operator fusion ===")
backend.check_before_ops(model.ops_in_model_before())
backend.check_after_ops(model.ops_in_model_after())

View File

@@ -41,14 +41,14 @@ class TestAscendConfig(TestBase):
self.assertFalse(ascend_config.multistream_overlap_shared_expert) self.assertFalse(ascend_config.multistream_overlap_shared_expert)
ascend_compilation_config = ascend_config.ascend_compilation_config ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertTrue(ascend_compilation_config.enable_quantization_fusion) self.assertTrue(ascend_compilation_config.fuse_norm_quant)
@_clean_up_ascend_config @_clean_up_ascend_config
def test_init_ascend_config_with_additional_config(self): def test_init_ascend_config_with_additional_config(self):
test_vllm_config = VllmConfig() test_vllm_config = VllmConfig()
test_vllm_config.additional_config = { test_vllm_config.additional_config = {
"ascend_compilation_config": { "ascend_compilation_config": {
"enable_quantization_fusion": False, "fuse_norm_quant": False,
}, },
"multistream_overlap_shared_expert": True, "multistream_overlap_shared_expert": True,
"expert_map_path": "test_expert_map_path", "expert_map_path": "test_expert_map_path",
@@ -60,7 +60,7 @@ class TestAscendConfig(TestBase):
self.assertFalse(ascend_config.enable_npugraph_ex) self.assertFalse(ascend_config.enable_npugraph_ex)
ascend_compilation_config = ascend_config.ascend_compilation_config ascend_compilation_config = ascend_config.ascend_compilation_config
self.assertFalse(ascend_compilation_config.enable_quantization_fusion) self.assertFalse(ascend_compilation_config.fuse_norm_quant)
@_clean_up_ascend_config @_clean_up_ascend_config
def test_init_ascend_config_enable_npugraph_ex(self): def test_init_ascend_config_enable_npugraph_ex(self):

View File

@@ -190,19 +190,18 @@ class AscendCompilationConfig:
deployed on Ascend platforms. deployed on Ascend platforms.
""" """
def __init__(self, enable_quantization_fusion: bool = True, **kwargs): def __init__(self, fuse_norm_quant: bool = True, **kwargs):
""" """
Initialize the configuration. Initialize the configuration.
Args: Args:
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization. fuse_norm_quant (bool): Whether to enable norm and quant fusion optimization.
When set to True, the system will optimize quantization-related operations, When set to True, the system will optimize norm and quant operations.
reducing the number of quantization/dequantization nodes.
Default: True Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension. **kwargs: Additional optional parameters for forward compatibility and configuration extension.
""" """
self.enable_quantization_fusion = enable_quantization_fusion self.fuse_norm_quant = fuse_norm_quant
# Add more compilation related configs here as needed # Add more compilation related configs here as needed

View File

@@ -46,8 +46,8 @@ class GraphFusionPassManager:
# By default, we enable the graph fusion and quantization fusion pass. # By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get( self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {}) "ascend_compilation_config", {})
if self.ascend_compilation_config.get("enable_quantization_fusion", if self.ascend_compilation_config.get("fuse_norm_quant", True):
True): from .passes.norm_quant_fusion_pass import \
from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config)) self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed # Add more passes here as needed

View File

@@ -88,8 +88,7 @@ class NPUPlatform(Platform):
Get the custom compile backend. Previously, we used EagerAdaptor by default. Get the custom compile backend. Previously, we used EagerAdaptor by default.
To use graph fusion operations, we defined our own backend compiler. To use graph fusion operations, we defined our own backend compiler.
""" """
from vllm_ascend.compilation.compiler_interface import AscendCompiler return "vllm_ascend.compilation.compiler_interface.AscendCompiler"
return AscendCompiler.__module__ + "." + AscendCompiler.__name__
@classmethod @classmethod
def pre_register_and_update(cls, def pre_register_and_update(cls,
@@ -225,8 +224,8 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
from vllm_ascend.compilation.compiler_interface import AscendCompiler # get custom compile backend for graph fusion
compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__ compilation_config.oot_compiler = cls.get_compile_backend()
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.mode = CompilationMode.NONE compilation_config.mode = CompilationMode.NONE