# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Base class for weight transfer engines.""" from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import KW_ONLY, dataclass, field from typing import Any, Generic, TypeVar import torch from vllm.config.parallel import ParallelConfig from vllm.config.weight_transfer import WeightTransferConfig TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo") TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo") # Base protocols for backend-specific dataclasses @dataclass class WeightTransferInitInfo(ABC): # noqa: B024 """Base class for backend-specific initialization info.""" pass @dataclass class WeightTransferUpdateInfo(ABC): # noqa: B024 """Base class for backend-specific weight update info.""" _: KW_ONLY is_checkpoint_format: bool = True """Set to True if weights are in checkpoint/original model format and need layerwise processing. Set to False if weights have already been processed into kernel format (repacking, renaming, etc.).""" # API-level request classes (accept dicts for backend-agnostic serialization) @dataclass class WeightTransferInitRequest: """API-level weight transfer initialization request.""" init_info: dict[str, Any] = field(default_factory=dict) @dataclass class WeightTransferUpdateRequest: """API-level weight update request.""" update_info: dict[str, Any] = field(default_factory=dict) class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): """ Base class for weight transfer engines that handle transport of model weights from a trainer to inference workers. This abstraction separates weight transfer transport logic from the worker implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in. Subclasses should define: init_info_cls: Type of backend-specific initialization info update_info_cls: Type of backend-specific update info """ # Subclasses should override these class attributes init_info_cls: type[TInitInfo] update_info_cls: type[TUpdateInfo] def __init__( self, config: WeightTransferConfig, parallel_config: ParallelConfig ) -> None: """ Initialize the weight transfer engine. Args: config: The configuration for the weight transfer engine parallel_config: The configuration for the parallel setup """ self.config = config self.parallel_config = parallel_config def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo: """ Construct typed init info from dict with validation. Args: init_dict: Dictionary containing backend-specific initialization parameters Returns: Typed backend-specific init info dataclass Raises: ValueError: If init_dict is invalid for this backend """ try: return self.init_info_cls(**init_dict) except TypeError as e: raise ValueError( f"Invalid init_info for {self.__class__.__name__}: {e}" ) from e def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo: """ Construct typed update info from dict with validation. Args: update_dict: Dictionary containing backend-specific update parameters Returns: Typed backend-specific update info dataclass Raises: ValueError: If update_dict is invalid for this backend """ try: return self.update_info_cls(**update_dict) except TypeError as e: raise ValueError( f"Invalid update_info for {self.__class__.__name__}: {e}" ) from e @abstractmethod def init_transfer_engine(self, init_info: TInitInfo) -> None: """ Initialize the weight transfer mechanism. This is called once at the beginning of training. Args: init_info: Backend-specific initialization info """ raise NotImplementedError @abstractmethod def receive_weights( self, update_info: TUpdateInfo, load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], ) -> None: """ Receive weights from the trainer and load them incrementally. Args: update_info: Backend-specific update info containing parameter metadata and any backend-specific data load_weights: Callable that loads weights into the model. Called incrementally for each weight to avoid OOM. """ raise NotImplementedError @abstractmethod def shutdown(self) -> None: """ Shutdown the weight transfer engine. This should be called when the worker is shutting down. """ raise NotImplementedError