[bugfix][npugraph_ex]fix the model output type issue caused by manually modify FX graph (#6015)
### What this PR does / why we need it?
When using the full_decode_only mode, the vllm framework will still use
the torch.fx.passes.split_module.split_module API to process the
corresponding GraphModule of the model.
However, the output of this API may cause the output of the fx graph to
no longer be a tuple, and torch.compile enforces strict checks on this.
Previously, we manually modified the fx graph, which introduced an
abnormality in the model output type.
In this PR, we switched to using PyTorch's native API to modify the FX
graph, and removed the code that was previously added to handle output
type anomalies.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2c24bc6996
---------
Signed-off-by: chencangtao <chencangtao@huawei.com>
Co-authored-by: chencangtao <chencangtao@huawei.com>
This commit is contained in:
@@ -74,20 +74,6 @@ def npugraph_ex_compile(
|
|||||||
compile_range: Range,
|
compile_range: Range,
|
||||||
key: str | None = None,
|
key: str | None = None,
|
||||||
) -> tuple[Callable | None, Any | None]:
|
) -> tuple[Callable | None, Any | None]:
|
||||||
# When currently using the FULL_DECODE_ONLY mode,
|
|
||||||
# the piecewise compilation level slicing process
|
|
||||||
# in vllm is also encountered.
|
|
||||||
# This process causes the output to no longer be
|
|
||||||
# wrapped as a tuple when the fx graph has a single
|
|
||||||
# output, but torch.compile has a mandatory check.
|
|
||||||
fx_graph = graph.graph
|
|
||||||
if not graph_returns_tuple(graph):
|
|
||||||
output_node = fx_graph.output_node()
|
|
||||||
with fx_graph.inserting_before(output_node):
|
|
||||||
return_value = output_node.args[0]
|
|
||||||
tuple_node = fx_graph.create_node("call_function", tuple, args=([return_value],))
|
|
||||||
output_node.args = (tuple_node,)
|
|
||||||
graph.recompile()
|
|
||||||
import torchair
|
import torchair
|
||||||
|
|
||||||
# TODO: use a better way to lazy register replacement, instead of import one by one
|
# TODO: use a better way to lazy register replacement, instead of import one by one
|
||||||
@@ -118,8 +104,10 @@ def npugraph_ex_compile(
|
|||||||
|
|
||||||
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
|
npugraph_ex = torchair.get_npu_backend(compiler_config=config)
|
||||||
|
|
||||||
compile_graph = npugraph_ex(graph, example_inputs)
|
# torch.compile requires the output of the fx graph to be a tuple
|
||||||
return compile_graph, None
|
if not graph_returns_tuple(graph):
|
||||||
|
return make_graph_return_tuple(graph, example_inputs, npugraph_ex), None
|
||||||
|
return npugraph_ex(graph, example_inputs), None
|
||||||
|
|
||||||
|
|
||||||
class AscendCompiler(CompilerInterface):
|
class AscendCompiler(CompilerInterface):
|
||||||
|
|||||||
@@ -1586,12 +1586,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
self.debugger.stop()
|
self.debugger.stop()
|
||||||
self.debugger.step()
|
self.debugger.step()
|
||||||
return pool_output
|
return pool_output
|
||||||
# Sometimes, after the model is compiled through the AOT backend,
|
|
||||||
# the model output may become a list containing only one Tensor object.
|
|
||||||
if isinstance(hidden_states, list) and \
|
|
||||||
len(hidden_states) == 1 and \
|
|
||||||
isinstance(hidden_states[0], torch.Tensor):
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
if broadcast_pp_output:
|
if broadcast_pp_output:
|
||||||
@@ -2300,12 +2294,6 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||||
# TODO: need to rum a dummy sampler for generate task
|
# TODO: need to rum a dummy sampler for generate task
|
||||||
# Sometimes, after the model is compiled through the AOT backend,
|
|
||||||
# the model output may become a list containing only one Tensor object.
|
|
||||||
if isinstance(hidden_states, list) and \
|
|
||||||
len(hidden_states) == 1 and \
|
|
||||||
isinstance(hidden_states[0], torch.Tensor):
|
|
||||||
hidden_states = hidden_states[0]
|
|
||||||
hidden_states = hidden_states[logit_indices]
|
hidden_states = hidden_states[logit_indices]
|
||||||
output = self.model.compute_logits(hidden_states)
|
output = self.model.compute_logits(hidden_states)
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user