# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from collections.abc import Callable from typing import Any import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int compiled: bool = False runnable: Callable = None # type: ignore class PiecewiseBackend: def __init__( self, graph: fx.GraphModule, vllm_config: VllmConfig, piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend, ): """ The backend for piecewise compilation. It mainly handles the compilation of static shapes and dispatching based on runtime shape. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1 self.is_full_graph = total_piecewise_compiles == 1 self.compile_sizes: set[int] = set(self.compilation_config.compile_sizes) self.first_run_finished = False self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa self.sym_shape_indices = sym_shape_indices self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() # We only keep compilation management inside this class directly. for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, ) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() return entry.runnable(*args)