### What this PR does / why we need it? Revert PR 5253 to fix the smoking problem ### Does this PR introduce _any_ user-facing change? Does not. ### How was this patch tested? It was tested in the failure case. Signed-off-by: Rifa <865071616@qq.com>
This commit is contained in:
@@ -26,7 +26,6 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch.fx import GraphModule
|
||||
from vllm.compilation.compiler_interface import CompilerInterface
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
@@ -47,13 +46,13 @@ def fusion_pass_compile(
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
def compile_inner(graph, example_inputs):
|
||||
current_pass_manager = compiler_config["graph_fusion_manager"]
|
||||
graph = current_pass_manager(graph, compile_range)
|
||||
graph = current_pass_manager(graph, runtime_shape)
|
||||
return graph
|
||||
|
||||
decompositions = select_decomp_table()
|
||||
@@ -72,7 +71,7 @@ def npugraph_ex_compile(
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
# When currently using the FULL_DECODE_ONLY mode,
|
||||
@@ -125,14 +124,14 @@ class AscendCompiler(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_npugraph_ex:
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
runtime_shape, key)
|
||||
else:
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
||||
compile_range, key)
|
||||
runtime_shape, key)
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#
|
||||
|
||||
from torch import fx as fx
|
||||
from vllm.compilation.inductor_pass import get_pass_context
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@@ -33,13 +32,10 @@ class GraphFusionPassManager:
|
||||
def __init__(self):
|
||||
self.passes: list[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph:
|
||||
compile_range = get_pass_context().compile_range
|
||||
|
||||
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
if pass_.is_applicable(runtime_shape):
|
||||
pass_(graph)
|
||||
graph.recompile()
|
||||
return graph
|
||||
|
||||
def add(self, pass_: VllmInductorPass):
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@@ -309,7 +308,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
def is_applicable(self, runtime_shape: int | None = None) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
|
||||
@@ -22,7 +22,6 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@@ -284,7 +283,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
pattern_idx += 1
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
def is_applicable(self, runtime_shape):
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user