# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from abc import ABC, abstractmethod from collections.abc import Callable from concurrent.futures import Future from functools import cached_property from typing import TYPE_CHECKING, Literal, TypeVar, overload from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorHandshakeMetadata, ) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.tasks import SupportedTask from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase if TYPE_CHECKING: from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase logger = init_logger(__name__) _R = TypeVar("_R") FailureCallback = Callable[[], None] class Executor(ABC): """Abstract base class for vLLM executors." An executor is responsible for executing the model on one device, or it can be a distributed executor that can execute the model on multiple devices. """ uses_ray: bool = False # whether the executor uses Ray for orchestration. supports_pp: bool = False # whether the executor supports PP @staticmethod def get_class(vllm_config: VllmConfig) -> type["Executor"]: executor_class: type[Executor] parallel_config = vllm_config.parallel_config distributed_executor_backend = parallel_config.distributed_executor_backend # distributed_executor_backend must be set in VllmConfig.__post_init__ if isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " f"Executor. Got {distributed_executor_backend}." ) executor_class = distributed_executor_backend elif distributed_executor_backend == "ray": from vllm.v1.executor.ray_executor import RayDistributedExecutor executor_class = RayDistributedExecutor elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor elif distributed_executor_backend == "uni": from vllm.v1.executor.uniproc_executor import UniProcExecutor executor_class = UniProcExecutor elif distributed_executor_backend == "external_launcher": # TODO: make v1 scheduling deterministic # to support external launcher executor_class = ExecutorWithExternalLauncher elif isinstance(distributed_executor_backend, str): executor_class = resolve_obj_by_qualname(distributed_executor_backend) if not issubclass(executor_class, Executor): raise TypeError( "distributed_executor_backend must be a subclass of " f"Executor. Got {executor_class}." ) else: raise ValueError( f"Unknown distributed executor backend: {distributed_executor_backend}" ) return executor_class def __init__( self, vllm_config: VllmConfig, ) -> None: self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config self._init_executor() self.is_sleeping = False self.sleeping_tags: set[str] = set() self.kv_output_aggregator: KVOutputAggregator | None = None @abstractmethod def _init_executor(self) -> None: raise NotImplementedError def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("compile_or_warm_up_model") def register_failure_callback(self, callback: FailureCallback): # noqa: B027 """ Register a function to be called if the executor enters a permanent failed state. """ pass def determine_available_memory(self) -> list[int]: # in bytes return self.collective_rpc("determine_available_memory") def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: return self.collective_rpc("get_kv_cache_spec") @overload def collective_rpc( self, method: str | Callable[[WorkerBase], _R], timeout: float | None = None, args: tuple = (), kwargs: dict | None = None, non_block: Literal[False] = False, ) -> list[_R]: """ Execute an RPC call on all workers. Args: method: Name of the worker method to execute, or a callable that is serialized and sent to all workers to execute. If the method is a callable, it should accept an additional `self` argument, in addition to the arguments passed in `args` and `kwargs`. The `self` argument will be the worker object. timeout: Maximum time in seconds to wait for execution. Raises a [`TimeoutError`][] on timeout. `None` means wait indefinitely. args: Positional arguments to pass to the worker method. kwargs: Keyword arguments to pass to the worker method. non_block: If `True`, returns a list of Futures instead of waiting for the results. Returns: A list containing the results from each worker. Note: It is recommended to use this API to only pass control messages, and set up data-plane communication to pass data. """ pass @overload def collective_rpc( self, method: str | Callable[[WorkerBase], _R], timeout: float | None = None, args: tuple = (), kwargs: dict | None = None, non_block: Literal[True] = True, ) -> Future[list[_R]]: pass @abstractmethod def collective_rpc( self, method, timeout=None, args=(), kwargs=None, non_block: bool = False ): raise NotImplementedError def get_kv_connector_handshake_metadata( self, ) -> list[dict[int, KVConnectorHandshakeMetadata]]: return self.collective_rpc("get_kv_connector_handshake_metadata") @overload def execute_model( self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False ) -> ModelRunnerOutput | None: pass @overload def execute_model( self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True ) -> Future[ModelRunnerOutput | None]: pass def execute_model( self, scheduler_output: SchedulerOutput, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] "execute_model", args=(scheduler_output,), non_block=non_block ) return output[0] @overload def sample_tokens( self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False ) -> ModelRunnerOutput: pass @overload def sample_tokens( self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True ) -> Future[ModelRunnerOutput]: pass def sample_tokens( self, grammar_output: GrammarOutput | None, non_block: bool = False ) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]: output = self.collective_rpc( # type: ignore[call-overload] "sample_tokens", args=(grammar_output,), non_block=non_block ) return output[0] def execute_dummy_batch(self) -> None: self.collective_rpc("execute_dummy_batch") def take_draft_token_ids(self) -> DraftTokenIds | None: output: list[DraftTokenIds] = self.collective_rpc("take_draft_token_ids") return output[0] @property def max_concurrent_batches(self) -> int: return 1 def profile(self, is_start: bool = True): self.collective_rpc("profile", args=(is_start,)) def save_sharded_state( self, path: str, pattern: str | None = None, max_size: int | None = None, ) -> None: self.collective_rpc( "save_sharded_state", kwargs=dict(path=path, pattern=pattern, max_size=max_size), ) @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" raise NotImplementedError def shutdown(self) -> None: """Shutdown the executor.""" self.collective_rpc("shutdown") def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None: """Init KVOutputAggregator""" self.kv_output_aggregator = KVOutputAggregator.from_connector( connector, self.parallel_config.world_size ) @cached_property # Avoid unnecessary RPC calls def supported_tasks(self) -> tuple[SupportedTask, ...]: output: list[tuple[SupportedTask, ...]] output = self.collective_rpc("get_supported_tasks") return output[0] def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." return all(self.collective_rpc("add_lora", args=(lora_request,))) def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return all(self.collective_rpc("remove_lora", args=(lora_id,))) def pin_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return all(self.collective_rpc("pin_lora", args=(lora_id,))) def list_loras(self) -> set[int]: sets: list[set[int]] = self.collective_rpc("list_loras") for s in sets: assert s == sets[0], "All workers should have the same LORAs." return sets[0] def reset_mm_cache(self) -> None: """Reset the multi-modal cache in each worker.""" self.collective_rpc("reset_mm_cache") def sleep(self, level: int = 1): if self.is_sleeping: logger.warning("Executor is already sleeping.") return time_before_sleep = time.perf_counter() self.collective_rpc("sleep", kwargs=dict(level=level)) time_after_sleep = time.perf_counter() self.sleeping_tags = {"weights", "kv_cache"} self.is_sleeping = True logger.info( "It took %.6f seconds to fall asleep.", time_after_sleep - time_before_sleep ) def wake_up(self, tags: list[str] | None = None): if not self.is_sleeping: logger.warning("Executor is not sleeping.") return if tags: for tag in tags: if tag not in self.sleeping_tags: logger.warning( "Tag %s is not in sleeping tags %s", tag, self.sleeping_tags ) return time_before_wakeup = time.perf_counter() self.collective_rpc("wake_up", kwargs=dict(tags=tags)) time_after_wakeup = time.perf_counter() logger.info( "It took %.6f seconds to wake up tags %s.", time_after_wakeup - time_before_wakeup, tags if tags is not None else self.sleeping_tags, ) if tags: for tag in tags: self.sleeping_tags.remove(tag) else: self.sleeping_tags.clear() if not self.sleeping_tags: self.is_sleeping = False def reinitialize_distributed( self, reconfig_request: ReconfigureDistributedRequest ) -> None: raise NotImplementedError from vllm.v1.executor.uniproc_executor import ( # noqa: E402 ExecutorWithExternalLauncher as _ExecutorWithExternalLauncher, ) from vllm.v1.executor.uniproc_executor import ( # noqa: E402 UniProcExecutor as _UniProcExecutor, ) # For backwards compatibility. UniProcExecutor = _UniProcExecutor ExecutorWithExternalLauncher = _ExecutorWithExternalLauncher