update
This commit is contained in:
158
vllm/distributed/weight_transfer/base.py
Normal file
158
vllm/distributed/weight_transfer/base.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
|
||||
TInitInfo = TypeVar("TInitInfo", bound="WeightTransferInitInfo")
|
||||
TUpdateInfo = TypeVar("TUpdateInfo", bound="WeightTransferUpdateInfo")
|
||||
|
||||
|
||||
# Base protocols for backend-specific dataclasses
|
||||
@dataclass
|
||||
class WeightTransferInitInfo(ABC): # noqa: B024
|
||||
"""Base class for backend-specific initialization info."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransferUpdateInfo(ABC): # noqa: B024
|
||||
"""Base class for backend-specific weight update info."""
|
||||
|
||||
_: KW_ONLY
|
||||
is_checkpoint_format: bool = True
|
||||
"""Set to True if weights are in checkpoint/original model format and need
|
||||
layerwise processing. Set to False if weights have already been processed
|
||||
into kernel format (repacking, renaming, etc.)."""
|
||||
|
||||
|
||||
# API-level request classes (accept dicts for backend-agnostic serialization)
|
||||
@dataclass
|
||||
class WeightTransferInitRequest:
|
||||
"""API-level weight transfer initialization request."""
|
||||
|
||||
init_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransferUpdateRequest:
|
||||
"""API-level weight update request."""
|
||||
|
||||
update_info: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
"""
|
||||
Base class for weight transfer engines that handle transport of model weights
|
||||
from a trainer to inference workers.
|
||||
|
||||
This abstraction separates weight transfer transport logic from the worker
|
||||
implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
|
||||
plugged in.
|
||||
|
||||
Subclasses should define:
|
||||
init_info_cls: Type of backend-specific initialization info
|
||||
update_info_cls: Type of backend-specific update info
|
||||
"""
|
||||
|
||||
# Subclasses should override these class attributes
|
||||
init_info_cls: type[TInitInfo]
|
||||
update_info_cls: type[TUpdateInfo]
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
self.config = config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
|
||||
"""
|
||||
Construct typed init info from dict with validation.
|
||||
|
||||
Args:
|
||||
init_dict: Dictionary containing backend-specific initialization parameters
|
||||
|
||||
Returns:
|
||||
Typed backend-specific init info dataclass
|
||||
|
||||
Raises:
|
||||
ValueError: If init_dict is invalid for this backend
|
||||
"""
|
||||
try:
|
||||
return self.init_info_cls(**init_dict)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid init_info for {self.__class__.__name__}: {e}"
|
||||
) from e
|
||||
|
||||
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
|
||||
"""
|
||||
Construct typed update info from dict with validation.
|
||||
|
||||
Args:
|
||||
update_dict: Dictionary containing backend-specific update parameters
|
||||
|
||||
Returns:
|
||||
Typed backend-specific update info dataclass
|
||||
|
||||
Raises:
|
||||
ValueError: If update_dict is invalid for this backend
|
||||
"""
|
||||
try:
|
||||
return self.update_info_cls(**update_dict)
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f"Invalid update_info for {self.__class__.__name__}: {e}"
|
||||
) from e
|
||||
|
||||
@abstractmethod
|
||||
def init_transfer_engine(self, init_info: TInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
This is called once at the beginning of training.
|
||||
|
||||
Args:
|
||||
init_info: Backend-specific initialization info
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: TUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from the trainer and load them incrementally.
|
||||
|
||||
Args:
|
||||
update_info: Backend-specific update info containing parameter metadata
|
||||
and any backend-specific data
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user