### What this PR does / why we need it? Revert PR 5253 to fix the smoking problem ### Does this PR introduce _any_ user-facing change? Does not. ### How was this patch tested? It was tested in the failure case. Signed-off-by: Rifa <865071616@qq.com>
236 lines
9.5 KiB
Python
236 lines
9.5 KiB
Python
from collections.abc import Callable
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
import vllm.compilation.backends
|
|
import vllm.compilation.piecewise_backend
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
from vllm.compilation.backends import VllmBackend
|
|
from vllm.compilation.counter import compilation_counter
|
|
from vllm.compilation.piecewise_backend import RangeEntry
|
|
from vllm.config import CUDAGraphMode, VllmConfig
|
|
from vllm.config.utils import Range
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class AscendPiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
|
It runs the given graph with fake inputs, and compile some
|
|
submodules specified by `compile_submod_names` with the given
|
|
compilation configs.
|
|
|
|
NOTE: the order in `compile_submod_names` matters, because
|
|
it will be used to determine the order of the compiled piecewise
|
|
graphs. The first graph will handle logging, and the last graph
|
|
has some special cudagraph output handling.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
module: torch.fx.GraphModule,
|
|
compile_submod_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
vllm_backend: "VllmBackend",
|
|
):
|
|
super().__init__(module)
|
|
from torch._guards import detect_fake_mode
|
|
|
|
self.fake_mode = detect_fake_mode()
|
|
self.compile_submod_names = compile_submod_names
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.vllm_config = vllm_config
|
|
self.vllm_backend = vllm_backend
|
|
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
|
self.extra_traceback = False
|
|
|
|
def run(self, *args):
|
|
# maybe instead just assert inputs are fake?
|
|
fake_args = [
|
|
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
|
for t in args
|
|
]
|
|
with self.fake_mode, enable_python_dispatcher():
|
|
return super().run(*fake_args)
|
|
|
|
def call_module(
|
|
self,
|
|
target: torch.fx.node.Target,
|
|
args: tuple[torch.fx.node.Argument, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> Any:
|
|
assert isinstance(target, str)
|
|
|
|
output = super().call_module(target, args, kwargs)
|
|
|
|
if target in self.compile_submod_names:
|
|
index = self.compile_submod_names.index(target)
|
|
submod = self.fetch_attr(target)
|
|
|
|
sym_shape_indices = [
|
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
|
]
|
|
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
|
r1 = Range(start=1, end=max_num_batched_tokens)
|
|
compiled_graph_for_dynamic_shape = (
|
|
self.vllm_backend.compiler_manager.compile(
|
|
submod,
|
|
args,
|
|
self.vllm_backend.inductor_config,
|
|
self.compilation_config,
|
|
graph_index=index,
|
|
num_graphs=len(self.compile_submod_names),
|
|
compile_range=r1,
|
|
))
|
|
|
|
# Lazy import here to avoid circular import
|
|
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
|
|
|
piecewise_backend = PiecewiseBackend(
|
|
submod,
|
|
self.vllm_config,
|
|
index,
|
|
len(self.compile_submod_names),
|
|
sym_shape_indices,
|
|
compiled_graph_for_dynamic_shape,
|
|
self.vllm_backend,
|
|
)
|
|
|
|
if (self.compilation_config.cudagraph_mode.
|
|
has_piecewise_cudagraphs() and
|
|
not self.compilation_config.use_inductor_graph_partition):
|
|
# We're using Dynamo-based piecewise splitting, so we wrap
|
|
# the whole subgraph with a static graph wrapper.
|
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
|
|
|
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
|
# class) as platform dependent.
|
|
static_graph_wrapper_class = resolve_obj_by_qualname(
|
|
current_platform.get_static_graph_wrapper_cls())
|
|
|
|
# Always assign PIECEWISE runtime mode to the
|
|
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
|
# it from the FULL cudagraph runtime mode, no matter it
|
|
# is wrapped on a full or piecewise fx graph.
|
|
self.module.__dict__[target] = static_graph_wrapper_class(
|
|
runnable=piecewise_backend,
|
|
vllm_config=self.vllm_config,
|
|
runtime_mode=CUDAGraphMode.PIECEWISE,
|
|
cudagraph_options=CUDAGraphOptions(
|
|
debug_log_enable=piecewise_backend.is_first_graph,
|
|
gc_disable=not piecewise_backend.is_first_graph,
|
|
weak_ref_output=piecewise_backend.is_last_graph,
|
|
),
|
|
)
|
|
else:
|
|
self.module.__dict__[target] = piecewise_backend
|
|
|
|
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
|
|
|
return output
|
|
|
|
|
|
class AscendPiecewiseBackend:
|
|
|
|
def __init__(
|
|
self,
|
|
graph: fx.GraphModule,
|
|
vllm_config: VllmConfig,
|
|
piecewise_compile_index: int,
|
|
total_piecewise_compiles: int,
|
|
sym_shape_indices: list[int],
|
|
compiled_graph_for_general_shape: Callable,
|
|
vllm_backend: VllmBackend,
|
|
):
|
|
"""
|
|
The backend for piecewise compilation.
|
|
It mainly handles the compilation of static shapes and
|
|
dispatching based on runtime shape.
|
|
|
|
We will compile `self.graph` once for the general shape,
|
|
and then compile for different shapes specified in
|
|
`compilation_config.compile_sizes`.
|
|
"""
|
|
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape
|
|
self.graph = graph
|
|
self.vllm_config = vllm_config
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.piecewise_compile_index = piecewise_compile_index
|
|
self.total_piecewise_compiles = total_piecewise_compiles
|
|
self.vllm_backend = vllm_backend
|
|
|
|
self.is_first_graph = piecewise_compile_index == 0
|
|
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
|
|
|
self.is_full_graph = total_piecewise_compiles == 1
|
|
self.is_encoder_compilation = vllm_backend.is_encoder
|
|
|
|
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
|
if self.is_encoder_compilation:
|
|
# For encoder compilation we use the max int32 value
|
|
# to set the upper bound of the compile ranges
|
|
max_int32 = 2**31 - 1
|
|
last_compile_range = self.compile_ranges[-1]
|
|
assert (last_compile_range.end ==
|
|
vllm_config.scheduler_config.max_num_batched_tokens)
|
|
self.compile_ranges[-1] = Range(start=last_compile_range.start,
|
|
end=max_int32)
|
|
|
|
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
|
logger.debug_once(log_string)
|
|
|
|
self.compile_sizes = self.compilation_config.compile_sizes
|
|
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
|
logger.debug_once(log_string)
|
|
|
|
self.sym_shape_indices = sym_shape_indices
|
|
|
|
# the entries for ranges that we need to either
|
|
self.range_entries: dict[Range, RangeEntry] = {}
|
|
|
|
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
|
# and updates during the compilation process, so we need to copy it
|
|
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
|
|
|
# We only keep compilation management inside this class directly.
|
|
for size in self.compile_sizes:
|
|
range = Range(start=size, end=size)
|
|
if range not in self.compile_ranges:
|
|
self.range_entries[range] = RangeEntry(compile_range=range, )
|
|
self.to_be_compiled_ranges.add(range)
|
|
|
|
for range in self.compile_ranges:
|
|
self.range_entries[range] = RangeEntry(compile_range=range, )
|
|
|
|
def _find_range_for_shape(self, runtime_shape: int) -> Range | None:
|
|
# First we try to find the range entry for the concrete compile size
|
|
# If not found, we search for the range entry
|
|
# that contains the runtime shape.
|
|
if runtime_shape in self.compile_sizes:
|
|
return self.range_entries[Range(start=runtime_shape,
|
|
end=runtime_shape)]
|
|
else:
|
|
for range in self.compile_ranges:
|
|
if runtime_shape in range:
|
|
return self.range_entries[range]
|
|
return None
|
|
|
|
def __call__(self, *args) -> Any:
|
|
runtime_shape = args[self.sym_shape_indices[0]]
|
|
range_entry = self._find_range_for_shape(runtime_shape)
|
|
|
|
assert range_entry is not None, (
|
|
f"Shape out of considered range: {runtime_shape} "
|
|
"[1, max_num_batched_tokens]")
|
|
|
|
return self.compiled_graph_for_general_shape(*args)
|
|
|
|
|
|
vllm.compilation.backends.PiecewiseCompileInterpreter = AscendPiecewiseCompileInterpreter
|
|
vllm.compilation.piecewise_backend.PiecewiseBackend.__init__ = AscendPiecewiseBackend.__init__
|
|
vllm.compilation.piecewise_backend.PiecewiseBackend.__call__ = AscendPiecewiseBackend.__call__
|