v1.0
This commit is contained in:
0
distributed/kv_transfer/kv_pipe/__init__.py
Normal file
0
distributed/kv_transfer/kv_pipe/__init__.py
Normal file
Binary file not shown.
BIN
distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-312.pyc
Normal file
BIN
distributed/kv_transfer/kv_pipe/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
66
distributed/kv_transfer/kv_pipe/base.py
Normal file
66
distributed/kv_transfer/kv_pipe/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file defines an interface `KVPipeBase`
|
||||
that provides an abstraction for sending and receiving tensors, or None, via
|
||||
distributed communications.
|
||||
|
||||
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
||||
|
||||
If your distributed communication platform already supports key-value lookup,
|
||||
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVPipeBase(ABC):
|
||||
"""
|
||||
This class provides an interface for sending and receiving tensors, or
|
||||
None, by distributed communications.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""Send a tensor, or None, via the pipe.
|
||||
|
||||
Need to support sending None -- important for error handling.
|
||||
|
||||
TODO: add a `key` argument so that we can use traditional
|
||||
key-value database as the distributed communication mechanism behind
|
||||
the pipe.
|
||||
|
||||
Args:
|
||||
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""Receive a tensor (can be None) from the pipeline.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
||||
be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the pipeline and release resources.
|
||||
|
||||
This method is responsible for closing the communication pipeline
|
||||
and releasing any resources associated with it.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
295
distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
295
distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
@@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
prefill_url: str
|
||||
decode_url: str
|
||||
metadata_backend: str | None
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
prefill_url=config.get("prefill_url"),
|
||||
decode_url=config.get("decode_url"),
|
||||
metadata_backend=config.get("metadata_backend", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeTransferEngineConfig":
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
||||
)
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
||||
|
||||
def __init__(self, kv_rank: int, local_rank: int):
|
||||
try:
|
||||
from mooncake.engine import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector."
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.local_rank = local_rank
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error("An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = split_host_port(self.config.prefill_url)
|
||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and base_prefill_port == base_decode_port:
|
||||
base_decode_port = base_decode_port + 100
|
||||
|
||||
prefill_port = base_prefill_port + self.local_rank
|
||||
decode_port = base_decode_port + self.local_rank
|
||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
||||
self.decode_url = join_host_port(decode_host, decode_port)
|
||||
|
||||
self.initialize(
|
||||
self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.metadata_backend,
|
||||
)
|
||||
|
||||
self.remote_url = self.decode_url if kv_rank == 0 else self.prefill_url
|
||||
|
||||
# Initialize ZeroMQ context and sockets
|
||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
||||
|
||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
||||
self._setup_metadata_sockets(
|
||||
kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port
|
||||
)
|
||||
|
||||
def _setup_metadata_sockets(
|
||||
self, kv_rank: int, p_host: str, p_port: int, d_host: str, d_port: int
|
||||
) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1)
|
||||
)
|
||||
self.sender_ack.connect(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.receiver_ack.bind(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
else:
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1)
|
||||
)
|
||||
self.sender_socket.bind(make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.receiver_ack.bind(make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.sender_ack.connect(make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
local_hostname: str,
|
||||
metadata_server: str,
|
||||
protocol: str,
|
||||
device_name: str,
|
||||
metadata_backend: str | None,
|
||||
) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if metadata_backend is None:
|
||||
self.engine.initialize(
|
||||
local_hostname, metadata_server, protocol, device_name
|
||||
)
|
||||
else:
|
||||
supported_backend = ["etcd", "redis"]
|
||||
metadata_backend = metadata_backend.lower()
|
||||
if metadata_backend not in supported_backend:
|
||||
raise ValueError(
|
||||
"Mooncake Configuration error. `metadata_backend`"
|
||||
f" should be one of {supported_backend}."
|
||||
)
|
||||
|
||||
self.engine.initialize_ext(
|
||||
local_hostname, metadata_server, protocol, device_name, metadata_backend
|
||||
)
|
||||
|
||||
def allocate_managed_buffer(self, length: int) -> int:
|
||||
"""Allocate a managed buffer of the specified length."""
|
||||
ret = self.engine.allocate_managed_buffer(length)
|
||||
if ret <= 0:
|
||||
logger.error("Allocation Return Error")
|
||||
raise Exception("Allocation Return Error")
|
||||
return ret
|
||||
|
||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
||||
"""Free a previously allocated managed buffer."""
|
||||
return self.engine.free_managed_buffer(buffer, length)
|
||||
|
||||
def transfer_sync(self, buffer: int, peer_buffer_address: int, length: int) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
ret = self.engine.transfer_sync_read(
|
||||
self.remote_url, buffer, peer_buffer_address, length
|
||||
)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
return ret
|
||||
|
||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, length: int) -> int:
|
||||
"""Write bytes to the allocated buffer."""
|
||||
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
|
||||
|
||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
||||
"""Read bytes from the allocated buffer."""
|
||||
return self.engine.read_bytes_from_buffer(buffer, length)
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv()
|
||||
if ack != b"ACK":
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
self.free_managed_buffer(src_ptr, length)
|
||||
|
||||
def send_bytes(self, user_data: bytes) -> None:
|
||||
"""Send bytes to the remote process."""
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_multipart(
|
||||
[struct.pack("!Q", src_ptr), struct.pack("!Q", length)]
|
||||
)
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
data = self.receiver_socket.recv_multipart()
|
||||
src_ptr = struct.unpack("!Q", data[0])[0]
|
||||
length = struct.unpack("!Q", data[1])[0]
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send(b"ACK")
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class MooncakePipe(KVPipeBase):
|
||||
"""MooncakeTransferEngine based Pipe implementation."""
|
||||
|
||||
def __init__(
|
||||
self, local_rank: int, config: KVTransferConfig, device: str | None = None
|
||||
):
|
||||
"""Initialize the mooncake pipe and set related parameters."""
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank, self.local_rank)
|
||||
self.transport_thread: ThreadPoolExecutor | None = None
|
||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""Select available device (CUDA or CPU)."""
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
return hash(tensor.data_ptr())
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
||||
"""Implement the tensor sending logic using safetensors."""
|
||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor:
|
||||
"""Implement the tensor receiving logic using safetensors."""
|
||||
data = self.transfer_engine.recv_bytes()
|
||||
return safetensors_load(data)["tensor"].to(self.device)
|
||||
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""Send tensor to the target process."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = tensor if tensor is not None else self.none_tensor
|
||||
assert len(tensor.shape) > 0
|
||||
self.transport_thread.submit(self._send_impl, tensor)
|
||||
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""Receive tensor from other processes."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
||||
return None
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup logic when closing the pipe."""
|
||||
self.transfer_engine.sender_socket.close()
|
||||
self.transfer_engine.receiver_socket.close()
|
||||
self.transfer_engine.sender_ack.close()
|
||||
self.transfer_engine.receiver_ack.close()
|
||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
||||
285
distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
285
distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
communication features.
|
||||
|
||||
Key Features:
|
||||
- Supports sending and receiving tensors with metadata
|
||||
- Handles both CUDA and CPU device communications
|
||||
- Implements a non-blocking tensor transfer mechanism
|
||||
- Manages buffer size and provides backpressure control
|
||||
- Supports distributed process groups with configurable parameters
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BrokenPipeException(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
Metadata = dict[str, torch.Tensor | None]
|
||||
|
||||
|
||||
class PyNcclPipe(KVPipeBase):
|
||||
METADATA_LENGTH = 16
|
||||
MAX_TENSOR_DIMENSIONS = 14
|
||||
METADATA_DTYPE = torch.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: str | None = None,
|
||||
port_offset: int = 0,
|
||||
):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
assert self.kv_rank is not None
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
store_timeout=store_timeout,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
impl = self._get_device_send_recv_impl(self.group)
|
||||
self.device_send_func, self.device_recv_func = impl
|
||||
# set target rank
|
||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
||||
|
||||
# transportation-related variables
|
||||
self.transport_thread: ThreadPoolExecutor | None = None
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_lock = threading.Lock()
|
||||
self.buffer_size_thresh = self.config.kv_buffer_size
|
||||
|
||||
def _get_device_send_recv_impl(
|
||||
self, group: StatelessProcessGroup
|
||||
) -> tuple[
|
||||
Callable[[torch.Tensor, int], None], Callable[[torch.Tensor, int], None]
|
||||
]:
|
||||
send: Callable[[torch.Tensor, int], None]
|
||||
recv: Callable[[torch.Tensor, int], None]
|
||||
if self.device.type == "cuda":
|
||||
# use PyNCCL for send / recv
|
||||
comm = PyNcclCommunicator(group, device=self.local_rank)
|
||||
comm.disabled = False
|
||||
send, recv = comm.send, comm.recv # type: ignore
|
||||
else:
|
||||
# This send / recv implementation here is NOT intended to transfer
|
||||
# KV caches (and should NOT be repurposed to transfer KV caches).
|
||||
# Currently it is only used to transmit control-plane messages
|
||||
# for PyNcclBuffer.
|
||||
send = group.send_obj
|
||||
|
||||
def my_recv(x, src):
|
||||
x[...] = group.recv_obj(src)
|
||||
|
||||
recv = my_recv
|
||||
|
||||
return send, recv
|
||||
|
||||
def _select_device(self, device: str):
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _make_metadata(self, tensor: torch.Tensor | None) -> Metadata:
|
||||
"""
|
||||
Create the metadata as a dictionary based on the input tensor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor or None if no tensor is provided.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with the following keys:
|
||||
- "dtype": The data type of the tensor or None.
|
||||
- "shape": The shape of the tensor or None.
|
||||
"""
|
||||
if tensor is None:
|
||||
return {"dtype": None, "shape": None}
|
||||
else:
|
||||
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
||||
|
||||
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
||||
"""
|
||||
Create a buffer to receive the tensor based on the provided metadata.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape",
|
||||
describing the tensor's data type and shape.
|
||||
|
||||
Returns:
|
||||
buffer: A tensor of the specified type and shape,
|
||||
allocated on `self.device`.
|
||||
"""
|
||||
return torch.empty(
|
||||
metadata["shape"], dtype=metadata["dtype"], device=self.device
|
||||
)
|
||||
|
||||
def _send_metadata(self, metadata: Metadata):
|
||||
"""
|
||||
Send the metadata dictionary to the target rank.
|
||||
|
||||
Args:
|
||||
metadata: A dictionary with keys "dtype" and "shape".
|
||||
"""
|
||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
||||
|
||||
def _recv_metadata(self) -> Metadata:
|
||||
"""
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
Returns:
|
||||
metadata: A dictionary with keys "dtype" and "shape"
|
||||
describing the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor | None) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor to be sent, or `None` if no tensor is
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
self._send_metadata(metadata)
|
||||
if tensor is not None:
|
||||
self.device_send_func(tensor.to(self.device), self.target_rank_for_send)
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor | None:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
|
||||
Returns:
|
||||
buffer: The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
metadata = self._recv_metadata()
|
||||
if metadata["dtype"] is None:
|
||||
return None
|
||||
buffer = self._prepare_recv_buffer(metadata)
|
||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
||||
|
||||
return buffer
|
||||
|
||||
def send_tensor_wrapper(
|
||||
self, tensor: torch.Tensor | None, tensor_size: int
|
||||
) -> None:
|
||||
"""
|
||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
||||
"""
|
||||
try:
|
||||
self._send_impl(tensor)
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size -= tensor_size
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"[rank%d]: Exception when trying to send %s, msg: %s",
|
||||
torch.distributed.get_rank(),
|
||||
str(tensor),
|
||||
str(e),
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def block_if_full(self):
|
||||
"""
|
||||
Block the current thread if the buffer size is larger than the
|
||||
threshold.
|
||||
"""
|
||||
while self.buffer_size > self.buffer_size_thresh:
|
||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
||||
time.sleep(0.05)
|
||||
|
||||
def send_tensor(self, tensor: torch.Tensor | None) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
|
||||
Args:
|
||||
tensor: The tensor to send, or `None` if no tensor is being sent.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
if tensor is not None:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
tensor_size = 0
|
||||
|
||||
self.block_if_full()
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size)
|
||||
|
||||
def recv_tensor(self) -> torch.Tensor | None:
|
||||
"""
|
||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||
|
||||
Returns:
|
||||
The received tensor, or `None` if no tensor is received.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
future = self.transport_thread.submit(self._recv_impl)
|
||||
|
||||
try:
|
||||
tensor = future.result()
|
||||
except Exception as e:
|
||||
logger.error("Encountering exception in KV receiving thread")
|
||||
logger.error("%s", e)
|
||||
logger.error("My device: %s", self.device)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
return tensor
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the pipe and release associated resources.
|
||||
"""
|
||||
if hasattr(self, "transport_thread") and self.transport_thread is not None:
|
||||
self.transport_thread.shutdown()
|
||||
Reference in New Issue
Block a user