# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ ECConnectorBase Class for Distributed Encoder Cache & P2P Encoder cache communication in V1 The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save Encoder cache. check_caches_exist() - Check whether Encoder cache of requests exist update_state_after_alloc() - update ECConnector state after allocate. This will decide to load the cache or not request_finished() - called when a request is finished, free the cache with the requests Worker-side: runs in each worker, loads/saves Encoder Cache to/from the Connector based on the metadata. start_load_ec() - starts loading all ECs (maybe async) wait_for_save() - blocks until all saves are done get_finished() - called with ids of finished requests, returns ids of requests that have completed async sending/recving. """ import enum from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any import torch from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ECConnectorOutput if TYPE_CHECKING: from vllm.config import VllmConfig from vllm.v1.request import Request logger = init_logger(__name__) class ECConnectorRole(enum.Enum): # Connector running in the scheduler process SCHEDULER = 0 # Connector running in the worker process WORKER = 1 class ECConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the Scheduler ECConnector and Worker ECConnector. """ pass class ECConnectorBase(ABC): def __init__(self, vllm_config: "VllmConfig", role: ECConnectorRole): self._connector_metadata: ECConnectorMetadata | None = None self._vllm_config = vllm_config self._role = role if vllm_config.ec_transfer_config is not None: self._is_producer = vllm_config.ec_transfer_config.is_ec_producer else: raise ValueError("ec_transfer_config must be set for ECConnectorBase") @property def role(self) -> ECConnectorRole: return self._role @property def is_producer(self) -> bool: return self._is_producer # ============================== # Worker-side methods # ============================== def bind_connector_metadata(self, connector_metadata: ECConnectorMetadata) -> None: """Set the connector metadata from the scheduler. This function should be called by the model runner every time before the model execution. The metadata will be used for runtime EC cache loading. Args: connector_metadata (dict): the connector metadata. """ self._connector_metadata = connector_metadata def clear_connector_metadata(self) -> None: """Clear the connector metadata. This function should be called by the model runner every time after the model execution. """ self._connector_metadata = None def _get_connector_metadata(self) -> ECConnectorMetadata: """Get the connector metadata. This function should only be called inside the connector. Returns: ConnectorMetadata: the connector metadata. """ # Should only be called while set to valid metadata. assert self._connector_metadata is not None return self._connector_metadata def register_caches( self, ec_caches: dict[str, torch.Tensor], ): """ Initialize with the EC caches. Args: ec_caches: dictionary of encoder cache """ # TODO: Implement this later for P2P feature return @abstractmethod def start_load_caches( self, encoder_cache: dict[str, torch.Tensor], **kwargs ) -> None: """ Start loading the cache from the connector into vLLM's encoder cache. This method loads the encoder cache based on metadata provided by the scheduler. It is called before `_gather_mm_embeddings` for the EC Connector. For EC, the `encoder_cache` and `mm_hash` are stored in `kwargs`. Args: encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal data hashes (`mm_hash`) to encoder cache tensors. kwargs (dict): Additional keyword arguments for the connector. """ pass @abstractmethod def save_caches( self, encoder_cache: dict[str, torch.Tensor], mm_hash: str, **kwargs ) -> None: """ Save the encoder cache to the connector. This method saves the encoder cache from the worker's local storage to shared storage or another external connector. Args: encoder_cache (dict[str, torch.Tensor]): A dictionary mapping multimodal data hashes (`mm_hash`) to encoder cache tensors. mm_hash (str): The hash of the multimodal data whose cache is being saved. kwargs (dict): Additional keyword arguments for the connector. """ pass def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. The scheduler process (via the Executors) will use this output to track which workers are done. Returns: ids of requests that have finished asynchronous transfer (requests that previously returned True from request_finished()), tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ return None, None # ============================== # Scheduler-side methods # ============================== @abstractmethod def has_caches( self, request: "Request", ) -> list[bool]: """ Check if encoder cache exists for each mm data of requests Args: request (Request): the request object. Returns: A list bool where ith value is True if cache exist for ith mm_data of requests """ pass @abstractmethod def update_state_after_alloc(self, request: "Request", index: int): """ Update ECConnector state to decide allocate cache for requests Args: request (Request): the request object. """ pass @abstractmethod def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> ECConnectorMetadata: """ Build the connector metadata for this step. This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector. Args: scheduler_output (SchedulerOutput): the scheduler output object. """ pass def update_connector_output(self, connector_output: ECConnectorOutput): """ Update ECConnector state from worker-side connectors output. Args: connector_output (ECConnectorOutput): the worker-side connectors output. """ return def request_finished( self, request: "Request" ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its encoder cache is freed. Returns: True if the request is being saved/sent asynchronously and cached should not be freed until the request_id is returned from get_finished(). """ return False, None