[bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (#6015)

### What this PR does / why we need it?

When using the full_decode_only mode, the vllm framework will still use
the torch.fx.passes.split_module.split_module API to process the
corresponding GraphModule of the model.
However, the output of this API may cause the output of the fx graph to
no longer be a tuple, and torch.compile enforces strict checks on this.
Previously, we manually modified the fx graph, which introduced an
abnormality in the model output type.
In this PR, we switched to using PyTorch's native API to modify the FX
graph, and removed the code that was previously added to handle output
type anomalies.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2c24bc6996

---------

Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
ChenCangtao
2026-01-22 12:35:06 +08:00
committed by GitHub
parent 34fb628248
commit 38edfd585a
2 changed files with 6 additions and 30 deletions

View File

@@ -1586,12 +1586,6 @@ class NPUModelRunner(GPUModelRunner):
self.debugger.stop()
self.debugger.step()
return pool_output
# Sometimes, after the model is compiled through the AOT backend,
# the model output may become a list containing only one Tensor object.
if isinstance(hidden_states, list) and \
len(hidden_states) == 1 and \
isinstance(hidden_states[0], torch.Tensor):
hidden_states = hidden_states[0]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if broadcast_pp_output:
@@ -2300,14 +2294,8 @@ class NPUModelRunner(GPUModelRunner):
dtype=np.int32)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
# TODO: need to rum a dummy sampler for generate task
# Sometimes, after the model is compiled through the AOT backend,
# the model output may become a list containing only one Tensor object.
if isinstance(hidden_states, list) and \
len(hidden_states) == 1 and \
isinstance(hidden_states[0], torch.Tensor):
hidden_states = hidden_states[0]
hidden_states = hidden_states[logit_indices]
output = self.model.compute_logits(hidden_states)
hidden_states = hidden_states[logit_indices]
output = self.model.compute_logits(hidden_states)
return output
def profile_run(self) -> None: