281 lines
9.5 KiB
Python
281 lines
9.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
This module implements a PyNccl pipe for sending and receiving
|
|
Optional[torch.Tensor] between distributed ranks with advanced
|
|
communication features.
|
|
|
|
Key Features:
|
|
- Supports sending and receiving tensors with metadata
|
|
- Handles both CUDA and CPU device communications
|
|
- Implements a non-blocking tensor transfer mechanism
|
|
- Manages buffer size and provides backpressure control
|
|
- Supports distributed process groups with configurable parameters
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.config import KVTransferConfig
|
|
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
|
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
|
from vllm.distributed.utils import StatelessProcessGroup
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class BrokenPipeException(Exception):
|
|
|
|
def __init__(self, message):
|
|
self.message = message
|
|
super().__init__(self.message)
|
|
|
|
|
|
Metadata = dict[str, Optional[torch.Tensor]]
|
|
|
|
|
|
class PyNcclPipe(KVPipeBase):
|
|
|
|
METADATA_LENGTH = 16
|
|
MAX_TENSOR_DIMENSIONS = 14
|
|
METADATA_DTYPE = torch.int64
|
|
|
|
def __init__(self,
|
|
local_rank: int,
|
|
config: KVTransferConfig,
|
|
device: Optional[str] = None,
|
|
port_offset: int = 0):
|
|
self.config = config
|
|
self.local_rank = local_rank
|
|
self.kv_rank = self.config.kv_rank
|
|
self.kv_parallel_size = self.config.kv_parallel_size
|
|
if device is None:
|
|
self.device = self._select_device(self.config.kv_buffer_device)
|
|
else:
|
|
self.device = self._select_device(device)
|
|
|
|
# build distributed connection and send/recv implementation
|
|
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
|
self.group = StatelessProcessGroup.create(
|
|
host=self.config.kv_ip,
|
|
port=self.config.kv_port + port_offset,
|
|
rank=self.kv_rank,
|
|
world_size=self.kv_parallel_size,
|
|
store_timeout=store_timeout,
|
|
)
|
|
# add a barrier to make sure the connection is initiated properly
|
|
self.group.barrier()
|
|
impl = self._get_device_send_recv_impl(self.group)
|
|
self.device_send_func, self.device_recv_func = impl
|
|
# set target rank
|
|
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
|
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
|
|
|
# transportation-related variables
|
|
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
|
self.buffer_size = 0
|
|
self.buffer_size_lock = threading.Lock()
|
|
self.buffer_size_thresh = self.config.kv_buffer_size
|
|
|
|
def _get_device_send_recv_impl(
|
|
self, group: StatelessProcessGroup
|
|
) -> tuple[Callable[[torch.Tensor, int], None], Callable[
|
|
[torch.Tensor, int], None]]:
|
|
|
|
send: Callable[[torch.Tensor, int], None]
|
|
recv: Callable[[torch.Tensor, int], None]
|
|
if self.device.type == "cuda":
|
|
# use PyNCCL for send / recv
|
|
comm = PyNcclCommunicator(group, device=self.local_rank)
|
|
comm.disabled = False
|
|
send, recv = comm.send, comm.recv # type: ignore
|
|
else:
|
|
# This send / recv implementation here is NOT intended to transfer
|
|
# KV caches (and should NOT be repurposed to transfer KV caches).
|
|
# Currently it is only used to transmit control-plane messages
|
|
# for PyNcclBuffer.
|
|
send = group.send_obj
|
|
|
|
def my_recv(x, src):
|
|
x[...] = group.recv_obj(src)
|
|
|
|
recv = my_recv
|
|
|
|
return send, recv
|
|
|
|
def _select_device(self, device: str):
|
|
logger.info("Selecting device: %s", device)
|
|
if device == "cuda":
|
|
return torch.device(f"cuda:{self.local_rank}")
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
|
|
"""
|
|
Create the metadata as a dictionary based on the input tensor.
|
|
|
|
Args:
|
|
tensor: The input tensor or None if no tensor is provided.
|
|
|
|
Returns:
|
|
metadata: A dictionary with the following keys:
|
|
- "dtype": The data type of the tensor or None.
|
|
- "shape": The shape of the tensor or None.
|
|
"""
|
|
if tensor is None:
|
|
return {"dtype": None, "shape": None}
|
|
else:
|
|
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
|
|
|
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
|
"""
|
|
Create a buffer to receive the tensor based on the provided metadata.
|
|
|
|
Args:
|
|
metadata: A dictionary with keys "dtype" and "shape",
|
|
describing the tensor's data type and shape.
|
|
|
|
Returns:
|
|
buffer: A tensor of the specified type and shape,
|
|
allocated on `self.device`.
|
|
"""
|
|
return torch.empty(metadata["shape"],
|
|
dtype=metadata["dtype"],
|
|
device=self.device)
|
|
|
|
def _send_metadata(self, metadata: Metadata):
|
|
"""
|
|
Send the metadata dictionary to the target rank.
|
|
|
|
Args:
|
|
metadata: A dictionary with keys "dtype" and "shape".
|
|
"""
|
|
self.group.send_obj(metadata, self.target_rank_for_send)
|
|
|
|
def _recv_metadata(self) -> Metadata:
|
|
"""
|
|
Receive the metadata dictionary from the target rank.
|
|
|
|
Returns:
|
|
metadata: A dictionary with keys "dtype" and "shape"
|
|
describing the tensor.
|
|
"""
|
|
return self.group.recv_obj(self.target_rank_for_recv)
|
|
|
|
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
|
"""
|
|
The actual implementation of sending the tensor and its metadata to the
|
|
target rank.
|
|
|
|
Args:
|
|
tensor: The input tensor to be sent, or `None` if no tensor is
|
|
being sent.
|
|
"""
|
|
metadata = self._make_metadata(tensor)
|
|
self._send_metadata(metadata)
|
|
if tensor is not None:
|
|
self.device_send_func(tensor.to(self.device),
|
|
self.target_rank_for_send)
|
|
|
|
def _recv_impl(self) -> Optional[torch.Tensor]:
|
|
"""
|
|
The actual implementation of receiving a tensor and its metadata from
|
|
the target rank.
|
|
|
|
Returns:
|
|
buffer: The received tensor, or `None` if no tensor is received.
|
|
"""
|
|
metadata = self._recv_metadata()
|
|
if metadata["dtype"] is None:
|
|
return None
|
|
buffer = self._prepare_recv_buffer(metadata)
|
|
self.device_recv_func(buffer, self.target_rank_for_recv)
|
|
|
|
return buffer
|
|
|
|
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
|
|
tensor_size: int) -> None:
|
|
"""
|
|
Wrapper for _send_impl to handle exceptions and update buffer size.
|
|
"""
|
|
try:
|
|
self._send_impl(tensor)
|
|
|
|
with self.buffer_size_lock:
|
|
self.buffer_size -= tensor_size
|
|
except Exception as e:
|
|
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
|
|
torch.distributed.get_rank(), str(tensor), str(e))
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
def block_if_full(self):
|
|
"""
|
|
Block the current thread if the buffer size is larger than the
|
|
threshold.
|
|
"""
|
|
while self.buffer_size > self.buffer_size_thresh:
|
|
logger.debug("KV cache transfer pipe is full. Waiting...")
|
|
time.sleep(0.05)
|
|
|
|
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
|
"""
|
|
Sends a tensor and its metadata to the destination rank in a
|
|
non-blocking way.
|
|
|
|
Args:
|
|
tensor: The tensor to send, or `None` if no tensor is being sent.
|
|
"""
|
|
if self.transport_thread is None:
|
|
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
|
|
if tensor is not None:
|
|
tensor_size = tensor.element_size() * tensor.numel()
|
|
else:
|
|
tensor_size = 0
|
|
|
|
self.block_if_full()
|
|
|
|
with self.buffer_size_lock:
|
|
self.buffer_size += tensor_size
|
|
|
|
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
|
tensor_size)
|
|
|
|
def recv_tensor(self) -> Optional[torch.Tensor]:
|
|
"""
|
|
Receives a tensor and its metadata from the source rank. Blocking call.
|
|
|
|
Args:
|
|
tensor: The received tensor, or `None` if no tensor is received.
|
|
"""
|
|
if self.transport_thread is None:
|
|
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
|
|
|
future = self.transport_thread.submit(self._recv_impl)
|
|
|
|
try:
|
|
tensor = future.result()
|
|
except Exception as e:
|
|
logger.error("Encountering exception in KV receiving thread")
|
|
logger.error("%s", e)
|
|
logger.error("My device: %s", self.device)
|
|
import traceback
|
|
traceback.print_exc()
|
|
raise e
|
|
|
|
return tensor
|
|
|
|
def close(self):
|
|
"""
|
|
Close the pipe and release associated resources.
|
|
"""
|
|
if hasattr(self,
|
|
"transport_thread") and self.transport_thread is not None:
|
|
self.transport_thread.shutdown()
|