diff --git a/tests/ut/test_ascend_config.py b/tests/ut/test_ascend_config.py index a92bbc80..f56cf23b 100644 --- a/tests/ut/test_ascend_config.py +++ b/tests/ut/test_ascend_config.py @@ -57,10 +57,21 @@ class TestAscendConfig(TestBase): ascend_config = init_ascend_config(test_vllm_config) self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path") self.assertTrue(ascend_config.multistream_overlap_shared_expert) + self.assertFalse(ascend_config.enable_npugraph_ex) ascend_compilation_config = ascend_config.ascend_compilation_config self.assertFalse(ascend_compilation_config.enable_quantization_fusion) + @_clean_up_ascend_config + def test_init_ascend_config_enable_npugraph_ex(self): + with self.assertRaises(NotImplementedError): + test_vllm_config = VllmConfig() + test_vllm_config.additional_config = { + "enable_npugraph_ex": True, + "refresh": True, + } + init_ascend_config(test_vllm_config) + @_clean_up_ascend_config def test_get_ascend_config(self): test_vllm_config = VllmConfig() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 54ef914a..d27263bb 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -169,6 +169,12 @@ class AscendConfig: get_flashcomm2_oproj_tp_size_and_validate_config self.flashcomm2_oproj_tensor_parallel_size = get_flashcomm2_oproj_tp_size_and_validate_config( self, vllm_config) + self.enable_npugraph_ex = additional_config.get( + "enable_npugraph_ex", False) + if self.enable_npugraph_ex: + raise NotImplementedError( + "This feature is still in the experiment and will be supported soon." + ) kv_cfg = vllm_config.kv_transfer_config if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched", False): diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index 23e07bb4..4bb7deae 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -18,6 +18,7 @@ import functools from typing import Any, Callable, Optional +import torch import torch.fx as fx from torch._dynamo.backends.common import aot_autograd from torch._inductor.compile_fx import (graph_returns_tuple, @@ -26,6 +27,8 @@ from torch._inductor.decomposition import select_decomp_table from torch.fx import GraphModule from vllm.compilation.compiler_interface import CompilerInterface +from vllm_ascend.ascend_config import get_ascend_config + def compile_fx(graph: GraphModule, example_inputs: list, inner_compile: Callable, decompositions: dict) -> Callable: @@ -39,6 +42,75 @@ def compile_fx(graph: GraphModule, example_inputs: list, return aot_autograd(fw_compiler=inner_compile)(graph, example_inputs) +def fusion_pass_compile( + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, +) -> tuple[Optional[Callable], Optional[Any]]: + + def compile_inner(graph, example_inputs): + current_pass_manager = compiler_config["graph_fusion_manager"] + graph = current_pass_manager(graph, runtime_shape) + return graph + + decompositions = select_decomp_table() + + compiled_fn = compile_fx( + graph=graph, + example_inputs=example_inputs, + inner_compile=compile_inner, + decompositions=decompositions, + ) + + return compiled_fn, None + + +def npugraph_ex_compile( + graph: fx.GraphModule, + example_inputs: list[Any], + compiler_config: dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, +) -> tuple[Optional[Callable], Optional[Any]]: + # When currently using the FULL_DECODE_ONLY mode, + # the piecewise compilation level slicing process + # in vllm is also encountered. + # This process causes the output to no longer be + # wrapped as a tuple when the fx graph has a single + # output, but torch.compile has a mandatory check. + fx_graph = graph.graph + if not graph_returns_tuple(graph): + output_node = fx_graph.output_node() + with fx_graph.inserting_before(output_node): + return_value = output_node.args[0] + tuple_node = fx_graph.create_node("call_function", + tuple, + args=([return_value], )) + output_node.args = (tuple_node, ) + fx_graph.recompile() + + import torchair + + # TODO: use a better way to lazy register replacement, instead of import one by one + # As an example, we directly import here to register replacement. + import vllm_ascend.compilation.npugraph_ex_passes.add_rms_norm_quant # noqa + + torch.npu.set_compile_mode(jit_compile=False) + config = torchair.CompilerConfig() + # use aclgraph mode, avoid the transformation from fx graph to Ascend IR. + config.mode = "reduce-overhead" + # execute FX graph in eager mode before graph mode to optimize FX graph. + config.debug.run_eagerly = True + # static kernel switch, suitable for static shapes or scenes with less shape changes. + config.experimental_config.aclgraph._aclnn_static_shape_kernel = True + + npugraph_ex = torchair.get_npu_backend(compiler_config=config) + compile_graph = npugraph_ex(graph, example_inputs) + return compile_graph, None + + class AscendCompiler(CompilerInterface): """ AscendCompiler is a custom compiler interface for the Ascend platform. @@ -56,18 +128,10 @@ class AscendCompiler(CompilerInterface): key: Optional[str] = None, ) -> tuple[Optional[Callable], Optional[Any]]: - def compile_inner(graph, example_inputs): - current_pass_manager = compiler_config["graph_fusion_manager"] - graph = current_pass_manager(graph, runtime_shape) - return graph - - decompositions = select_decomp_table() - - compiled_fn = compile_fx( - graph=graph, - example_inputs=example_inputs, - inner_compile=compile_inner, - decompositions=decompositions, - ) - - return compiled_fn, None + ascend_config = get_ascend_config() + if ascend_config.enable_npugraph_ex: + return npugraph_ex_compile(graph, example_inputs, compiler_config, + runtime_shape, key) + else: + return fusion_pass_compile(graph, example_inputs, compiler_config, + runtime_shape, key) diff --git a/vllm_ascend/compilation/npugraph_ex_passes/__init__.py b/vllm_ascend/compilation/npugraph_ex_passes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py new file mode 100644 index 00000000..724d8140 --- /dev/null +++ b/vllm_ascend/compilation/npugraph_ex_passes/add_rms_norm_quant.py @@ -0,0 +1,123 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import functools +import sys + +import torch +from torch._inductor.pattern_matcher import Match +from vllm.logger import logger + + +@functools.lru_cache(None) +# The replacement registered here will be actually executed after AOT. +def _register_replacement(epsilon): + if 'torch_npu' not in sys.modules: + logger.info( + 'The AddRMSNormQuant fusion will only be enabled in a torch npu env.' + 'When there is no torch_npu in the env, skip fusion.') + return + + def _extra_stream_scope_check(match: Match) -> bool: + """ + Checks if all nodes in the same stream. + """ + non_default_streams = set() + has_default = False + + for node in match.nodes: + if node.op == "call_function": + current_stream = node.meta.get("stream_label") + if current_stream is None: + has_default = True + else: + non_default_streams.add(current_stream) + if len(non_default_streams) > 1: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations." + ) + return False + + if has_default and len(non_default_streams) > 0: + logger.debug( + f"Cross-stream operation detected in pattern match for AddRMSNormQuant. " + f"Multiple streams found: {non_default_streams}. " + f"Fusion is not supported for cross-stream operations.") + return False + + return True + + def pattern(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Pattern for AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm(rms_norm_input, residual, + rms_norm_weight, epsilon) + out0 = output[0] + out1 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, + torch.qint8, -1, False) + return quantized_output, out1 + + def replacement(rms_norm_input: torch.Tensor, residual: torch.Tensor, + rms_norm_weight: torch.Tensor, scale: torch.Tensor, + offset: torch.Tensor): + """ + Replacement for the AddRMSNormQuant fusion. + """ + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, + residual, + rms_norm_weight, + # The inverse of scale is required by npu_add_rms_norm_quant kernel which is opposite to the npu_quantize kernel. + 1. / scale, + offset, + epsilon=epsilon) + quantized_output = output[0] + out1 = output[2] + return quantized_output, out1 + + def get_inputs(): + """ + Generate example inputs for the AddRMSNormQuant fusion pattern. + """ + rms_norm_input = torch.randn(2, 4, device="npu") + residual = torch.randn(2, 4, device="npu") + rms_norm_weight = torch.randn(4, device="npu") + scale = torch.tensor([1.0], device="npu") + offset = torch.tensor([0.0], device="npu") + return [rms_norm_input, residual, rms_norm_weight, scale, offset] + + import torchair + + torchair.register_replacement(search_fn=pattern, + replace_fn=replacement, + example_inputs=get_inputs(), + extra_check=_extra_stream_scope_check) + + +# register converter for pass +common_epsilons = [1e-5, 1e-6] +for eps in common_epsilons: + logger.info( + f"Start register fusion pattern for AddRMSNormQuant with epsilons={eps}" + ) + _register_replacement(eps) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index c9abc8ff..8a508acf 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -231,6 +231,7 @@ class NPUPlatform(Platform): if compilation_config.cudagraph_mode == CUDAGraphMode.NONE: compilation_config.mode = CompilationMode.NONE + ascend_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE: logger.info( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " @@ -241,12 +242,14 @@ class NPUPlatform(Platform): compilation_config.use_inductor = False compilation_config.splitting_ops.extend(["vllm::mla_forward"]) update_aclgraph_sizes(vllm_config) + ascend_config.enable_npugraph_ex = False elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY or\ compilation_config.cudagraph_mode == CUDAGraphMode.FULL: logger.info( "FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - " "using only ACL Graph mode") compilation_config.use_inductor = False + compilation_config.splitting_ops = [] warning_message = """\033[91m ********************************************************************************** * WARNING: You have enabled the *full graph* feature. @@ -266,6 +269,7 @@ class NPUPlatform(Platform): compilation_config.cudagraph_mode) compilation_config.cudagraph_mode = CUDAGraphMode.NONE compilation_config.mode = CompilationMode.NONE + ascend_config.enable_npugraph_ex = False # TODO: Remove this check when ACL Graph supports ASCEND_LAUNCH_BLOCKING=1 # Then, we will have to discuss the error handling strategy and user experience diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b9e23334..bfb9a510 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2575,6 +2575,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): self.debugger.stop() self.debugger.step() return pool_output + # Sometimes, after the model is compiled through the AOT backend, + # the model output may become a list containing only one Tensor object. + if isinstance(hidden_states, list) and \ + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): + hidden_states = hidden_states[0] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: @@ -3296,6 +3302,12 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task + # Sometimes, after the model is compiled through the AOT backend, + # the model output may become a list containing only one Tensor object. + if isinstance(hidden_states, list) and \ + len(hidden_states) == 1 and \ + isinstance(hidden_states[0], torch.Tensor): + hidden_states = hidden_states[0] hidden_states = hidden_states[logit_indices] output = self.model.compute_logits(hidden_states)