[Feature] Support npuhraph_ex backend (#4700)

### What this PR does / why we need it?
We introduced the npugraph_ex backend through the vllm's adaptor
dispatch mechanism to accelerate aclgraph. This solution is based on
torch.compile and uses torchair to optimize the fx.graph. The
performance gains are mainly obtained from the static kernel. We
conducted tests on Qwen3-30B and achieved over 5% performance
optimization.

### Does this PR introduce _any_ user-facing change?
Yes, we add a new switch named"enable_npugraph_ex" in additional_config,
default is False.
We also add an example to show how to register custom replacement pass

### More information about this PR
This feature depends on the release of CANN and torch_npu in Q4. 
We tested it on a package that has not been publicly released yet and
verified that the functionality works.
This feature is still experimental at the moment; setting the config
true will directly raise error.
Merging into the main branch initially involves some preliminary commits
to facilitate subsequent development and testing of the feature, as well
as to avoid submitting an excessively large PR at once.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: panchao-hub <315134829@qq.com>
Co-authored-by: wbigat <wbigat@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
ChenCangtao
2025-12-10 20:48:05 +08:00
committed by GitHub
parent d7db6791e7
commit dd622aa6a6
7 changed files with 235 additions and 15 deletions

View File

@@ -0,0 +1,123 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import sys
import torch
from torch._inductor.pattern_matcher import Match
from vllm.logger import logger
@functools.lru_cache(None)
# The replacement registered here will be actually executed after AOT.
def _register_replacement(epsilon):
if 'torch_npu' not in sys.modules:
logger.info(
'The AddRMSNormQuant fusion will only be enabled in a torch npu env.'
'When there is no torch_npu in the env, skip fusion.')
return
def _extra_stream_scope_check(match: Match) -> bool:
"""
Checks if all nodes in the same stream.
"""
non_default_streams = set()
has_default = False
for node in match.nodes:
if node.op == "call_function":
current_stream = node.meta.get("stream_label")
if current_stream is None:
has_default = True
else:
non_default_streams.add(current_stream)
if len(non_default_streams) > 1:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations."
)
return False
if has_default and len(non_default_streams) > 0:
logger.debug(
f"Cross-stream operation detected in pattern match for AddRMSNormQuant. "
f"Multiple streams found: {non_default_streams}. "
f"Fusion is not supported for cross-stream operations.")
return False
return True
def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Pattern for AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual,
rms_norm_weight, epsilon)
out0 = output[0]
out1 = output[2]
quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset,
torch.qint8, -1, False)
return quantized_output, out1
def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor,
rms_norm_weight: torch.Tensor, scale: torch.Tensor,
offset: torch.Tensor):
"""
Replacement for the AddRMSNormQuant fusion.
"""
output = torch.ops.npu.npu_add_rms_norm_quant(
rms_norm_input,
residual,
rms_norm_weight,
# The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel.
1. / scale,
offset,
epsilon=epsilon)
quantized_output = output[0]
out1 = output[2]
return quantized_output, out1
def get_inputs():
"""
Generate example inputs for the AddRMSNormQuant fusion pattern.
"""
rms_norm_input = torch.randn(2, 4, device="npu")
residual = torch.randn(2, 4, device="npu")
rms_norm_weight = torch.randn(4, device="npu")
scale = torch.tensor([1.0], device="npu")
offset = torch.tensor([0.0], device="npu")
return [rms_norm_input, residual, rms_norm_weight, scale, offset]
import torchair
torchair.register_replacement(search_fn=pattern,
replace_fn=replacement,
example_inputs=get_inputs(),
extra_check=_extra_stream_scope_check)
# register converter for pass
common_epsilons = [1e-5, 1e-6]
for eps in common_epsilons:
logger.info(
f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}"
)
_register_replacement(eps)