Revert "[BugFix][Fusion] Fix graph fusion failure problem (#5253)" (#5667)

### What this PR does / why we need it?

Revert PR 5253 to fix the smoking problem

### Does this PR introduce _any_ user-facing change?

Does not.

### How was this patch tested?

It was tested in the failure case.

Signed-off-by: Rifa <865071616@qq.com>
This commit is contained in:
Fager10086
2026-01-06 21:55:47 +08:00
committed by GitHub
parent 330e25ab1d
commit 77a029979e
9 changed files with 267 additions and 36 deletions

View File

@@ -1020,6 +1020,8 @@ class TestNPUWorker(TestBase):
# Verify _dummy_run call count and order (by size descending)
expected_calls = [
unittest.mock.call(16),
unittest.mock.call(8),
unittest.mock.call(4),
unittest.mock.call(1),
]
worker.model_runner._dummy_run.assert_has_calls(expected_calls)
@@ -1028,7 +1030,7 @@ class TestNPUWorker(TestBase):
worker.model_runner.capture_model.assert_not_called()
# Verify log output
self.assertEqual(mock_logger.info.call_count, 2)
self.assertEqual(mock_logger.info.call_count, 4)
# Verify seed setting
mock_seed_everything.assert_called_once_with(12345)

View File

@@ -26,7 +26,6 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
from torch._inductor.decomposition import select_decomp_table
from torch.fx import GraphModule
from vllm.compilation.compiler_interface import CompilerInterface
from vllm.config.utils import Range
from vllm_ascend.ascend_config import get_ascend_config
@@ -47,13 +46,13 @@ def fusion_pass_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
def compile_inner(graph, example_inputs):
current_pass_manager = compiler_config["graph_fusion_manager"]
graph = current_pass_manager(graph, compile_range)
graph = current_pass_manager(graph, runtime_shape)
return graph
decompositions = select_decomp_table()
@@ -72,7 +71,7 @@ def npugraph_ex_compile(
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
# When currently using the FULL_DECODE_ONLY mode,
@@ -125,14 +124,14 @@ class AscendCompiler(CompilerInterface):
graph: fx.GraphModule,
example_inputs: list[Any],
compiler_config: dict[str, Any],
compile_range: Range,
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> tuple[Optional[Callable], Optional[Any]]:
ascend_config = get_ascend_config()
if ascend_config.enable_npugraph_ex:
return npugraph_ex_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)
else:
return fusion_pass_compile(graph, example_inputs, compiler_config,
compile_range, key)
runtime_shape, key)

View File

@@ -17,7 +17,6 @@
#
from torch import fx as fx
from vllm.compilation.inductor_pass import get_pass_context
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
@@ -33,13 +32,10 @@ class GraphFusionPassManager:
def __init__(self):
self.passes: list[VllmInductorPass] = []
def __call__(self, graph: fx.Graph, compile_range) -> fx.Graph:
compile_range = get_pass_context().compile_range
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
for pass_ in self.passes:
if pass_.is_applicable_for_range(compile_range):
if pass_.is_applicable(runtime_shape):
pass_(graph)
graph.recompile()
return graph
def add(self, pass_: VllmInductorPass):

View File

@@ -20,7 +20,6 @@ import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import PatternMatcherPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.logger import logger
@@ -309,7 +308,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
logger.debug("Replaced %s patterns", self.matched_count)
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
def is_applicable(self, runtime_shape: int | None = None) -> bool:
"""
Check if the pass is applicable for the current configuration.
"""

View File

@@ -22,7 +22,6 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass,
from vllm.attention.layer import Attention
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.compilation import Range
from vllm.logger import logger
@@ -284,7 +283,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
pattern_idx += 1
self.end_and_log()
def is_applicable_for_range(self, compile_range: Range) -> bool:
def is_applicable(self, runtime_shape):
"""
Check if the pass is applicable for the current configuration.
"""

View File

@@ -106,6 +106,20 @@
# Future Plan:
# Remove this patch when vLLM merge the PR.
#
# ** 7. File: platform/patch_compile_backend.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
# Why:
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
# This issue affects the performance of model inference on Ascend.
# How
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/24252
# Future Plan:
# Remove this patch when fix the problem.
#
# * Worker Patch:
# ===============
#

View File

@@ -16,6 +16,7 @@
import os
import vllm_ascend.patch.platform.patch_compile_backend # noqa
import vllm_ascend.patch.platform.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_ec_connector # noqa
import vllm_ascend.patch.platform.patch_mamba_config # noqa

View File

@@ -0,0 +1,235 @@
from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
import vllm.compilation.backends
import vllm.compilation.piecewise_backend
from torch._dispatch.python import enable_python_dispatcher
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.piecewise_backend import RangeEntry
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""
def __init__(
self,
module: torch.fx.GraphModule,
compile_submod_names: list[str],
vllm_config: VllmConfig,
vllm_backend: "VllmBackend",
):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
def run(self, *args):
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
def call_module(
self,
target: torch.fx.node.Target,
args: tuple[torch.fx.node.Argument, ...],
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
r1 = Range(start=1, end=max_num_batched_tokens)
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
compile_range=r1,
))
# Lazy import here to avoid circular import
from vllm.compilation.piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend(
submod,
self.vllm_config,
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)
if (self.compilation_config.cudagraph_mode.
has_piecewise_cudagraphs() and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from vllm.compilation.cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
),
)
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
class AscendPiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
"""
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
self.is_encoder_compilation = vllm_backend.is_encoder
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (last_compile_range.end ==
vllm_config.scheduler_config.max_num_batched_tokens)
self.compile_ranges[-1] = Range(start=last_compile_range.start,
end=max_int32)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape,
end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]")
return self.compiled_graph_for_general_shape(*args)
vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__

View File

@@ -27,7 +27,7 @@ import torch_npu
import vllm.envs as envs_vllm
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
from torch_npu.profiler import dynamic_profile as dp
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
@@ -372,25 +372,11 @@ class NPUWorker(WorkerBase):
self.model_runner.eplb_warmup()
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
or []).copy()
cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in cg_capture_sizes
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges(
)
# For each compile_range, if none of the batch sizes
# in warmup_sizes or cudagraph_capture_sizes are in the range,
# add the end of the range to ensure compilation/warmup.
all_sizes = set(cg_capture_sizes)
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
for compile_range in compile_ranges:
if not any(x in compile_range for x in all_sizes):
warmup_sizes.append(compile_range.end)
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)