[npugraph_ex]enable npugraph_ex by default (#6664)

### What this PR does / why we need it?

This pull request enables the `npugraph_ex` backend by default to
improve performance on Ascend NPUs, as proposed in the
[RFC](https://github.com/vllm-project/vllm-ascend/issues/6214).


### Does this PR introduce _any_ user-facing change?

Yes. `npugraph_ex` is now enabled by default. Users can disable it by
setting `enable: false` in the `npugraph_ex_config` section of the
`additional_config`.

### How was this patch tested?

CI passed. The changes are covered by existing and new E2E tests
(`test_aclgraph_accuracy.py`) and unit tests (`test_ascend_config.py`)
that have been updated to reflect the new default behavior. The tests
verify correctness and consistency with `npugraph_ex` enabled and
disabled, as well as with the new static kernel option.

Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com>
Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
This commit is contained in:
iiiklw
2026-02-12 08:44:06 +08:00
committed by GitHub
parent b86ea66b0a
commit a0315f6697
10 changed files with 159 additions and 9 deletions

View File

@@ -94,7 +94,7 @@ The details of each configuration option are as follows:
| Name | Type | Default | Description | | Name | Type | Default | Description |
|------------------------| ---- |---------|----------------------------------------------------------------------------------------| |------------------------| ---- |---------|----------------------------------------------------------------------------------------|
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. | | `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. | | `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | | `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | | `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |

View File

@@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
"cudagraph_mode": "FULL_DECODE_ONLY" "cudagraph_mode": "FULL_DECODE_ONLY"
}, },
"quantization": cur_case.quantization, "quantization": cur_case.quantization,
"additional_config": {
"npugraph_ex_config": {
"enable": False
}
},
} }
gen_and_valid(runner_kwargs=runner_kwargs, gen_and_valid(runner_kwargs=runner_kwargs,
prompts=cur_case.prompts, prompts=cur_case.prompts,

View File

@@ -67,7 +67,7 @@ class TestAscendConfig(TestBase):
self.assertTrue(ascend_config.multistream_overlap_shared_expert) self.assertTrue(ascend_config.multistream_overlap_shared_expert)
npugraph_ex_config = ascend_config.npugraph_ex_config npugraph_ex_config = ascend_config.npugraph_ex_config
self.assertFalse(npugraph_ex_config.enable) self.assertTrue(npugraph_ex_config.enable)
self.assertFalse(npugraph_ex_config.enable_static_kernel) self.assertFalse(npugraph_ex_config.enable_static_kernel)
ascend_compilation_config = ascend_config.ascend_compilation_config ascend_compilation_config = ascend_config.ascend_compilation_config

View File

@@ -260,7 +260,7 @@ class NpugraphExConfig:
def __init__( def __init__(
self, self,
enable: bool = False, enable: bool = True,
enable_static_kernel: bool = False, enable_static_kernel: bool = False,
fuse_norm_quant: bool = True, fuse_norm_quant: bool = True,
fuse_qknorm_rope: bool = True, fuse_qknorm_rope: bool = True,
@@ -274,7 +274,7 @@ class NpugraphExConfig:
enable (bool): Whether to enable npugraph_ex backend. enable (bool): Whether to enable npugraph_ex backend.
When set to True, the Fx graph generated by Dymano will be When set to True, the Fx graph generated by Dymano will be
optimized and compiled by the npugraph_ex backend. optimized and compiled by the npugraph_ex backend.
Default: False Default: True
enable_static_kernel (bool): Whether to enable static kernel. enable_static_kernel (bool): Whether to enable static kernel.
Static kernel is suitable for scenarios with purely static shapes Static kernel is suitable for scenarios with purely static shapes
or minimal shape changes, and can improve network performance. or minimal shape changes, and can improve network performance.

View File

@@ -88,7 +88,7 @@ def npugraph_ex_compile(
# that can trigger the compilation of static kernel. If this configuration is # that can trigger the compilation of static kernel. If this configuration is
# not applied, new shapes will trigger the compilation of static kernels, # not applied, new shapes will trigger the compilation of static kernels,
# affecting program execution. # affecting program execution.
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0 num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
uniform_decode_query_len = num_spec_tokens + 1 uniform_decode_query_len = num_spec_tokens + 1
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
decode_cudagraph_batch_sizes = [ decode_cudagraph_batch_sizes = [

View File

@@ -19,6 +19,7 @@
from torch import fx as fx from torch import fx as fx
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.15.0"): if vllm_version_is("0.15.0"):
@@ -55,18 +56,19 @@ class NpuGraphEXPassManager:
def configure(self, config: VllmConfig): def configure(self, config: VllmConfig):
# By default, we enable the graph fusion and quantization fusion pass. # By default, we enable the graph fusion and quantization fusion pass.
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {}) self.npugraph_ex_config = get_ascend_config().npugraph_ex_config
if self.npugraph_ex_config.get("fuse_norm_quant", True):
if self.npugraph_ex_config.fuse_norm_quant:
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
self.passes.append(GraphEXAddRMSNormFusionPass(config)) self.passes.append(GraphEXAddRMSNormFusionPass(config))
if self.npugraph_ex_config.get("fuse_qknorm_rope", True): if self.npugraph_ex_config.fuse_qknorm_rope:
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
self.passes.append(GraphEXQKNormRopeFusionPass(config)) self.passes.append(GraphEXQKNormRopeFusionPass(config))
if self.npugraph_ex_config.get("fuse_allreduce_rms", True): if self.npugraph_ex_config.fuse_allreduce_rms:
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config)) self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))

View File

@@ -249,3 +249,17 @@
# make unquantized_gemm as a customop. # make unquantized_gemm as a customop.
# Future Plan: # Future Plan:
# Remove this patch when vLLM support the operator as customop. # Remove this patch when vLLM support the operator as customop.
#
# ** 13. File: worker/patch_npugraph_ex_triton.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `torchair.core._concrete_graph.ValuePack`,
# `torchair.npu_fx_compiler._unpack_meta`,
# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu`
# Why:
# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters.
# How
# Supplement the relevant processing logic through patches.
# Related PR (if no, explain why):
# https://gitcode.com/Ascend/torchair/pull/2575
# Future Plan:
# Remove this patch when the PTA version used by vllm-ascend has been upgraded.

View File

@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa
import vllm_ascend.patch.worker.patch_v2_egale # noqa import vllm_ascend.patch.worker.patch_v2_egale # noqa
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa

View File

@@ -0,0 +1,116 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import sys
import torch
import torchair
from torch._subclasses.fake_tensor import FakeTensor
from torchair.core._concrete_graph import _is_symlist
from torchair.npu_fx_compiler import _unpack_meta_list
class ValuePack:
def __init__(self, meta, npu_meta=None) -> None:
self._meta = meta
self._npu_meta = meta if npu_meta is None else npu_meta
@property
def meta(self):
return self._meta
@property
def npu(self):
return self._npu_meta
def __getitem__(self, key):
if isinstance(self._meta, dict):
return self._meta.get(key)
raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}")
def __repr__(self) -> str:
if isinstance(self._meta, FakeTensor):
meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.Tensor):
meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
elif isinstance(self._meta, torch.SymInt):
meta_str = f"torch.SymInt({self._meta})"
else:
try:
meta_str = f"{type(self._meta)}({self._meta})"
except Exception:
meta_str = f"{type(self._meta)}"
return f"Pack(meta:{meta_str} npu:{self._npu_meta})"
def _unpack_meta(args, kwargs):
unpacked_args = []
unpacked_kwargs = {}
def _get_meta_part(arg):
if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg):
return _unpack_meta_list(arg)
elif isinstance(arg, dict):
return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.meta
else:
return arg
for arg in args:
unpacked_args.append(_get_meta_part(arg))
for key, value in kwargs.items():
unpacked_kwargs[key] = _get_meta_part(value)
return list(unpacked_args), unpacked_kwargs
def _unpack_npu(self, args, kwargs):
unpacked = []
unpacked_kwargs = {}
def _get_npu_part(arg):
if isinstance(arg, (list, tuple)) and len(arg):
if _is_symlist(arg):
arg = self._graph.parse_symlist(arg)
else:
arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg]
return arg
elif isinstance(arg, dict):
return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()}
elif isinstance(arg, ValuePack):
return arg.npu
else:
return arg
for arg in args:
unpacked.append(_get_npu_part(arg))
for key, value in kwargs.items():
unpacked_kwargs[key] = _get_npu_part(value)
return unpacked, unpacked_kwargs
torchair.core._concrete_graph.ValuePack = ValuePack
# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded.
importlib.reload(sys.modules["torchair.fx_summary"])
importlib.reload(sys.modules["torchair.npu_fx_compiler"])
torchair.npu_fx_compiler._unpack_meta = _unpack_meta
torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu

View File

@@ -210,6 +210,18 @@ class NPUPlatform(Platform):
"{new_compile_ranges_split_points} for matmul and allreduce fusion" "{new_compile_ranges_split_points} for matmul and allreduce fusion"
) )
npugraph_ex_config = ascend_config.npugraph_ex_config
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
logger.debug(
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
)
elif model_config and hasattr(model_config.hf_text_config, "index_topk"): elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "") vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")