[BugFix][Fusion] Fix graph fusion failure problem (#5676)
Currently, the vllm pull request
(https://github.com/vllm-project/vllm/pull/24252) is causing operator
fusion to fail. This issue was previously fixed by patching the backend.
The root cause has been identified, and the problem can be resolved with
this pull request.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#
|
||||
|
||||
from torch import fx as fx
|
||||
from vllm.compilation.inductor_pass import get_pass_context
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@@ -32,10 +33,13 @@ class GraphFusionPassManager:
|
||||
def __init__(self):
|
||||
self.passes: list[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
|
||||
def __call__(self, graph: fx.Graph) -> fx.Graph:
|
||||
compile_range = get_pass_context().compile_range
|
||||
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable(runtime_shape):
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
graph.recompile()
|
||||
return graph
|
||||
|
||||
def add(self, pass_: VllmInductorPass):
|
||||
|
||||
Reference in New Issue
Block a user