[BugFix][Fusion] Fix graph fusion failure problem (#5676)

Currently, the vllm pull request
(https://github.com/vllm-project/vllm/pull/24252) is causing operator
fusion to fail. This issue was previously fixed by patching the backend.
The root cause has been identified, and the problem can be resolved with
this pull request.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
Icey
2026-01-07 18:42:55 +08:00
committed by GitHub
parent 137f28341d
commit b94fc13d3f
8 changed files with 37 additions and 265 deletions

View File

@@ -26,6 +26,7 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import COMPILATION_PASS_KEY
@@ -47,13 +48,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
graph = current_pass_manager(graph, runtime_shape)
graph = current_pass_manager(graph)
return graph
decompositions = select_decomp_table()
@@ -72,7 +73,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
@@ -125,14 +126,14 @@ class AscendCompiler(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
runtime_shape: Optional[int] = None,
compile_range: Range,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
runtime_shape, key)
compile_range, key)

View File

@@ -17,6 +17,7 @@
#
from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
@@ -32,10 +33,13 @@ class GraphFusionPassManager:
def __init__(self):
self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
def __call__(self, graph: fx.Graph) -> fx.Graph:
compile_range = get_pass_context().compile_range
for pass_ in self.passes:
if pass_.is_applicable(runtime_shape):
if pass_.is_applicable_for_range(compile_range):
pass_(graph)
graph.recompile()
return graph
def add(self, pass_: VllmInductorPass):

View File

@@ -20,6 +20,7 @@ import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
@@ -308,7 +309,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()
def is_applicable(self, runtime_shape: int | None = None) -> bool:
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""

View File

@@ -22,6 +22,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass,
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger
@@ -283,7 +284,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
pattern_idx += 1
self.end_and_log()
def is_applicable(self, runtime_shape):
def is_applicable_for_range(self, compile_range: Range) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""