From 38edfd585a3773398fa89a16b76a2f55c22d59c6 Mon Sep 17 00:00:00 2001 From: ChenCangtao <50493711+ChenCangtao@users.noreply.github.com> Date: Thu, 22 Jan 2026 12:35:06 +0800 Subject: [PATCH] [bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (#6015) ### What this PR does / why we need it? When using the full_decode_only mode, the vllm framework will still use the torch.fx.passes.split_module.split_module API to process the corresponding GraphModule of the model. However, the output of this API may cause the output of the fx graph to no longer be a tuple, and torch.compile enforces strict checks on this. Previously, we manually modified the fx graph, which introduced an abnormality in the model output type. In this PR, we switched to using PyTorch's native API to modify the FX graph, and removed the code that was previously added to handle output type anomalies. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2c24bc6996cb165fce92f780b388a5e39b3f4060 --------- Signed-off-by: chencangtao Co-authored-by: chencangtao --- vllm_ascend/compilation/compiler_interface.py | 20 ++++--------------- vllm_ascend/worker/model_runner_v1.py | 16 ++------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/vllm_ascend/compilation/compiler_interface.py b/vllm_ascend/compilation/compiler_interface.py index b9917bb1..0c989539 100644 --- a/vllm_ascend/compilation/compiler_interface.py +++ b/vllm_ascend/compilation/compiler_interface.py @@ -74,20 +74,6 @@ def npugraph_ex_compile( compile_range: Range, key: str | None = None, ) -> tuple[Callable | None, Any | None]: - # When currently using the FULL_DECODE_ONLY mode, - # the piecewise compilation level slicing process - # in vllm is also encountered. - # This process causes the output to no longer be - # wrapped as a tuple when the fx graph has a single - # output, but torch.compile has a mandatory check. - fx_graph = graph.graph - if not graph_returns_tuple(graph): - output_node = fx_graph.output_node() - with fx_graph.inserting_before(output_node): - return_value = output_node.args[0] - tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],)) - output_node.args = (tuple_node,) - graph.recompile() import torchair # TODO: use a better way to lazy register replacement, instead of import one by one @@ -118,8 +104,10 @@ def npugraph_ex_compile( npugraph_ex = torchair.get_npu_backend(compiler_config=config) - compile_graph = npugraph_ex(graph, example_inputs) - return compile_graph, None + # torch.compile requires the output of the fx graph to be a tuple + if not graph_returns_tuple(graph): + return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None + return npugraph_ex(graph, example_inputs), None class AscendCompiler(CompilerInterface): diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 135b9a9d..c8dab784 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1586,12 +1586,6 @@ class NPUModelRunner(GPUModelRunner): self.debugger.stop() self.debugger.step() return pool_output - # Sometimes, after the model is compiled through the AOT backend, - # the model output may become a list containing only one Tensor object. - if isinstance(hidden_states, list) and \ - len(hidden_states) == 1 and \ - isinstance(hidden_states[0], torch.Tensor): - hidden_states = hidden_states[0] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) if broadcast_pp_output: @@ -2300,14 +2294,8 @@ class NPUModelRunner(GPUModelRunner): dtype=np.int32) logit_indices = np.cumsum(num_scheduled_tokens) - 1 # TODO: need to rum a dummy sampler for generate task - # Sometimes, after the model is compiled through the AOT backend, - # the model output may become a list containing only one Tensor object. - if isinstance(hidden_states, list) and \ - len(hidden_states) == 1 and \ - isinstance(hidden_states[0], torch.Tensor): - hidden_states = hidden_states[0] - hidden_states = hidden_states[logit_indices] - output = self.model.compute_logits(hidden_states) + hidden_states = hidden_states[logit_indices] + output = self.model.compute_logits(hidden_states) return output def profile_run(self) -> None: