Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | Any,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers.
|
||||
|
||||
This is a static method that can be called from the trainer process
|
||||
to send weights to all inference workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
The tensors should be on the appropriate device for the backend.
|
||||
trainer_args: Dictionary containing backend-specific arguments needed
|
||||
to send weights. The structure depends on the backend:
|
||||
- NCCL: Contains 'group', 'src', 'packed', etc.
|
||||
- IPC: Contains 'mode' ('http' or 'ray'),
|
||||
'llm_handle' (for Ray), 'url' (for HTTP), etc.
|
||||
|
||||
Example:
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> engine.trainer_send_weights(param_iter, trainer_args)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user