145 lines
4.6 KiB
Python
145 lines
4.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
|
|
from vllm.logger import init_logger
|
|
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
|
|
|
# a single transfer spec (src_blocks_spec, dst_blocks_spec)
|
|
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]
|
|
# transfers are forwarded to workers by (src_medium, dst_medium)
|
|
TransferType = tuple[str, str]
|
|
# transfer result (job_id, success)
|
|
TransferResult = tuple[int, bool]
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class OffloadingHandler(ABC):
|
|
"""
|
|
OffloadingHandler class for managing asynchronous KV data transfers
|
|
|
|
This class runs in the worker.
|
|
It kicks off async KV data transfer requests, and allows
|
|
collecting back completion statuses.
|
|
|
|
The class provides the following primitives:
|
|
transfer_async() - kicks off a new transfer job
|
|
get_finished() - returns a list of newly finished job IDs.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
|
"""
|
|
Initiates an asynchronous transfer of KV data.
|
|
|
|
Args:
|
|
job_id: a unique ID that will be used when notifying back on
|
|
transfer completion.
|
|
spec: the (src, dst) spec of the KV data transfer.
|
|
|
|
Returns:
|
|
True if transfer was submitted successfully.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_finished(self) -> list[TransferResult]:
|
|
"""
|
|
Get transfers finished since last call.
|
|
|
|
Returns:
|
|
A list of (job_id, success) of transfers.
|
|
"""
|
|
pass
|
|
|
|
|
|
class OffloadingWorker:
|
|
"""
|
|
OffloadingWorker class for managing asynchronous KV data transfers
|
|
using multiple OffloadingHandlers
|
|
|
|
This class runs in the worker.
|
|
It kicks off async KV data transfer requests, by delegating
|
|
to one of its registered OffloadingHandlers, based on the transfer type.
|
|
|
|
The class provides the following primitives:
|
|
register_handler() - registers a new handler to handle
|
|
a specific transfer type
|
|
transfer_async() - kicks off a new transfer job
|
|
using one of the registered handlers.
|
|
get_finished() - returns a list of newly finished job IDs
|
|
from all handlers.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.handlers: set[OffloadingHandler] = set()
|
|
self.transfer_type_to_handler: dict[TransferType, OffloadingHandler] = {}
|
|
|
|
def register_handler(
|
|
self,
|
|
src_cls: type[LoadStoreSpec],
|
|
dst_cls: type[LoadStoreSpec],
|
|
handler: OffloadingHandler,
|
|
) -> None:
|
|
"""
|
|
Registers a new handler.
|
|
|
|
Args:
|
|
src_cls: the source type of transfers handled by this handler.
|
|
dst_cls: the destination type of transfers handled by this handler.
|
|
handler: the handler that will handle transfers.
|
|
"""
|
|
transfer_type = (src_cls.medium(), dst_cls.medium())
|
|
assert transfer_type not in self.transfer_type_to_handler
|
|
self.handlers.add(handler)
|
|
self.transfer_type_to_handler[transfer_type] = handler
|
|
|
|
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
|
"""
|
|
Initiates an asynchronous transfer of KV data.
|
|
|
|
Args:
|
|
job_id: a unique ID that will be used when notifying back on
|
|
transfer completion.
|
|
spec: the (src, dst) spec of the KV data transfer.
|
|
|
|
Returns:
|
|
True if transfer was submitted successfully.
|
|
"""
|
|
src, dst = spec
|
|
transfer_type = (src.medium(), dst.medium())
|
|
handler = self.transfer_type_to_handler.get(transfer_type)
|
|
assert handler is not None
|
|
|
|
try:
|
|
success = handler.transfer_async(job_id, spec)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Exception in %r transfer %d: %r",
|
|
transfer_type,
|
|
job_id,
|
|
e,
|
|
exc_info=True,
|
|
)
|
|
return False
|
|
|
|
if not success:
|
|
logger.warning("Failed to submit %r transfer %d", transfer_type, job_id)
|
|
else:
|
|
logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec)
|
|
|
|
return success
|
|
|
|
def get_finished(self) -> list[TransferResult]:
|
|
"""
|
|
Get transfers finished since last call.
|
|
|
|
Returns:
|
|
A list of (job_id, success) of transfers.
|
|
"""
|
|
finished = []
|
|
for handler in self.handlers:
|
|
finished.extend(handler.get_finished())
|
|
return finished
|