Revert "[BugFix][Fusion] Fix graph fusion failure problem (#5253)" (#5667)

### What this PR does / why we need it?

Revert PR 5253 to fix the smoking problem

### Does this PR introduce _any_ user-facing change?

Does not.

### How was this patch tested?

It was tested in the failure case.

Signed-off-by: Rifa <865071616@qq.com>
This commit is contained in:
Fager10086
2026-01-06 21:55:47 +08:00
committed by GitHub
parent 330e25ab1d
commit 77a029979e
9 changed files with 267 additions and 36 deletions

View File

@@ -26,7 +26,6 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
@@ -47,13 +46,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, compile_range)
graph = current_pass_manager(graph, runtime_shape)
return graph
decompositions = select_decomp_table()
@@ -72,7 +71,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
@@ -125,14 +124,14 @@ class AscendCompiler(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)