[npugraph_ex]enable npugraph_ex by default (#6664)
### What this PR does / why we need it? This pull request enables the `npugraph_ex` backend by default to improve performance on Ascend NPUs, as proposed in the [RFC](https://github.com/vllm-project/vllm-ascend/issues/6214). ### Does this PR introduce _any_ user-facing change? Yes. `npugraph_ex` is now enabled by default. Users can disable it by setting `enable: false` in the `npugraph_ex_config` section of the `additional_config`. ### How was this patch tested? CI passed. The changes are covered by existing and new E2E tests (`test_aclgraph_accuracy.py`) and unit tests (`test_ascend_config.py`) that have been updated to reflect the new default behavior. The tests verify correctness and consistency with `npugraph_ex` enabled and disabled, as well as with the new static kernel option. Signed-off-by: huyuanquan1 <huyuanquan1@huawei.com> Co-authored-by: huyuanquan1 <huyuanquan1@huawei.com>
This commit is contained in:
@@ -94,7 +94,7 @@ The details of each configuration option are as follows:
|
|||||||
|
|
||||||
| Name | Type | Default | Description |
|
| Name | Type | Default | Description |
|
||||||
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
|
|------------------------| ---- |---------|----------------------------------------------------------------------------------------|
|
||||||
| `enable` | bool | `False` | Whether to enable npugraph_ex backend. |
|
| `enable` | bool | `True` | Whether to enable npugraph_ex backend. |
|
||||||
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
|
| `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. |
|
||||||
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
|
| `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. |
|
||||||
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
|
| `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. |
|
||||||
|
|||||||
@@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch):
|
|||||||
"cudagraph_mode": "FULL_DECODE_ONLY"
|
"cudagraph_mode": "FULL_DECODE_ONLY"
|
||||||
},
|
},
|
||||||
"quantization": cur_case.quantization,
|
"quantization": cur_case.quantization,
|
||||||
|
"additional_config": {
|
||||||
|
"npugraph_ex_config": {
|
||||||
|
"enable": False
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
gen_and_valid(runner_kwargs=runner_kwargs,
|
gen_and_valid(runner_kwargs=runner_kwargs,
|
||||||
prompts=cur_case.prompts,
|
prompts=cur_case.prompts,
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class TestAscendConfig(TestBase):
|
|||||||
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
|
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
|
||||||
|
|
||||||
npugraph_ex_config = ascend_config.npugraph_ex_config
|
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||||
self.assertFalse(npugraph_ex_config.enable)
|
self.assertTrue(npugraph_ex_config.enable)
|
||||||
self.assertFalse(npugraph_ex_config.enable_static_kernel)
|
self.assertFalse(npugraph_ex_config.enable_static_kernel)
|
||||||
|
|
||||||
ascend_compilation_config = ascend_config.ascend_compilation_config
|
ascend_compilation_config = ascend_config.ascend_compilation_config
|
||||||
|
|||||||
@@ -260,7 +260,7 @@ class NpugraphExConfig:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enable: bool = False,
|
enable: bool = True,
|
||||||
enable_static_kernel: bool = False,
|
enable_static_kernel: bool = False,
|
||||||
fuse_norm_quant: bool = True,
|
fuse_norm_quant: bool = True,
|
||||||
fuse_qknorm_rope: bool = True,
|
fuse_qknorm_rope: bool = True,
|
||||||
@@ -274,7 +274,7 @@ class NpugraphExConfig:
|
|||||||
enable (bool): Whether to enable npugraph_ex backend.
|
enable (bool): Whether to enable npugraph_ex backend.
|
||||||
When set to True, the Fx graph generated by Dymano will be
|
When set to True, the Fx graph generated by Dymano will be
|
||||||
optimized and compiled by the npugraph_ex backend.
|
optimized and compiled by the npugraph_ex backend.
|
||||||
Default: False
|
Default: True
|
||||||
enable_static_kernel (bool): Whether to enable static kernel.
|
enable_static_kernel (bool): Whether to enable static kernel.
|
||||||
Static kernel is suitable for scenarios with purely static shapes
|
Static kernel is suitable for scenarios with purely static shapes
|
||||||
or minimal shape changes, and can improve network performance.
|
or minimal shape changes, and can improve network performance.
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ def npugraph_ex_compile(
|
|||||||
# that can trigger the compilation of static kernel. If this configuration is
|
# that can trigger the compilation of static kernel. If this configuration is
|
||||||
# not applied, new shapes will trigger the compilation of static kernels,
|
# not applied, new shapes will trigger the compilation of static kernels,
|
||||||
# affecting program execution.
|
# affecting program execution.
|
||||||
num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0
|
num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0
|
||||||
uniform_decode_query_len = num_spec_tokens + 1
|
uniform_decode_query_len = num_spec_tokens + 1
|
||||||
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
|
max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len
|
||||||
decode_cudagraph_batch_sizes = [
|
decode_cudagraph_batch_sizes = [
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
if vllm_version_is("0.15.0"):
|
if vllm_version_is("0.15.0"):
|
||||||
@@ -55,18 +56,19 @@ class NpuGraphEXPassManager:
|
|||||||
|
|
||||||
def configure(self, config: VllmConfig):
|
def configure(self, config: VllmConfig):
|
||||||
# By default, we enable the graph fusion and quantization fusion pass.
|
# By default, we enable the graph fusion and quantization fusion pass.
|
||||||
self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {})
|
self.npugraph_ex_config = get_ascend_config().npugraph_ex_config
|
||||||
if self.npugraph_ex_config.get("fuse_norm_quant", True):
|
|
||||||
|
if self.npugraph_ex_config.fuse_norm_quant:
|
||||||
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
|
from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass
|
||||||
|
|
||||||
self.passes.append(GraphEXAddRMSNormFusionPass(config))
|
self.passes.append(GraphEXAddRMSNormFusionPass(config))
|
||||||
|
|
||||||
if self.npugraph_ex_config.get("fuse_qknorm_rope", True):
|
if self.npugraph_ex_config.fuse_qknorm_rope:
|
||||||
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
|
from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass
|
||||||
|
|
||||||
self.passes.append(GraphEXQKNormRopeFusionPass(config))
|
self.passes.append(GraphEXQKNormRopeFusionPass(config))
|
||||||
|
|
||||||
if self.npugraph_ex_config.get("fuse_allreduce_rms", True):
|
if self.npugraph_ex_config.fuse_allreduce_rms:
|
||||||
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
|
from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass
|
||||||
|
|
||||||
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))
|
self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config))
|
||||||
|
|||||||
@@ -249,3 +249,17 @@
|
|||||||
# make unquantized_gemm as a customop.
|
# make unquantized_gemm as a customop.
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when vLLM support the operator as customop.
|
# Remove this patch when vLLM support the operator as customop.
|
||||||
|
#
|
||||||
|
# ** 13. File: worker/patch_npugraph_ex_triton.py**
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
# 1. `torchair.core._concrete_graph.ValuePack`,
|
||||||
|
# `torchair.npu_fx_compiler._unpack_meta`,
|
||||||
|
# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu`
|
||||||
|
# Why:
|
||||||
|
# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters.
|
||||||
|
# How:
|
||||||
|
# Supplement the relevant processing logic through patches.
|
||||||
|
# Related PR (if no, explain why):
|
||||||
|
# https://gitcode.com/Ascend/torchair/pull/2575
|
||||||
|
# Future Plan:
|
||||||
|
# Remove this patch when the PTA version used by vllm-ascend has been upgraded.
|
||||||
|
|||||||
@@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_rejection_sampler # noqa
|
|||||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||||
import vllm_ascend.patch.worker.patch_v2_egale # noqa
|
import vllm_ascend.patch.worker.patch_v2_egale # noqa
|
||||||
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
|
import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
||||||
|
|||||||
116
vllm_ascend/patch/worker/patch_npugraph_ex_triton.py
Normal file
116
vllm_ascend/patch/worker/patch_npugraph_ex_triton.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchair
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
|
from torchair.core._concrete_graph import _is_symlist
|
||||||
|
from torchair.npu_fx_compiler import _unpack_meta_list
|
||||||
|
|
||||||
|
|
||||||
|
class ValuePack:
|
||||||
|
def __init__(self, meta, npu_meta=None) -> None:
|
||||||
|
self._meta = meta
|
||||||
|
self._npu_meta = meta if npu_meta is None else npu_meta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def meta(self):
|
||||||
|
return self._meta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def npu(self):
|
||||||
|
return self._npu_meta
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if isinstance(self._meta, dict):
|
||||||
|
return self._meta.get(key)
|
||||||
|
raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if isinstance(self._meta, FakeTensor):
|
||||||
|
meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
|
||||||
|
elif isinstance(self._meta, torch.Tensor):
|
||||||
|
meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}"
|
||||||
|
elif isinstance(self._meta, torch.SymInt):
|
||||||
|
meta_str = f"torch.SymInt({self._meta})"
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
meta_str = f"{type(self._meta)}({self._meta})"
|
||||||
|
except Exception:
|
||||||
|
meta_str = f"{type(self._meta)}"
|
||||||
|
return f"Pack(meta:{meta_str} npu:{self._npu_meta})"
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_meta(args, kwargs):
|
||||||
|
unpacked_args = []
|
||||||
|
unpacked_kwargs = {}
|
||||||
|
|
||||||
|
def _get_meta_part(arg):
|
||||||
|
if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg):
|
||||||
|
return _unpack_meta_list(arg)
|
||||||
|
elif isinstance(arg, dict):
|
||||||
|
return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()}
|
||||||
|
elif isinstance(arg, ValuePack):
|
||||||
|
return arg.meta
|
||||||
|
else:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
unpacked_args.append(_get_meta_part(arg))
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
unpacked_kwargs[key] = _get_meta_part(value)
|
||||||
|
|
||||||
|
return list(unpacked_args), unpacked_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_npu(self, args, kwargs):
|
||||||
|
unpacked = []
|
||||||
|
unpacked_kwargs = {}
|
||||||
|
|
||||||
|
def _get_npu_part(arg):
|
||||||
|
if isinstance(arg, (list, tuple)) and len(arg):
|
||||||
|
if _is_symlist(arg):
|
||||||
|
arg = self._graph.parse_symlist(arg)
|
||||||
|
else:
|
||||||
|
arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg]
|
||||||
|
return arg
|
||||||
|
elif isinstance(arg, dict):
|
||||||
|
return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()}
|
||||||
|
elif isinstance(arg, ValuePack):
|
||||||
|
return arg.npu
|
||||||
|
else:
|
||||||
|
return arg
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
unpacked.append(_get_npu_part(arg))
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
unpacked_kwargs[key] = _get_npu_part(value)
|
||||||
|
|
||||||
|
return unpacked, unpacked_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
torchair.core._concrete_graph.ValuePack = ValuePack
|
||||||
|
# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded.
|
||||||
|
importlib.reload(sys.modules["torchair.fx_summary"])
|
||||||
|
importlib.reload(sys.modules["torchair.npu_fx_compiler"])
|
||||||
|
torchair.npu_fx_compiler._unpack_meta = _unpack_meta
|
||||||
|
torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu
|
||||||
@@ -210,6 +210,18 @@ class NPUPlatform(Platform):
|
|||||||
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
"{new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
npugraph_ex_config = ascend_config.npugraph_ex_config
|
||||||
|
if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms:
|
||||||
|
from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD
|
||||||
|
|
||||||
|
new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points
|
||||||
|
new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD)
|
||||||
|
new_compile_ranges_split_points = sorted(new_compile_ranges_split_points)
|
||||||
|
vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points
|
||||||
|
logger.debug(
|
||||||
|
"set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion"
|
||||||
|
)
|
||||||
|
|
||||||
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
elif model_config and hasattr(model_config.hf_text_config, "index_topk"):
|
||||||
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user