Files
enginex-bi_150-vllm/vllm/distributed/weight_transfer/base.py
2026-04-09 11:23:47 +08:00

159 lines
5.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
import torch
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo")
TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo")
# Base protocols for backend-specific dataclasses
@dataclass
class WeightTransferInitInfo(ABC): # noqa: B024
"""Base class for backend-specific initialization info."""
pass
@dataclass
class WeightTransferUpdateInfo(ABC): # noqa: B024
"""Base class for backend-specific weight update info."""
_: KW_ONLY
is_checkpoint_format: bool = True
"""Set to True if weights are in checkpoint/original model format and need
layerwise processing. Set to False if weights have already been processed
into kernel format (repacking, renaming, etc.)."""
# API-level request classes (accept dicts for backend-agnostic serialization)
@dataclass
class WeightTransferInitRequest:
"""API-level weight transfer initialization request."""
init_info: dict[str, Any] = field(default_factory=dict)
@dataclass
class WeightTransferUpdateRequest:
"""API-level weight update request."""
update_info: dict[str, Any] = field(default_factory=dict)
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
"""
Base class for weight transfer engines that handle transport of model weights
from a trainer to inference workers.
This abstraction separates weight transfer transport logic from the worker
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
plugged in.
Subclasses should define:
init_info_cls: Type of backend-specific initialization info
update_info_cls: Type of backend-specific update info
"""
# Subclasses should override these class attributes
init_info_cls: type[TInitInfo]
update_info_cls: type[TUpdateInfo]
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
self.config = config
self.parallel_config = parallel_config
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
"""
Construct typed init info from dict with validation.
Args:
init_dict: Dictionary containing backend-specific initialization parameters
Returns:
Typed backend-specific init info dataclass
Raises:
ValueError: If init_dict is invalid for this backend
"""
try:
return self.init_info_cls(**init_dict)
except TypeError as e:
raise ValueError(
f"Invalid init_info for {self.__class__.__name__}: {e}"
) from e
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
"""
Construct typed update info from dict with validation.
Args:
update_dict: Dictionary containing backend-specific update parameters
Returns:
Typed backend-specific update info dataclass
Raises:
ValueError: If update_dict is invalid for this backend
"""
try:
return self.update_info_cls(**update_dict)
except TypeError as e:
raise ValueError(
f"Invalid update_info for {self.__class__.__name__}: {e}"
) from e
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
Args:
init_info: Backend-specific initialization info
"""
raise NotImplementedError
@abstractmethod
def receive_weights(
self,
update_info: TUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer and load them incrementally.
Args:
update_info: Backend-specific update info containing parameter metadata
and any backend-specific data
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
This should be called when the worker is shutting down.
"""
raise NotImplementedError