# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This module implements a PyNccl pipe for sending and receiving Optional[torch.Tensor] between distributed ranks with advanced communication features. Key Features: - Supports sending and receiving tensors with metadata - Handles both CUDA and CPU device communications - Implements a non-blocking tensor transfer mechanism - Manages buffer size and provides backpressure control - Supports distributed process groups with configurable parameters """ import threading import time from concurrent.futures import ThreadPoolExecutor from typing import Callable, Optional import torch from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger logger = init_logger(__name__) class BrokenPipeException(Exception): def __init__(self, message): self.message = message super().__init__(self.message) Metadata = dict[str, Optional[torch.Tensor]] class PyNcclPipe(KVPipeBase): METADATA_LENGTH = 16 MAX_TENSOR_DIMENSIONS = 14 METADATA_DTYPE = torch.int64 def __init__(self, local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank self.kv_parallel_size = self.config.kv_parallel_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: self.device = self._select_device(device) # build distributed connection and send/recv implementation store_timeout = self.config.get_from_extra_config("store_timeout", 300) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, rank=self.kv_rank, world_size=self.kv_parallel_size, store_timeout=store_timeout, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() impl = self._get_device_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = impl # set target rank self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None self.buffer_size = 0 self.buffer_size_lock = threading.Lock() self.buffer_size_thresh = self.config.kv_buffer_size def _get_device_send_recv_impl( self, group: StatelessProcessGroup ) -> tuple[Callable[[torch.Tensor, int], None], Callable[ [torch.Tensor, int], None]]: send: Callable[[torch.Tensor, int], None] recv: Callable[[torch.Tensor, int], None] if self.device.type == "cuda": # use PyNCCL for send / recv comm = PyNcclCommunicator(group, device=self.local_rank) comm.disabled = False send, recv = comm.send, comm.recv # type: ignore else: # This send / recv implementation here is NOT intended to transfer # KV caches (and should NOT be repurposed to transfer KV caches). # Currently it is only used to transmit control-plane messages # for PyNcclBuffer. send = group.send_obj def my_recv(x, src): x[...] = group.recv_obj(src) recv = my_recv return send, recv def _select_device(self, device: str): logger.info("Selecting device: %s", device) if device == "cuda": return torch.device(f"cuda:{self.local_rank}") else: return torch.device("cpu") def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: """ Create the metadata as a dictionary based on the input tensor. Args: tensor: The input tensor or None if no tensor is provided. Returns: metadata: A dictionary with the following keys: - "dtype": The data type of the tensor or None. - "shape": The shape of the tensor or None. """ if tensor is None: return {"dtype": None, "shape": None} else: return {"dtype": tensor.dtype, "shape": tensor.shape} def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: """ Create a buffer to receive the tensor based on the provided metadata. Args: metadata: A dictionary with keys "dtype" and "shape", describing the tensor's data type and shape. Returns: buffer: A tensor of the specified type and shape, allocated on `self.device`. """ return torch.empty(metadata["shape"], dtype=metadata["dtype"], device=self.device) def _send_metadata(self, metadata: Metadata): """ Send the metadata dictionary to the target rank. Args: metadata: A dictionary with keys "dtype" and "shape". """ self.group.send_obj(metadata, self.target_rank_for_send) def _recv_metadata(self) -> Metadata: """ Receive the metadata dictionary from the target rank. Returns: metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ return self.group.recv_obj(self.target_rank_for_recv) def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. Args: tensor: The input tensor to be sent, or `None` if no tensor is being sent. """ metadata = self._make_metadata(tensor) self._send_metadata(metadata) if tensor is not None: self.device_send_func(tensor.to(self.device), self.target_rank_for_send) def _recv_impl(self) -> Optional[torch.Tensor]: """ The actual implementation of receiving a tensor and its metadata from the target rank. Returns: buffer: The received tensor, or `None` if no tensor is received. """ metadata = self._recv_metadata() if metadata["dtype"] is None: return None buffer = self._prepare_recv_buffer(metadata) self.device_recv_func(buffer, self.target_rank_for_recv) return buffer def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], tensor_size: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ try: self._send_impl(tensor) with self.buffer_size_lock: self.buffer_size -= tensor_size except Exception as e: logger.error("[rank%d]: Exception when trying to send %s, msg: %s", torch.distributed.get_rank(), str(tensor), str(e)) import traceback traceback.print_exc() def block_if_full(self): """ Block the current thread if the buffer size is larger than the threshold. """ while self.buffer_size > self.buffer_size_thresh: logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. Args: tensor: The tensor to send, or `None` if no tensor is being sent. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) if tensor is not None: tensor_size = tensor.element_size() * tensor.numel() else: tensor_size = 0 self.block_if_full() with self.buffer_size_lock: self.buffer_size += tensor_size self.transport_thread.submit(self.send_tensor_wrapper, tensor, tensor_size) def recv_tensor(self) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. Returns: The received tensor, or `None` if no tensor is received. """ if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) future = self.transport_thread.submit(self._recv_impl) try: tensor = future.result() except Exception as e: logger.error("Encountering exception in KV receiving thread") logger.error("%s", e) logger.error("My device: %s", self.device) import traceback traceback.print_exc() raise e return tensor def close(self): """ Close the pipe and release associated resources. """ if hasattr(self, "transport_thread") and self.transport_thread is not None: self.transport_thread.shutdown()