diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 169ffeab..51315d6f 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -94,7 +94,7 @@ The details of each configuration option are as follows: | Name | Type | Default | Description | |------------------------| ---- |---------|----------------------------------------------------------------------------------------| -| `enable` | bool | `False` | Whether to enable npugraph_ex backend. | +| `enable` | bool | `True` | Whether to enable npugraph_ex backend. | | `enable_static_kernel` | bool | `False` | Whether to enable static kernel. Suitable for scenarios where shape changes are minimal and some time is available for static kernel compilation. | | `fuse_norm_quant` | bool | `True` | Whether to enable fuse_norm_quant pass. | | `fuse_qknorm_rope` | bool | `True` | Whether to enable fuse_qknorm_rope pass. If Triton is not in the environment, set it to False. | diff --git a/tests/e2e/singlecard/test_aclgraph_accuracy.py b/tests/e2e/singlecard/test_aclgraph_accuracy.py index 1ebd7b21..ac5c0de8 100644 --- a/tests/e2e/singlecard/test_aclgraph_accuracy.py +++ b/tests/e2e/singlecard/test_aclgraph_accuracy.py @@ -147,6 +147,11 @@ def test_full_decode_only_res_consistency(cur_case: LLMTestCase, monkeypatch): "cudagraph_mode": "FULL_DECODE_ONLY" }, "quantization": cur_case.quantization, + "additional_config": { + "npugraph_ex_config": { + "enable": False + } + }, } gen_and_valid(runner_kwargs=runner_kwargs, prompts=cur_case.prompts, diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index b48d6d59..c3c42d4d 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -67,7 +67,7 @@ class TestAscendConfig(TestBase): self.assertTrue(ascend_config.multistream_overlap_shared_expert) npugraph_ex_config = ascend_config.npugraph_ex_config - self.assertFalse(npugraph_ex_config.enable) + self.assertTrue(npugraph_ex_config.enable) self.assertFalse(npugraph_ex_config.enable_static_kernel) ascend_compilation_config = ascend_config.ascend_compilation_config diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 31cf69c5..facb2fe1 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -260,7 +260,7 @@ class NpugraphExConfig: def __init__( self, - enable: bool = False, + enable: bool = True, enable_static_kernel: bool = False, fuse_norm_quant: bool = True, fuse_qknorm_rope: bool = True, @@ -274,7 +274,7 @@ class NpugraphExConfig: enable (bool): Whether to enable npugraph_ex backend. When set to True, the Fx graph generated by Dymano will be optimized and compiled by the npugraph_ex backend. - Default: False + Default: True enable_static_kernel (bool): Whether to enable static kernel. Static kernel is suitable for scenarios with purely static shapes or minimal shape changes, and can improve network performance. diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index cefac33d..69c2a377 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -88,7 +88,7 @@ def npugraph_ex_compile( # that can trigger the compilation of static kernel. If this configuration is # not applied, new shapes will trigger the compilation of static kernels, # affecting program execution. - num_spec_tokens = vllm_config.speculative_config.num_speculative_token if vllm_config.speculative_config else 0 + num_spec_tokens = vllm_config.speculative_config.num_speculative_tokens if vllm_config.speculative_config else 0 uniform_decode_query_len = num_spec_tokens + 1 max_num_tokens = vllm_config.scheduler_config.max_num_seqs * uniform_decode_query_len decode_cudagraph_batch_sizes = [ diff --git a/vllm_ascend/compilation/npu_graph_ex_pass_manager.py b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py index c5802cef..5f7466fc 100644 --- a/vllm_ascend/compilation/npu_graph_ex_pass_manager.py +++ b/vllm_ascend/compilation/npu_graph_ex_pass_manager.py @@ -19,6 +19,7 @@ from torch import fx as fx from vllm.config import VllmConfig +from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import vllm_version_is if vllm_version_is("0.15.0"): @@ -55,18 +56,19 @@ class NpuGraphEXPassManager: def configure(self, config: VllmConfig): # By default, we enable the graph fusion and quantization fusion pass. - self.npugraph_ex_config: dict = config.additional_config.get("npugraph_ex_config", {}) - if self.npugraph_ex_config.get("fuse_norm_quant", True): + self.npugraph_ex_config = get_ascend_config().npugraph_ex_config + + if self.npugraph_ex_config.fuse_norm_quant: from .npugraph_ex_passes.graphex_norm_quant_fusion_pass import GraphEXAddRMSNormFusionPass self.passes.append(GraphEXAddRMSNormFusionPass(config)) - if self.npugraph_ex_config.get("fuse_qknorm_rope", True): + if self.npugraph_ex_config.fuse_qknorm_rope: from .npugraph_ex_passes.graphex_qknorm_rope_fusion_pass import GraphEXQKNormRopeFusionPass self.passes.append(GraphEXQKNormRopeFusionPass(config)) - if self.npugraph_ex_config.get("fuse_allreduce_rms", True): + if self.npugraph_ex_config.fuse_allreduce_rms: from .npugraph_ex_passes.graphex_allreduce_rmsnorm_fusion_pass import GraphEXMatmulAllReduceAddRMSNormPass self.passes.append(GraphEXMatmulAllReduceAddRMSNormPass(config)) diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index a5a32dd1..4e46c171 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -249,3 +249,17 @@ # make unquantized_gemm as a customop. # Future Plan: # Remove this patch when vLLM support the operator as customop. +# +# ** 13. File: worker/patch_npugraph_ex_triton.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `torchair.core._concrete_graph.ValuePack`, +# `torchair.npu_fx_compiler._unpack_meta`, +# `torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu` +# Why: +# In the Triton scenario, npugraph_ex backend needs to process the value pack of the input parameters. +# How: +# Supplement the relevant processing logic through patches. +# Related PR (if no, explain why): +# https://gitcode.com/Ascend/torchair/pull/2575 +# Future Plan: +# Remove this patch when the PTA version used by vllm-ascend has been upgraded. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 1eac4d0c..bd44362d 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -33,3 +33,4 @@ import vllm_ascend.patch.worker.patch_rejection_sampler # noqa import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_v2_egale # noqa import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa +import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa diff --git a/vllm_ascend/patch/worker/patch_npugraph_ex_triton.py b/vllm_ascend/patch/worker/patch_npugraph_ex_triton.py new file mode 100644 index 00000000..98124098 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_npugraph_ex_triton.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import sys + +import torch +import torchair +from torch._subclasses.fake_tensor import FakeTensor +from torchair.core._concrete_graph import _is_symlist +from torchair.npu_fx_compiler import _unpack_meta_list + + +class ValuePack: + def __init__(self, meta, npu_meta=None) -> None: + self._meta = meta + self._npu_meta = meta if npu_meta is None else npu_meta + + @property + def meta(self): + return self._meta + + @property + def npu(self): + return self._npu_meta + + def __getitem__(self, key): + if isinstance(self._meta, dict): + return self._meta.get(key) + raise ValueError(f"Unsupported meta type for ValuePack __getitem__, key:{key}, type: {type(self._meta)}") + + def __repr__(self) -> str: + if isinstance(self._meta, FakeTensor): + meta_str = f"FakeTensor(dtype={self._meta.dtype}, size={list(self._meta.size())}" + elif isinstance(self._meta, torch.Tensor): + meta_str = f"torch.Tensor(dtype={self._meta.dtype}, size={list(self._meta.size())}" + elif isinstance(self._meta, torch.SymInt): + meta_str = f"torch.SymInt({self._meta})" + else: + try: + meta_str = f"{type(self._meta)}({self._meta})" + except Exception: + meta_str = f"{type(self._meta)}" + return f"Pack(meta:{meta_str} npu:{self._npu_meta})" + + +def _unpack_meta(args, kwargs): + unpacked_args = [] + unpacked_kwargs = {} + + def _get_meta_part(arg): + if isinstance(arg, (list, tuple)) and any(isinstance(v, ValuePack) for v in arg): + return _unpack_meta_list(arg) + elif isinstance(arg, dict): + return {k: v.meta if isinstance(v, ValuePack) else v for k, v in arg.items()} + elif isinstance(arg, ValuePack): + return arg.meta + else: + return arg + + for arg in args: + unpacked_args.append(_get_meta_part(arg)) + + for key, value in kwargs.items(): + unpacked_kwargs[key] = _get_meta_part(value) + + return list(unpacked_args), unpacked_kwargs + + +def _unpack_npu(self, args, kwargs): + unpacked = [] + unpacked_kwargs = {} + + def _get_npu_part(arg): + if isinstance(arg, (list, tuple)) and len(arg): + if _is_symlist(arg): + arg = self._graph.parse_symlist(arg) + else: + arg = [(v.npu if isinstance(v, ValuePack) else v) for v in arg] + return arg + elif isinstance(arg, dict): + return {k: v.npu if isinstance(v, ValuePack) else v for k, v in arg.items()} + elif isinstance(arg, ValuePack): + return arg.npu + else: + return arg + + for arg in args: + unpacked.append(_get_npu_part(arg)) + + for key, value in kwargs.items(): + unpacked_kwargs[key] = _get_npu_part(value) + + return unpacked, unpacked_kwargs + + +torchair.core._concrete_graph.ValuePack = ValuePack +# The ValuePack class is referenced in these two modules, and after the patch, these two modules need to be reloaded. +importlib.reload(sys.modules["torchair.fx_summary"]) +importlib.reload(sys.modules["torchair.npu_fx_compiler"]) +torchair.npu_fx_compiler._unpack_meta = _unpack_meta +torchair.npu_fx_compiler._NpuGraphConverter._unpack_npu = _unpack_npu diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 00c4e6cf..da3ca2f7 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -210,6 +210,18 @@ class NPUPlatform(Platform): "{new_compile_ranges_split_points} for matmul and allreduce fusion" ) + npugraph_ex_config = ascend_config.npugraph_ex_config + if npugraph_ex_config and npugraph_ex_config.fuse_allreduce_rms: + from vllm_ascend.compilation.passes.allreduce_rmsnorm_fusion_pass import ALLREDUCE_NORM_FUSE_THREHOLD + + new_compile_ranges_split_points = vllm_config.compilation_config.compile_ranges_split_points + new_compile_ranges_split_points.append(ALLREDUCE_NORM_FUSE_THREHOLD) + new_compile_ranges_split_points = sorted(new_compile_ranges_split_points) + vllm_config.compilation_config.compile_ranges_split_points = new_compile_ranges_split_points + logger.debug( + "set compile_ranges_split_points to {new_compile_ranges_split_points} for matmul and allreduce fusion" + ) + elif model_config and hasattr(model_config.hf_text_config, "index_topk"): vllm_config.cache_config.cache_dtype = str(model_config.dtype).replace("torch.", "")