Files
xc-llm-ascend/vllm_ascend/patch/platform/patch_compile_backend.py
Icey 9b2a7d8866 [BugFix][Fusion] Patch compile backend to make fusion available (#5308)
Currently, the vllm pr: https://github.com/vllm-project/vllm/pull/24252
is causing operator fusion to fail, which can be mitigated by patching
the backend. Once the problem is completely resolved, I will submit a
new pull request to remove the patch.

- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
---------
Signed-off-by: wxsIcey <1790571317@qq.com>
2025-12-26 09:18:16 +08:00

236 lines
9.5 KiB
Python

from collections.abc import Callable
from typing import Any
import torch
import torch.fx as fx
import vllm.compilation.backends
import vllm.compilation.piecewise_backend
from torch._dispatch.python import enable_python_dispatcher
from vllm.compilation.backends import VllmBackend
from vllm.compilation.counter import compilation_counter
from vllm.compilation.piecewise_backend import RangeEntry
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
"""
def __init__(
self,
module: torch.fx.GraphModule,
compile_submod_names: list[str],
vllm_config: VllmConfig,
vllm_backend: "VllmBackend",
):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False
def run(self, *args):
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
def call_module(
self,
target: torch.fx.node.Target,
args: tuple[torch.fx.node.Argument, ...],
kwargs: dict[str, Any],
) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)
if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
r1 = Range(start=1, end=max_num_batched_tokens)
compiled_graph_for_dynamic_shape = (
self.vllm_backend.compiler_manager.compile(
submod,
args,
self.vllm_backend.inductor_config,
self.compilation_config,
graph_index=index,
num_graphs=len(self.compile_submod_names),
compile_range=r1,
))
# Lazy import here to avoid circular import
from vllm.compilation.piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend(
submod,
self.vllm_config,
index,
len(self.compile_submod_names),
sym_shape_indices,
compiled_graph_for_dynamic_shape,
self.vllm_backend,
)
if (self.compilation_config.cudagraph_mode.
has_piecewise_cudagraphs() and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from vllm.compilation.cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
),
)
else:
self.module.__dict__[target] = piecewise_backend
compilation_counter.num_piecewise_capturable_graphs_seen += 1
return output
class AscendPiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
vllm_backend: VllmBackend,
):
"""
The backend for piecewise compilation.
It mainly handles the compilation of static shapes and
dispatching based on runtime shape.
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
"""
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
self.is_full_graph = total_piecewise_compiles == 1
self.is_encoder_compilation = vllm_backend.is_encoder
self.compile_ranges = self.compilation_config.get_compile_ranges()
if self.is_encoder_compilation:
# For encoder compilation we use the max int32 value
# to set the upper bound of the compile ranges
max_int32 = 2**31 - 1
last_compile_range = self.compile_ranges[-1]
assert (last_compile_range.end ==
vllm_config.scheduler_config.max_num_batched_tokens)
self.compile_ranges[-1] = Range(start=last_compile_range.start,
end=max_int32)
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
logger.debug_once(log_string)
self.compile_sizes = self.compilation_config.compile_sizes
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
for size in self.compile_sizes:
range = Range(start=size, end=size)
if range not in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(compile_range=range, )
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
# First we try to find the range entry for the concrete compile size
# If not found, we search for the range entry
# that contains the runtime shape.
if runtime_shape in self.compile_sizes:
return self.range_entries[Range(start=runtime_shape,
end=runtime_shape)]
else:
for range in self.compile_ranges:
if runtime_shape in range:
return self.range_entries[range]
return None
def __call__(self, *args) -> Any:
runtime_shape = args[self.sym_shape_indices[0]]
range_entry = self._find_range_for_shape(runtime_shape)
assert range_entry is not None, (
f"Shape out of considered range: {runtime_shape} "
"[1, max_num_batched_tokens]")
return self.compiled_graph_for_general_shape(*args)
vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__