[BugFix][Fusion] Fix graph fusion failure problem (#5676)
Currently, the vllm pull request
(https://github.com/vllm-project/vllm/pull/24252) is causing operator
fusion to fail. This issue was previously fixed by patching the backend.
The root cause has been identified, and the problem can be resolved with
this pull request.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from torch._inductor.compile_fx import (graph_returns_tuple,
|
||||
from torch._inductor.decomposition import select_decomp_table
|
||||
from torch.fx import GraphModule
|
||||
from vllm.compilation.compiler_interface import CompilerInterface
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import COMPILATION_PASS_KEY
|
||||
@@ -47,13 +48,13 @@ def fusion_pass_compile(
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
def compile_inner(graph, example_inputs):
|
||||
current_pass_manager = compiler_config[COMPILATION_PASS_KEY]
|
||||
graph = current_pass_manager(graph, runtime_shape)
|
||||
graph = current_pass_manager(graph)
|
||||
return graph
|
||||
|
||||
decompositions = select_decomp_table()
|
||||
@@ -72,7 +73,7 @@ def npugraph_ex_compile(
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
# When currently using the FULL_DECODE_ONLY mode,
|
||||
@@ -125,14 +126,14 @@ class AscendCompiler(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
compile_range: Range,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.enable_npugraph_ex:
|
||||
return npugraph_ex_compile(graph, example_inputs, compiler_config,
|
||||
runtime_shape, key)
|
||||
compile_range, key)
|
||||
else:
|
||||
return fusion_pass_compile(graph, example_inputs, compiler_config,
|
||||
runtime_shape, key)
|
||||
compile_range, key)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#
|
||||
|
||||
from torch import fx as fx
|
||||
from vllm.compilation.inductor_pass import get_pass_context
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
@@ -32,10 +33,13 @@ class GraphFusionPassManager:
|
||||
def __init__(self):
|
||||
self.passes: list[VllmInductorPass] = []
|
||||
|
||||
def __call__(self, graph: fx.Graph, runtime_shape) -> fx.Graph:
|
||||
def __call__(self, graph: fx.Graph) -> fx.Graph:
|
||||
compile_range = get_pass_context().compile_range
|
||||
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable(runtime_shape):
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
graph.recompile()
|
||||
return graph
|
||||
|
||||
def add(self, pass_: VllmInductorPass):
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch._inductor.pattern_matcher as pm
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@@ -308,7 +309,7 @@ class AddRMSNormQuantFusionPass(VllmInductorPass):
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable(self, runtime_shape: int | None = None) -> bool:
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
|
||||
@@ -22,6 +22,7 @@ from torch._inductor.pattern_matcher import (PatternMatcherPass,
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.compilation import Range
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
@@ -283,7 +284,7 @@ class QKNormRopeFusionPass(VllmInductorPass):
|
||||
pattern_idx += 1
|
||||
self.end_and_log()
|
||||
|
||||
def is_applicable(self, runtime_shape):
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Check if the pass is applicable for the current configuration.
|
||||
"""
|
||||
|
||||
@@ -106,20 +106,6 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merge the PR.
|
||||
#
|
||||
# ** 7. File: platform/patch_compile_backend.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.compilation.backends.PiecewiseCompileInterpreter`
|
||||
# `vllm.compilation.piecewise_backend.PiecewiseBackend`
|
||||
# Why:
|
||||
# vllm removed the compile graph for general shape, which caused operator fusion to fail.
|
||||
# This issue affects the performance of model inference on Ascend.
|
||||
# How:
|
||||
# recover the compiled graph for dynamic_shape in PiecewiseBackend.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/24252
|
||||
# Future Plan:
|
||||
# Remove this patch when fix the problem.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
import os
|
||||
|
||||
import vllm_ascend.patch.platform.patch_compile_backend # noqa
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_ec_connector # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import vllm.compilation.backends
|
||||
import vllm.compilation.piecewise_backend
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.piecewise_backend import RangeEntry
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
|
||||
NOTE: the order in `compile_submod_names` matters, because
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
vllm_backend: "VllmBackend",
|
||||
):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
self.extra_traceback = False
|
||||
|
||||
def run(self, *args):
|
||||
# maybe instead just assert inputs are fake?
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
with self.fake_mode, enable_python_dispatcher():
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(
|
||||
self,
|
||||
target: torch.fx.node.Target,
|
||||
args: tuple[torch.fx.node.Argument, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
assert isinstance(target, str)
|
||||
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
index = self.compile_submod_names.index(target)
|
||||
submod = self.fetch_attr(target)
|
||||
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
||||
r1 = Range(start=1, end=max_num_batched_tokens)
|
||||
compiled_graph_for_dynamic_shape = (
|
||||
self.vllm_backend.compiler_manager.compile(
|
||||
submod,
|
||||
args,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
graph_index=index,
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
compile_range=r1,
|
||||
))
|
||||
|
||||
# Lazy import here to avoid circular import
|
||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod,
|
||||
self.vllm_config,
|
||||
index,
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape,
|
||||
self.vllm_backend,
|
||||
)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode.
|
||||
has_piecewise_cudagraphs() and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls())
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AscendPiecewiseBackend:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
"""
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
self.is_encoder_compilation = vllm_backend.is_encoder
|
||||
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
if self.is_encoder_compilation:
|
||||
# For encoder compilation we use the max int32 value
|
||||
# to set the upper bound of the compile ranges
|
||||
max_int32 = 2**31 - 1
|
||||
last_compile_range = self.compile_ranges[-1]
|
||||
assert (last_compile_range.end ==
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.compile_ranges[-1] = Range(start=last_compile_range.start,
|
||||
end=max_int32)
|
||||
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
for size in self.compile_sizes:
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(compile_range=range, )
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(compile_range=range, )
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape,
|
||||
end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape out of considered range: {runtime_shape} "
|
||||
"[1, max_num_batched_tokens]")
|
||||
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
|
||||
vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
|
||||
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
|
||||
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__
|
||||
@@ -28,7 +28,7 @@ import torch_npu
|
||||
import vllm.envs as envs_vllm
|
||||
from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
||||
from torch_npu.profiler import dynamic_profile as dp
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config import CUDAGraphMode, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.ec_transfer import ensure_ec_transfer_initialized
|
||||
@@ -381,10 +381,25 @@ class NPUWorker(WorkerBase):
|
||||
warmup_sizes = (self.vllm_config.compilation_config.compile_sizes
|
||||
or []).copy()
|
||||
if not self.model_config.enforce_eager:
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
]
|
||||
cg_capture_sizes: list[int] = []
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
cg_sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes
|
||||
cg_capture_sizes = [] if cg_sizes is None else cg_sizes
|
||||
warmup_sizes = [
|
||||
x for x in warmup_sizes if x not in cg_capture_sizes
|
||||
]
|
||||
|
||||
compile_ranges = self.vllm_config.compilation_config.get_compile_ranges(
|
||||
)
|
||||
# For each compile_range, if none of the batch sizes
|
||||
# in warmup_sizes or cudagraph_capture_sizes are in the range,
|
||||
# add the end of the range to ensure compilation/warmup.
|
||||
all_sizes = set(cg_capture_sizes)
|
||||
all_sizes.update([x for x in warmup_sizes if isinstance(x, int)])
|
||||
for compile_range in compile_ranges:
|
||||
if not any(x in compile_range for x in all_sizes):
|
||||
warmup_sizes.append(compile_range.end)
|
||||
|
||||
for size in sorted(warmup_sizes, reverse=True):
|
||||
logger.info("Compile and warming up model for size %d", size)
|
||||
self.model_runner._dummy_run(size)
|
||||
|
||||
Reference in New Issue
Block a user