Adopt inductor fusion and define quantization fusion pass (#4168)

### What this PR does / why we need it?
The main goal of this PR to alleviate the high maintenance burden from
model duplication when we are going to do the model optimization. Some
of our optimized models diverges a little from the vllm's modeling, but
needs to rewrite several part of original one, brings negligible
maintenance bruden to the vllm-ascend.In order to solve that, we propose
to leverage `torch.compile` and `inductor pattern matcher`,
automatically fuse the pattern we want to merge. For more details can
refer to the RFC https://github.com/vllm-project/vllm-ascend/issues/4239

This pr integrates `AddRMSNorm` and the `Quant` operator, which can
improve the inference speed of models using `w8a8 `quantization.

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

### How was this patch tested?
```python
def main():
    prompts = [
        "The president of the United States is Mr.",
    ]

    # Create a sampling params object.
    sampling_params = SamplingParams(max_tokens=100, temperature=0.6, top_k=40, top_p=0.95)
    # Create an LLM.
    llm = LLM(
        model="/root/.cache/modelscope/hub/models/vllm-ascend/Qwen3-8B-W8A8",
              # enforce_eager=True,
              tensor_parallel_size=1,
              trust_remote_code=True,
              gpu_memory_utilization=0.7,
              quantization="ascend",
              )

    # Generate texts from the prompts.
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

```text
Prompt: 'The president of the United States is Mr.', Generated text: ' Trump. The president of the United States is Mr. Biden. Which of the following statements is correct? \n\nA. Mr. Trump is Mr. Biden.  \nB. Mr. Trump is not Mr. Biden.  \nC. The president of the United States is not Mr. Trump.  \nD. The president of the United States is not Mr. Biden.\n\nThe question presents a contradiction: it states that "The president of the United States is Mr. Trump" and "The president of'
```


- vLLM version: 86e178f7c4d8c3b0eaf3c8e3f810a83f63b90e24
- vLLM main:
86e178f7c4

---------

Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2025-12-04 10:29:48 +08:00
committed by GitHub
parent c4a71fc6d5
commit 178ca1607e
13 changed files with 593 additions and 267 deletions

View File

@@ -36,9 +36,15 @@ class AscendConfig:
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
torchair_graph_config = additional_config.get("torchair_graph_config",
{})
self.torchair_graph_config = TorchairGraphConfig(
torchair_graph_config, vllm_config, additional_config)
ascend_compilation_config = additional_config.get(
"ascend_compilation_config", {})
self.ascend_compilation_config = AscendCompilationConfig(
**ascend_compilation_config)
ascend_scheduler_config = additional_config.get(
"ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig(
@@ -144,6 +150,31 @@ class AscendConfig:
self, vllm_config)
class AscendCompilationConfig:
"""
Configuration for controlling the behavior of Ascend graph optimization.
This class provides a way to configure graph fusion optimizations.
These configurations directly impact the performance and behavior of models
deployed on Ascend platforms.
"""
def __init__(self, enable_quantization_fusion: bool = True, **kwargs):
"""
Initialize the configuration.
Args:
enable_quantization_fusion (bool): Whether to enable quantization fusion optimization.
When set to True, the system will optimize quantization-related operations,
reducing the number of quantization/dequantization nodes.
Default: True
**kwargs: Additional optional parameters for forward compatibility and configuration extension.
"""
self.enable_quantization_fusion = enable_quantization_fusion
# Add more compilation related configs here as needed
class TorchairGraphConfig:
"""
Configuration Object for torchair_graph_config from additional_config
@@ -326,6 +357,11 @@ def check_ascend_config(vllm_config, enforce_eager):
"it has been disabled automatically.")
# aclgraph case
else:
if ascend_config.ascend_compilation_config.enable_quantization_fusion:
logger.info(
"Quantization fusion enabled! op fusion on quantization are expected. "
)
if vllm_config.model_config:
model_type = vllm_config.model_config.hf_config.model_type
if "qwen" not in model_type:

View File

@@ -159,25 +159,6 @@ def set_ascend_forward_context(
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.
#
# set for addrmsnorm+quant fusion.
# this optim now just support dense models due to the specific operators used.
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
model_type_scope = ["llama", "qwen2", "qwen3", "qwen3_moe"]
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in model_type_scope and \
forward_context.layer_idx is not None
if addrmsnorm_quant_fusion_enabled:
forward_context.model_instance = model_instance
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
if num_tokens is None and attn_metadata is not None:
num_tokens = attn_metadata.num_actual_tokens

View File

@@ -0,0 +1,73 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
from typing import Any, Callable, Optional
import torch.fx as fx
from torch._dynamo.backends.common import aot_autograd
from torch._inductor.compile_fx import (graph_returns_tuple,
make_graph_return_tuple)
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
def compile_fx(graph: GraphModule, example_inputs: list,
inner_compile: Callable, decompositions: dict) -> Callable:
recursive_compile_fx = functools.partial(compile_fx,
inner_compile=inner_compile,
decompositions=decompositions)
if not graph_returns_tuple(graph):
return make_graph_return_tuple(graph, example_inputs,
recursive_compile_fx)
return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs)
class AscendCompiler(CompilerInterface):
"""
AscendCompiler is a custom compiler interface for the Ascend platform.
This class provides a method to compile a PyTorch FX graph module with
specific configurations for graph fusion and decomposition.
"""
name = "AscendCompiler"
def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, runtime_shape)
return graph
decompositions = select_decomp_table()
compiled_fn = compile_fx(
graph=graph,
example_inputs=example_inputs,
inner_compile=compile_inner,
decompositions=decompositions,
)
return compiled_fn, None

View File

@@ -0,0 +1,53 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from torch import fx as fx
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
class GraphFusionPassManager:
"""
A pass manager for graph fusion passes.
It handles the configuration and execution of passes.
The counterpart in vllm is PostGradPassManager. Since torch_npu
does not support triton for now, we define our own pass manager.
"""
def __init__(self):
self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
for pass_ in self.passes:
if pass_.is_applicable(runtime_shape):
pass_(graph)
return graph
def add(self, pass_: VllmInductorPass):
assert isinstance(pass_, VllmInductorPass)
self.passes.append(pass_)
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.ascend_compilation_config: dict = config.additional_config.get(
"ascend_compilation_config", {})
if self.ascend_compilation_config.get("enable_quantization_fusion",
True):
from .passes.quant_fusion_pass import AddRMSNormQuantFusionPass
self.passes.append(AddRMSNormQuantFusionPass(config))
# Add more passes here as needed

View File

@@ -0,0 +1,113 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import torch
import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
class AddRMSNormQuantPattern:
def __init__(self, vllm_config: VllmConfig, eps: float = 1e-6):
self.vllm_config = vllm_config
self.eps = eps
def get_inputs(self):
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
def register(self, pm_pass: PatternMatcherPass):
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, self.eps)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(
out0, scale, offset, torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
1. /
scale, # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
offset,
epsilon=self.eps)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AddRMSNormQuantFusionPass(VllmInductorPass):
"""
A pass for fusing AddRMSNorm and W8A8 quantization operations on Ascend.
"""
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.pattern_match_passes: PatternMatcherPass = PatternMatcherPass(
pass_name="rmsnorm_quant_fusion_pass")
dtype = vllm_config.model_config.dtype
if dtype not in (torch.bfloat16, torch.float16):
logging.info("Quant fusion not enabled: unsupported dtype %s",
dtype)
return
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
AddRMSNormQuantPattern(vllm_config,
eps=eps).register(self.pattern_match_passes)
def __call__(self, graph: torch.fx.Graph):
self.begin()
self.matched_count = self.pattern_match_passes.apply(graph)
logging.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()
def is_applicable(self, runtime_shape: int | None = None) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""
return True

View File

@@ -19,70 +19,9 @@ from typing import Optional, Tuple, Union, cast
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
bias: Optional[torch.nn.Parameter] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if layer is not None and get_ascend_device_type(
) != AscendDeviceType._310P:
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
except AssertionError:
weight_prefetch_method = None
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
# add_rms_norm_quant
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
layer_cls_name=layer_cls_name,
stop_flag=x,
)
else:
if get_ascend_device_type() == AscendDeviceType._310P:
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if bias is not None:
x.add_(bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
def __init__(
@@ -109,59 +48,27 @@ class AscendRMSNorm(RMSNorm):
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
if residual is not None:
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear,
self.bias)
if get_ascend_device_type() == AscendDeviceType._310P:
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x
@property
def next_need_quant_fusion_linear(self):
try:
forward_context = get_forward_context()
if not forward_context.addrmsnorm_quant_fusion_enabled or \
forward_context.layer_idx == forward_context.num_hidden_layers:
return None
except AssertionError:
return None
next_linear = None
model_instance = forward_context.model_instance
layer_idx = forward_context.layer_idx
fusion_linear = forward_context.fusion_linear
next_linear = None
if fusion_linear == "qkv_dense":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_up_dense"
elif fusion_linear == "gate_up_dense":
next_linear = model_instance.model.layers[
layer_idx].mlp.gate_up_proj
forward_context.fusion_linear = "qkv_dense"
# if prefetch_mlp_weight enabled, following accumulation operation
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
elif fusion_linear == "qkv_moe":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_moe"
elif fusion_linear == "gate_moe":
forward_context.fusion_linear = "qkv_moe"
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
next_linear = None
return next_linear
class AscendQuantRMSNorm(AscendRMSNorm):

View File

@@ -73,7 +73,10 @@ def _rope_forward_oot(
query = query.contiguous().view(1, query.shape[0], -1,
self.head_size)
key = key.contiguous().view(1, key.shape[0], -1, self.head_size)
torch_npu.npu_apply_rotary_pos_emb(query, key, self.cos, self.sin)
# Although this function modifies in-place, please retain the function's return value.
# Otherwise, the graph fusion operation may fail.
query, key = torch_npu.npu_apply_rotary_pos_emb(
query, key, self.cos, self.sin)
elif self.rotary_dim < self.head_size:
num_tokens = query.shape[0]
query = query.view(num_tokens, -1, self.head_size)

View File

@@ -66,6 +66,32 @@ class NPUPlatform(Platform):
def is_sleep_mode_available(self) -> bool:
return True
@property
def pass_key(self) -> str:
"""
Inductor config key for the PassManager custom pass, for example 'post_grad_custom_post_pass'.
It is a parameter of inductor_config used to register custom passes.
Currently, we only use Inductor's 'pattern matcher' functionality, so we define our own pass_key.
"""
return "graph_fusion_manager"
@classmethod
def get_pass_manager_cls(cls) -> str:
"""
Get the pass manager class for this platform.
It will be registered as a custom pass under the current_platform.pass_key.
"""
return "vllm_ascend.compilation.graph_fusion_pass_manager.GraphFusionPassManager"
@classmethod
def get_compile_backend(self) -> str:
"""
Get the custom compile backend. Previously, we used EagerAdaptor by default.
To use graph fusion operations, we defined our own backend compiler.
"""
from vllm_ascend.compilation.compiler_interface import AscendCompiler
return AscendCompiler.__module__ + "." + AscendCompiler.__name__
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
@@ -135,6 +161,13 @@ class NPUPlatform(Platform):
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
ascend_scheduler_config = ascend_config.ascend_scheduler_config
ascend_compilation_config = ascend_config.ascend_compilation_config
if ascend_compilation_config:
vllm_config.additional_config.setdefault(
"ascend_compilation_config", {}).update(
vars(ascend_compilation_config
) if not isinstance(ascend_compilation_config, dict)
else ascend_compilation_config)
kv_cache_dtype = vllm_config.additional_config.get(
"kv_cache_dtype", None)
@@ -214,6 +247,9 @@ class NPUPlatform(Platform):
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
from vllm_ascend.compilation.compiler_interface import AscendCompiler
compilation_config.oot_compiler = AscendCompiler.__module__ + "." + AscendCompiler.__name__
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.mode = CompilationMode.NONE
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: