[npugraph_ex]enable npugraph_ex by default (#6664)

### What this PR does / why we need it?

This pull request enables the `npugraph_ex` backend by default to
improve performance on Ascend NPUs, as proposed in the
[RFC](https://github.com/vllm-project/vllm-ascend/issues/6214).


### Does this PR introduce _any_ user-facing change?

Yes. `npugraph_ex` is now enabled by default. Users can disable it by
setting `enable: false` in the `npugraph_ex_config` section of the
`additional_config`.

### How was this patch tested?

CI passed. The changes are covered by existing and new E2E tests
(`test_aclgraph_accuracy.py`) and unit tests (`test_ascend_config.py`)
that have been updated to reflect the new default behavior. The tests
verify correctness and consistency with `npugraph_ex` enabled and
disabled, as well as with the new static kernel option.

Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com>
Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
This commit is contained in:
iiiklw
2026-02-12 08:44:06 +08:00
committed by GitHub
parent b86ea66b0a
commit a0315f6697
10 changed files with 159 additions and 9 deletions

View File

@@ -94,7 +94,7 @@ The details of each configuration option are as follows:
| Name | Type | Default | Description |
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
| `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |

View File

@@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
"cudagraph_mode": "FULL_DECODE_ONLY"
},
"quantization": cur_case.quantization,
"additional_config": {
"npugraph_ex_config": {
"enable": False
}
},
}
gen_and_valid(runner_kwargs=runner_kwargs,
prompts=cur_case.prompts,

View File

@@ -67,7 +67,7 @@ class TestAscendConfig(TestBase):
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
npugraph_ex_config = ascend_config.npugraph_ex_config
self.assertFalse(npugraph_ex_config.enable)
self.assertTrue(npugraph_ex_config.enable)
self.assertFalse(npugraph_ex_config.enable_static_kernel)
ascend_compilation_config = ascend_config.ascend_compilation_config

View File

@@ -260,7 +260,7 @@ class NpugraphExConfig:
def __init__(
self,
enable: bool = False,
enable: bool = True,
enable_static_kernel: bool = False,
fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = True,
@@ -274,7 +274,7 @@ class NpugraphExConfig:
enable (bool): Whether to enable npugraph_ex backend.
When set to True, the Fx graph generated by Dymano will be
optimized and compiled by the npugraph_ex backend.
Default: False
Default: True
enable_static_kernel (bool): Whether to enable static kernel.
Static kernel is suitable for scenarios with purely static shapes
or minimal shape changes, and can improve network performance.

View File

@@ -88,7 +88,7 @@ def npugraph_ex_compile(
# that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
decode_cudagraph_batch_sizes = [

View File

@@ -19,6 +19,7 @@
from torch import fx as fx
from vllm.config import VllmConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.15.0"):
@@ -55,18 +56,19 @@ class NpuGraphEXPassManager:
def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass.
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
if self.npugraph_ex_config.get("fuse_norm_quant", True):
self.npugraph_ex_config = get_ascend_config().npugraph_ex_config
if self.npugraph_ex_config.fuse_norm_quant:
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
self.passes.append(GraphEXAddRMSNormFusionPass(config))
if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
if self.npugraph_ex_config.fuse_qknorm_rope:
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
self.passes.append(GraphEXQKNormRopeFusionPass(config))
if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
if self.npugraph_ex_config.fuse_allreduce_rms:
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))

View File

@@ -249,3 +249,17 @@
# make unquantized_gemm as a customop.
# Future Plan:
# Remove this patch when vLLM support the operator as customop.
#
# ** 13. File: worker/patch_npugraph_ex_triton.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `torchair.core._concrete_graph.ValuePack`,
# `torchair.npu_fx_compiler._unpack_meta`,
# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu`
# Why:
# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters.
# How
# Supplement the relevant processing logic through patches.
# Related PR (if no, explain why):
# https://gitcode.com/Ascend/torchair/pull/2575
# Future Plan:
# Remove this patch when the PTA version used by vllm-ascend has been upgraded.

View File

@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_v2_egale # noqa
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa

View File

@@ -0,0 +1,116 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import sys
import torch
import torchair
from torch._subclasses.fake_tensor import FakeTensor
from torchair.core._concrete_graph import _is_symlist
from torchair.npu_fx_compiler import _unpack_meta_list
class ValuePack:
def __init__(self, meta, npu_meta=None) -> None:
self._meta = meta
self._npu_meta = meta if npu_meta is None else npu_meta
@property
def meta(self):
return self._meta
@property
def npu(self):
return self._npu_meta
def __getitem__(self, key):
if isinstance(self._meta, dict):
return self._meta.get(key)
raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}")
def __repr__(self) -> str:
if isinstance(self._meta, FakeTensor):
meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.Tensor):
meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.SymInt):
meta_str = f"torch.SymInt({self._meta})"
else:
try:
meta_str = f"{type(self._meta)}({self._meta})"
except Exception:
meta_str = f"{type(self._meta)}"
return f"Pack(meta:{meta_str} npu:{self._npu_meta})"
def _unpack_meta(args, kwargs):
unpacked_args = []
unpacked_kwargs = {}
def _get_meta_part(arg):
if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg):
return _unpack_meta_list(arg)
elif isinstance(arg, dict):
return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.meta
else:
return arg
for arg in args:
unpacked_args.append(_get_meta_part(arg))
for key, value in kwargs.items():
unpacked_kwargs[key] = _get_meta_part(value)
return list(unpacked_args), unpacked_kwargs
def _unpack_npu(self, args, kwargs):
unpacked = []
unpacked_kwargs = {}
def _get_npu_part(arg):
if isinstance(arg, (list, tuple)) and len(arg):
if _is_symlist(arg):
arg = self._graph.parse_symlist(arg)
else:
arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg]
return arg
elif isinstance(arg, dict):
return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.npu
else:
return arg
for arg in args:
unpacked.append(_get_npu_part(arg))
for key, value in kwargs.items():
unpacked_kwargs[key] = _get_npu_part(value)
return unpacked, unpacked_kwargs
torchair.core._concrete_graph.ValuePack = ValuePack
# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded.
importlib.reload(sys.modules["torchair.fx_summary"])
importlib.reload(sys.modules["torchair.npu_fx_compiler"])
torchair.npu_fx_compiler._unpack_meta = _unpack_meta
torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu

View File

@@ -210,6 +210,18 @@ class NPUPlatform(Platform):
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
)
npugraph_ex_config = ascend_config.npugraph_ex_config
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")