[npugraph_ex]enable npugraph_ex by default (#6664)
### What this PR does / why we need it? This pull request enables the `npugraph_ex` backend by default to improve performance on Ascend NPUs, as proposed in the [RFC](https://github.com/vllm-project/vllm-ascend/issues/6214). ### Does this PR introduce _any_ user-facing change? Yes. `npugraph_ex` is now enabled by default. Users can disable it by setting `enable: false` in the `npugraph_ex_config` section of the `additional_config`. ### How was this patch tested? CI passed. The changes are covered by existing and new E2E tests (`test_aclgraph_accuracy.py`) and unit tests (`test_ascend_config.py`) that have been updated to reflect the new default behavior. The tests verify correctness and consistency with `npugraph_ex` enabled and disabled, as well as with the new static kernel option. Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com> Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
This commit is contained in:
@@ -94,7 +94,7 @@ The details of each configuration option are as follows:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
|
||||
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
|
||||
| `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
|
||||
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
|
||||
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
|
||||
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
|
||||
|
||||
@@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY"
|
||||
},
|
||||
"quantization": cur_case.quantization,
|
||||
"additional_config": {
|
||||
"npugraph_ex_config": {
|
||||
"enable": False
|
||||
}
|
||||
},
|
||||
}
|
||||
gen_and_valid(runner_kwargs=runner_kwargs,
|
||||
prompts=cur_case.prompts,
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestAscendConfig(TestBase):
|
||||
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
|
||||
|
||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||
self.assertFalse(npugraph_ex_config.enable)
|
||||
self.assertTrue(npugraph_ex_config.enable)
|
||||
self.assertFalse(npugraph_ex_config.enable_static_kernel)
|
||||
|
||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||
|
||||
@@ -260,7 +260,7 @@ class NpugraphExConfig:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enable: bool = False,
|
||||
enable: bool = True,
|
||||
enable_static_kernel: bool = False,
|
||||
fuse_norm_quant: bool = True,
|
||||
fuse_qknorm_rope: bool = True,
|
||||
@@ -274,7 +274,7 @@ class NpugraphExConfig:
|
||||
enable (bool): Whether to enable npugraph_ex backend.
|
||||
When set to True, the Fx graph generated by Dymano will be
|
||||
optimized and compiled by the npugraph_ex backend.
|
||||
Default: False
|
||||
Default: True
|
||||
enable_static_kernel (bool): Whether to enable static kernel.
|
||||
Static kernel is suitable for scenarios with purely static shapes
|
||||
or minimal shape changes, and can improve network performance.
|
||||
|
||||
@@ -88,7 +88,7 @@ def npugraph_ex_compile(
|
||||
# that can trigger the compilation of static kernel. If this configuration is
|
||||
# not applied, new shapes will trigger the compilation of static kernels,
|
||||
# affecting program execution.
|
||||
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
|
||||
num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
|
||||
uniform_decode_query_len = num_spec_tokens + 1
|
||||
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
|
||||
decode_cudagraph_batch_sizes = [
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from torch import fx as fx
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.15.0"):
|
||||
@@ -55,18 +56,19 @@ class NpuGraphEXPassManager:
|
||||
|
||||
def configure(self, config: VllmConfig):
|
||||
# By default, we enable the graph fusion and quantization fusion pass.
|
||||
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
|
||||
if self.npugraph_ex_config.get("fuse_norm_quant", True):
|
||||
self.npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||
|
||||
if self.npugraph_ex_config.fuse_norm_quant:
|
||||
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
|
||||
|
||||
self.passes.append(GraphEXAddRMSNormFusionPass(config))
|
||||
|
||||
if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
|
||||
if self.npugraph_ex_config.fuse_qknorm_rope:
|
||||
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
|
||||
|
||||
self.passes.append(GraphEXQKNormRopeFusionPass(config))
|
||||
|
||||
if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
|
||||
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
|
||||
|
||||
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))
|
||||
|
||||
@@ -249,3 +249,17 @@
|
||||
# make unquantized_gemm as a customop.
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM support the operator as customop.
|
||||
#
|
||||
# ** 13. File: worker/patch_npugraph_ex_triton.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `torchair.core._concrete_graph.ValuePack`,
|
||||
# `torchair.npu_fx_compiler._unpack_meta`,
|
||||
# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu`
|
||||
# Why:
|
||||
# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters.
|
||||
# How:
|
||||
# Supplement the relevant processing logic through patches.
|
||||
# Related PR (if no, explain why):
|
||||
# https://gitcode.com/Ascend/torchair/pull/2575
|
||||
# Future Plan:
|
||||
# Remove this patch when the PTA version used by vllm-ascend has been upgraded.
|
||||
|
||||
@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||
import vllm_ascend.patch.worker.patch_v2_egale # noqa
|
||||
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
||||
|
||||
116
vllm_ascend/patch/worker/patch_npugraph_ex_triton.py
Normal file
116
vllm_ascend/patch/worker/patch_npugraph_ex_triton.py
Normal file
@@ -0,0 +1,116 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torchair
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torchair.core._concrete_graph import _is_symlist
|
||||
from torchair.npu_fx_compiler import _unpack_meta_list
|
||||
|
||||
|
||||
class ValuePack:
|
||||
def __init__(self, meta, npu_meta=None) -> None:
|
||||
self._meta = meta
|
||||
self._npu_meta = meta if npu_meta is None else npu_meta
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
return self._meta
|
||||
|
||||
@property
|
||||
def npu(self):
|
||||
return self._npu_meta
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(self._meta, dict):
|
||||
return self._meta.get(key)
|
||||
raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if isinstance(self._meta, FakeTensor):
|
||||
meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
|
||||
elif isinstance(self._meta, torch.Tensor):
|
||||
meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
|
||||
elif isinstance(self._meta, torch.SymInt):
|
||||
meta_str = f"torch.SymInt({self._meta})"
|
||||
else:
|
||||
try:
|
||||
meta_str = f"{type(self._meta)}({self._meta})"
|
||||
except Exception:
|
||||
meta_str = f"{type(self._meta)}"
|
||||
return f"Pack(meta:{meta_str} npu:{self._npu_meta})"
|
||||
|
||||
|
||||
def _unpack_meta(args, kwargs):
|
||||
unpacked_args = []
|
||||
unpacked_kwargs = {}
|
||||
|
||||
def _get_meta_part(arg):
|
||||
if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg):
|
||||
return _unpack_meta_list(arg)
|
||||
elif isinstance(arg, dict):
|
||||
return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()}
|
||||
elif isinstance(arg, ValuePack):
|
||||
return arg.meta
|
||||
else:
|
||||
return arg
|
||||
|
||||
for arg in args:
|
||||
unpacked_args.append(_get_meta_part(arg))
|
||||
|
||||
for key, value in kwargs.items():
|
||||
unpacked_kwargs[key] = _get_meta_part(value)
|
||||
|
||||
return list(unpacked_args), unpacked_kwargs
|
||||
|
||||
|
||||
def _unpack_npu(self, args, kwargs):
|
||||
unpacked = []
|
||||
unpacked_kwargs = {}
|
||||
|
||||
def _get_npu_part(arg):
|
||||
if isinstance(arg, (list, tuple)) and len(arg):
|
||||
if _is_symlist(arg):
|
||||
arg = self._graph.parse_symlist(arg)
|
||||
else:
|
||||
arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg]
|
||||
return arg
|
||||
elif isinstance(arg, dict):
|
||||
return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()}
|
||||
elif isinstance(arg, ValuePack):
|
||||
return arg.npu
|
||||
else:
|
||||
return arg
|
||||
|
||||
for arg in args:
|
||||
unpacked.append(_get_npu_part(arg))
|
||||
|
||||
for key, value in kwargs.items():
|
||||
unpacked_kwargs[key] = _get_npu_part(value)
|
||||
|
||||
return unpacked, unpacked_kwargs
|
||||
|
||||
|
||||
torchair.core._concrete_graph.ValuePack = ValuePack
|
||||
# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded.
|
||||
importlib.reload(sys.modules["torchair.fx_summary"])
|
||||
importlib.reload(sys.modules["torchair.npu_fx_compiler"])
|
||||
torchair.npu_fx_compiler._unpack_meta = _unpack_meta
|
||||
torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu
|
||||
@@ -210,6 +210,18 @@ class NPUPlatform(Platform):
|
||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
|
||||
|
||||
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
|
||||
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||
logger.debug(
|
||||
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||
)
|
||||
|
||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user