# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from typing import Any, Callable import torch.fx as fx import vllm.envs as envs from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int compiled: bool = False runnable: Callable = None # type: ignore class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: list[int], compiled_graph_for_general_shape: Callable, vllm_backend: VllmBackend): """ The backend for piecewise compilation. It mainly handles the compilation of static shapes and dispatching based on runtime shape. We will compile `self.graph` once for the general shape, and then compile for different shapes specified in `compilation_config.compile_sizes`. """ self.graph = graph self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 self.is_last_graph = ( piecewise_compile_index == total_piecewise_compiles - 1) self.is_full_graph = total_piecewise_compiles == 1 self.compile_sizes: set[int] = set( self.compilation_config.compile_sizes) self.first_run_finished = False self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa self.sym_shape_indices = sym_shape_indices self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" # the entries for different shapes that we need to compile self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {} # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy() # We only keep compilation management inside this class directly. for shape in self.compile_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, runnable=self.compiled_graph_for_general_shape, ) def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True self.check_for_ending_compilation() return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] if runtime_shape not in self.concrete_size_entries: # we don't need to do anything for this shape return self.compiled_graph_for_general_shape(*args) entry = self.concrete_size_entries[runtime_shape] if not entry.compiled: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: self.check_for_ending_compilation() return entry.runnable(*args)