From dd622aa6a6ea34a6bc799d52171ebe6c3e062972 Mon Sep 17 00:00:00 2001 From: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:48:05 +0800 Subject: [PATCH] [Feature] Support npuhraph_ex backend (#4700) ### What this PR does / why we need it? We introduced the npugraph_ex backend through the vllm's adaptor dispatch mechanism to accelerate aclgraph. This solution is based on torch.compile and uses torchair to optimize the fx.graph. The performance gains are mainly obtained from the static kernel. We conducted tests on Qwen3-30B and achieved over 5% performance optimization. ### Does this PR introduce _any_ user-facing change? Yes, we add a new switch named"enable_npugraph_ex" in additional_config, default is False. We also add an example to show how to register custom replacement pass ### More information about this PR This feature depends on the release of CANN and torch_npu in Q4. We tested it on a package that has not been publicly released yet and verified that the functionality works. This feature is still experimental at the moment; setting the config true will directly raise error. Merging into the main branch initially involves some preliminary commits to facilitate subsequent development and testing of the feature, as well as to avoid submitting an excessively large PR at once. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: chencangtao Signed-off-by: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Co-authored-by: chencangtao Co-authored-by: panchao-hub <315134829@qq.com> Co-authored-by: wbigat Co-authored-by: Mengqing Cao --- tests/ut/test_ascend_config.py | 11 ++ vllm_ascend/ascend_config.py | 6 + vllm_ascend/compilation/compiler_interface.py | 94 ++++++++++--- .../npugraph_ex_passes/__init__.py | 0 .../npugraph_ex_passes/add_rms_norm_quant.py | 123 ++++++++++++++++++ vllm_ascend/platform.py | 4 + vllm_ascend/worker/model_runner_v1.py | 12 ++ 7 files changed, 235 insertions(+), 15 deletions(-) create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/__init__.py create mode 100644 vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index a92bbc80..f56cf23b 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -57,10 +57,21 @@ class TestAscendConfig(TestBase): ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") self.assertTrue(ascend_config.multistream_overlap_shared_expert) + self.assertFalse(ascend_config.enable_npugraph_ex) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.enable_quantization_fusion) + @_clean_up_ascend_config + def test_init_ascend_config_enable_npugraph_ex(self): + with self.assertRaises(NotImplementedError): + test_vllm_config = VllmConfig() + test_vllm_config.additional_config = { + "enable_npugraph_ex": True, + "refresh": True, + } + init_ascend_config(test_vllm_config) + @_clean_up_ascend_config def test_get_ascend_config(self): test_vllm_config = VllmConfig() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 54ef914a..d27263bb 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -169,6 +169,12 @@ class AscendConfig: get_flashcomm2_oproj_tp_size_and_validate_config self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( self, vllm_config) + self.enable_npugraph_ex = additional_config.get( + "enable_npugraph_ex", False) + if self.enable_npugraph_ex: + raise NotImplementedError( + "This feature is still in the experiment and will be supported soon." + ) kv_cfg = vllm_config.kv_transfer_config if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched", False): diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 23e07bb4..4bb7deae 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -18,6 +18,7 @@ import functools from typing import Any, Callable, Optional +import torch import torch.fx as fx from torch._dynamo.backends.common import aot_autograd from torch._inductor.compile_fx import (graph_returns_tuple, @@ -26,6 +27,8 @@ from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface +from vllm_ascend.ascend_config import get_ascend_config + def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable: @@ -39,6 +42,75 @@ def compile_fx(graph: GraphModule, example_inputs: list, return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs) +def fusion_pass_compile( + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, +) -> tuple[Optional[Callable], Optional[Any]]: + + def compile_inner(graph, example_inputs): + current_pass_manager = compiler_config["graph_fusion_manager"] + graph = current_pass_manager(graph, runtime_shape) + return graph + + decompositions = select_decomp_table() + + compiled_fn = compile_fx( + graph=graph, + example_inputs=example_inputs, + inner_compile=compile_inner, + decompositions=decompositions, + ) + + return compiled_fn, None + + +def npugraph_ex_compile( + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, +) -> tuple[Optional[Callable], Optional[Any]]: + # When currently using the FULL_DECODE_ONLY mode, + # the piecewise compilation level slicing process + # in vllm is also encountered. + # This process causes the output to no longer be + # wrapped as a tuple when the fx graph has a single + # output, but torch.compile has a mandatory check. + fx_graph = graph.graph + if not graph_returns_tuple(graph): + output_node = fx_graph.output_node() + with fx_graph.inserting_before(output_node): + return_value = output_node.args[0] + tuple_node = fx_graph.create_node("call_function", + tuple, + args=([return_value], )) + output_node.args = (tuple_node, ) + fx_graph.recompile() + + import torchair + + # TODO: use a better way to lazy register replacement, instead of import one by one + # As an example, we directly import here to register replacement. + import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa + + torch.npu.set_compile_mode(jit_compile=False) + config = torchair.CompilerConfig() + # use aclgraph mode, avoid the transformation from fx graph to Ascend IR. + config.mode = "reduce-overhead" + # execute FX graph in eager mode before graph mode to optimize FX graph. + config.debug.run_eagerly = True + # static kernel switch, suitable for static shapes or scenes with less shape changes. + config.experimental_config.aclgraph._aclnn_static_shape_kernel = True + + npugraph_ex = torchair.get_npu_backend(compiler_config=config) + compile_graph = npugraph_ex(graph, example_inputs) + return compile_graph, None + + class AscendCompiler(CompilerInterface): """ AscendCompiler is a custom compiler interface for the Ascend platform. @@ -56,18 +128,10 @@ class AscendCompiler(CompilerInterface): key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: - def compile_inner(graph, example_inputs): - current_pass_manager = compiler_config["graph_fusion_manager"] - graph = current_pass_manager(graph, runtime_shape) - return graph - - decompositions = select_decomp_table() - - compiled_fn = compile_fx( - graph=graph, - example_inputs=example_inputs, - inner_compile=compile_inner, - decompositions=decompositions, - ) - - return compiled_fn, None + ascend_config = get_ascend_config() + if ascend_config.enable_npugraph_ex: + return npugraph_ex_compile(graph, example_inputs, compiler_config, + runtime_shape, key) + else: + return fusion_pass_compile(graph, example_inputs, compiler_config, + runtime_shape, key) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/__init__.py b/vllm_ascend/compilation/npugraph_ex_passes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py new file mode 100644 index 00000000..724d8140 --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -0,0 +1,123 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import sys + +import torch +from torch._inductor.pattern_matcher import Match +from vllm.logger import logger + + +@functools.lru_cache(None) +# The replacement registered here will be actually executed after AOT. +def _register_replacement(epsilon): + if 'torch_npu' not in sys.modules: + logger.info( + 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.' + 'When there is no torch_npu in the env, skip fusion.') + return + + def _extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations.") + return False + + return True + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.tensor([1.0], device="npu") + offset = torch.tensor([0.0], device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, offset] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + +# register converter for pass +common_epsilons = [1e-5, 1e-6] +for eps in common_epsilons: + logger.info( + f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}" + ) + _register_replacement(eps) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c9abc8ff..8a508acf 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -231,6 +231,7 @@ class NPUPlatform(Platform): if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.mode = CompilationMode.NONE + ascend_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " @@ -241,12 +242,14 @@ class NPUPlatform(Platform): compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) + ascend_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") compilation_config.use_inductor = False + compilation_config.splitting_ops = [] warning_message = """\033[91m ********************************************************************************** * WARNING: You have enabled the *full graph* feature. @@ -266,6 +269,7 @@ class NPUPlatform(Platform): compilation_config.cudagraph_mode) compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.mode = CompilationMode.NONE + ascend_config.enable_npugraph_ex = False # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # Then, we will have to discuss the error handling strategy and user experience diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b9e23334..bfb9a510 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2575,6 +2575,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.debugger.stop() self.debugger.step() return pool_output + # Sometimes, after the model is compiled through the AOT backend, + # the model output may become a list containing only one Tensor object. + if isinstance(hidden_states, list) and \ + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): + hidden_states = hidden_states[0] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: @@ -3296,6 +3302,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task + # Sometimes, after the model is compiled through the AOT backend, + # the model output may become a list containing only one Tensor object. + if isinstance(hidden_states, list) and \ + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): + hidden_states = hidden_states[0] hidden_states = hidden_states[logit_indices] output = self.model.compute_logits(hidden_states)