Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
0
vllm/distributed/kv_transfer/kv_pipe/__init__.py
Normal file
0
vllm/distributed/kv_transfer/kv_pipe/__init__.py
Normal file
66
vllm/distributed/kv_transfer/kv_pipe/base.py
Normal file
66
vllm/distributed/kv_transfer/kv_pipe/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This file defines an interface `KVPipeBase`
|
||||
that provides an abstraction for sending and receiving tensors, or None, via
|
||||
distributed communications.
|
||||
|
||||
All classes instantiated from this interface are assumed to be a FIFO pipe.
|
||||
|
||||
If your distributed communication platform already supports key-value lookup,
|
||||
you can bypass this interface and directly start from `kv_lookup_buffer`.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KVPipeBase(ABC):
|
||||
"""
|
||||
This class provides an interface for sending and receiving tensors, or
|
||||
None, by distributed communications.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""Send a tensor, or None, via the pipe.
|
||||
|
||||
Need to support sending None -- important for error handling.
|
||||
|
||||
TODO: add a `key` argument so that we can use traditional
|
||||
key-value database as the distributed communication mechanism behind
|
||||
the pipe.
|
||||
|
||||
Args:
|
||||
tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""Receive a tensor (can be None) from the pipeline.
|
||||
|
||||
Returns:
|
||||
Optional[torch.Tensor]: The tensor received from the pipeline. Can
|
||||
be None.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the pipeline and release resources.
|
||||
|
||||
This method is responsible for closing the communication pipeline
|
||||
and releasing any resources associated with it.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: This method must be implemented in subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
458
vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py
Normal file
458
vllm/distributed/kv_transfer/kv_pipe/flagcx_p2p_nccl_pipe.py
Normal file
@@ -0,0 +1,458 @@
|
||||
# Mainly adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py.
|
||||
# Below is the original copyright:
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
import ctypes
|
||||
import sys
|
||||
sys.path.append(os.getenv('FLAGCX_PATH'))
|
||||
from plugin.interservice.flagcx_wrapper import (
|
||||
FLAGCXLibrary,
|
||||
buffer_type,
|
||||
cudaStream_t,
|
||||
flagcxComm_t,
|
||||
flagcxDataTypeEnum,
|
||||
)
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.utils import current_stream, get_ip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2pNcclPipe:
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: Optional[str] = None) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
flagcx_path = os.getenv('FLAGCX_PATH')
|
||||
library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so")
|
||||
self.flagcx = FLAGCXLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = self.config.kv_port + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
self.comm_cv = threading.Condition()
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
|
||||
if self.send_type == "GET":
|
||||
self.send_store: Dict[str,
|
||||
torch.Tensor] = {} # tensor_id: torch.Tensor
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
self.send_queue: Deque[
|
||||
List[Any]] = deque() # tensor_id: torch.Tensor
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self._send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
self.recv_store: Dict[str,
|
||||
torch.Tensor] = {} # tensor_id: torch.Tensor
|
||||
self.socks: Dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = self.config.kv_buffer_size
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self._ping,
|
||||
daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
def _create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info("👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address, self.comms)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.flagcx.flagcxGetUniqueId().contents
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
comm = self.flagcx.flagcxCommInitRank(
|
||||
2, ctypes.byref(unique_id), rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
else:
|
||||
if self.send_type == "PUT":
|
||||
return self._send_sync(tensor_id, tensor, remote_address)
|
||||
elif self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
self.send_queue_cv.notify()
|
||||
else: # GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size,
|
||||
self.buffer_size, oldest_tenser_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.info(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address, tensor_id, tensor_size, tensor.shape,
|
||||
self.rank, self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
self.recv_store[tensor_id] = None
|
||||
while len(self.recv_store) > 10000:
|
||||
self.recv_store.pop(next(iter(self.recv_store)))
|
||||
|
||||
duration = time.time() - start_time
|
||||
if tensor is not None:
|
||||
self.buffer_size -= (tensor.element_size() * tensor.numel())
|
||||
logger.info(
|
||||
"🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, "
|
||||
"duration:%.3fms, size:%.3fGB, rank:%d", remote_address,
|
||||
tensor_id, tensor.shape, duration * 1000,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
self.rank)
|
||||
else:
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
|
||||
"rank:%d", remote_address, tensor_id, duration * 1000,
|
||||
self.rank)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
start_time = time.time()
|
||||
self._recv(comm, tensor, rank ^ 1)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, "
|
||||
"size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape,
|
||||
duration * 1000,
|
||||
tensor.element_size() * tensor.numel() / 1024**3, self.rank)
|
||||
|
||||
return tensor
|
||||
|
||||
def _listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket in socks:
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
logger.debug("Received message from %s, data:%s",
|
||||
remote_address.decode(), data)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.flagcx.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
# comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
# 2, unique_id, rank)
|
||||
comm = self.flagcx.flagcxCommInitRank(
|
||||
2, ctypes.byref(unique_id), rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"2"])
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
tensor = None
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._recv(comm, tensor, rank ^ 1)
|
||||
logger.info(
|
||||
"🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, "
|
||||
"data:%s, shape:%s", self.zmq_address,
|
||||
remote_address.decode(), rank, data,
|
||||
tensor.shape)
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype":
|
||||
str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1)
|
||||
|
||||
logger.info(
|
||||
"🔵[GET]Send Tensor, %s👉%s, "
|
||||
"MyRank:%s, data:%s", self.zmq_address,
|
||||
remote_address.decode(), rank, data)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
# Asynchronous sending may cause conflicts between P2P NCCL and
|
||||
# NCCL used in TP/PP, which can lead to deadlock issues.
|
||||
def _send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
tensor_id, remote_address, tensor = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self._send_sync(tensor_id, tensor, remote_address)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def _send_sync(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
return False
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
# with self.send_queue_cv:
|
||||
# self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
# self.send_queue_cv.notify()
|
||||
logger.warning(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1)
|
||||
logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape)
|
||||
return True
|
||||
|
||||
def _ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with self.comm_cv:
|
||||
flagcx_stream = self.flagcx.adaptor_stream_copy(stream)
|
||||
self.flagcx.flagcxSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
flagcxDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
comm, flagcx_stream)
|
||||
self.flagcx.adaptor_stream_free(flagcx_stream)
|
||||
|
||||
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with self.comm_cv:
|
||||
flagcx_stream = self.flagcx.adaptor_stream_copy(stream)
|
||||
self.flagcx.flagcxRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
flagcxDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
comm, flagcx_stream)
|
||||
self.flagcx.adaptor_stream_free(flagcx_stream)
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
274
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
274
vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py
Normal file
@@ -0,0 +1,274 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
from safetensors.torch import load as safetensors_load
|
||||
from safetensors.torch import save as safetensors_save
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeTransferEngineConfig:
|
||||
prefill_url: str
|
||||
decode_url: str
|
||||
metadata_backend: Union[str, None]
|
||||
metadata_server: str
|
||||
protocol: str
|
||||
device_name: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> 'MooncakeTransferEngineConfig':
|
||||
"""Load the config from a JSON file."""
|
||||
with open(file_path) as fin:
|
||||
config = json.load(fin)
|
||||
return MooncakeTransferEngineConfig(
|
||||
prefill_url=config.get("prefill_url"),
|
||||
decode_url=config.get("decode_url"),
|
||||
metadata_backend=config.get("metadata_backend", None),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> 'MooncakeTransferEngineConfig':
|
||||
"""Load config from a file specified in the environment variable."""
|
||||
config_file_path = os.getenv('MOONCAKE_CONFIG_PATH')
|
||||
if config_file_path is None:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
||||
|
||||
|
||||
class MooncakeTransferEngine:
|
||||
"""Handles the transfer of data using mooncake_vllm_adaptor and ZeroMQ."""
|
||||
|
||||
def __init__(self, kv_rank: int, local_rank: int):
|
||||
try:
|
||||
import mooncake_vllm_adaptor as mva
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
|
||||
self.engine = mva.mooncake_vllm_adaptor()
|
||||
self.local_rank = local_rank
|
||||
|
||||
try:
|
||||
self.config = MooncakeTransferEngineConfig.load_from_env()
|
||||
logger.info("Mooncake Configuration loaded successfully.")
|
||||
except ValueError as e:
|
||||
logger.error(e)
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and \
|
||||
base_prefill_port == base_decode_port:
|
||||
base_decode_port = str(int(base_decode_port) + 100)
|
||||
|
||||
prefill_port = int(base_prefill_port) + self.local_rank
|
||||
decode_port = int(base_decode_port) + self.local_rank
|
||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||
|
||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server, self.config.protocol,
|
||||
self.config.device_name, self.config.metadata_backend)
|
||||
|
||||
self.remote_url = (self.decode_url
|
||||
if kv_rank == 0 else self.prefill_url)
|
||||
|
||||
# Initialize ZeroMQ context and sockets
|
||||
self.context = zmq.Context() # type: ignore[attr-defined]
|
||||
self.sender_socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.receiver_socket = self.context.socket(zmq.constants.PULL)
|
||||
self.sender_ack = self.context.socket(zmq.constants.PULL)
|
||||
self.receiver_ack = self.context.socket(zmq.constants.PUSH)
|
||||
|
||||
self.buffer_cleaner = ThreadPoolExecutor(max_workers=1)
|
||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||
decode_host, base_decode_port)
|
||||
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||
d_host: str, d_port: str) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
protocol: str, device_name: str,
|
||||
metadata_backend: Union[str, None]) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
if metadata_backend is None:
|
||||
self.engine.initialize(local_hostname, metadata_server, protocol,
|
||||
device_name)
|
||||
else:
|
||||
supported_backend = ["etcd", "redis"]
|
||||
metadata_backend = metadata_backend.lower()
|
||||
if metadata_backend not in supported_backend:
|
||||
raise ValueError(
|
||||
"Mooncake Configuration error. `metadata_backend`"
|
||||
f" should be one of {supported_backend}.")
|
||||
|
||||
self.engine.initializeExt(local_hostname, metadata_server,
|
||||
protocol, device_name, metadata_backend)
|
||||
|
||||
def allocate_managed_buffer(self, length: int) -> int:
|
||||
"""Allocate a managed buffer of the specified length."""
|
||||
ret = self.engine.allocateManagedBuffer(length)
|
||||
if ret <= 0:
|
||||
logger.error("Allocation Return Error")
|
||||
raise Exception("Allocation Return Error")
|
||||
return ret
|
||||
|
||||
def free_managed_buffer(self, buffer: int, length: int) -> int:
|
||||
"""Free a previously allocated managed buffer."""
|
||||
return self.engine.freeManagedBuffer(buffer, length)
|
||||
|
||||
def transfer_sync(self, buffer: int, peer_buffer_address: int,
|
||||
length: int) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
ret = self.engine.transferSync(self.remote_url, buffer,
|
||||
peer_buffer_address, length)
|
||||
if ret < 0:
|
||||
logger.error("Transfer Return Error")
|
||||
raise Exception("Transfer Return Error")
|
||||
return ret
|
||||
|
||||
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
|
||||
length: int) -> int:
|
||||
"""Write bytes to the allocated buffer."""
|
||||
return self.engine.writeBytesToBuffer(buffer, user_data, length)
|
||||
|
||||
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
|
||||
"""Read bytes from the allocated buffer."""
|
||||
return self.engine.readBytesFromBuffer(buffer, length)
|
||||
|
||||
def wait_for_ack(self, src_ptr: int, length: int) -> None:
|
||||
"""Asynchronously wait for ACK from the receiver."""
|
||||
ack = self.sender_ack.recv_pyobj()
|
||||
if ack != b'ACK':
|
||||
logger.error("Failed to receive ACK from the receiver")
|
||||
|
||||
self.free_managed_buffer(src_ptr, length)
|
||||
|
||||
def send_bytes(self, user_data: bytes) -> None:
|
||||
"""Send bytes to the remote process."""
|
||||
length = len(user_data)
|
||||
src_ptr = self.allocate_managed_buffer(length)
|
||||
self.write_bytes_to_buffer(src_ptr, user_data, length)
|
||||
self.sender_socket.send_pyobj((src_ptr, length))
|
||||
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
|
||||
|
||||
def recv_bytes(self) -> bytes:
|
||||
"""Receive bytes from the remote process."""
|
||||
src_ptr, length = self.receiver_socket.recv_pyobj()
|
||||
dst_ptr = self.allocate_managed_buffer(length)
|
||||
self.transfer_sync(dst_ptr, src_ptr, length)
|
||||
ret = self.read_bytes_from_buffer(dst_ptr, length)
|
||||
|
||||
# Buffer cleanup
|
||||
self.receiver_ack.send_pyobj(b'ACK')
|
||||
self.free_managed_buffer(dst_ptr, length)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class MooncakePipe(KVPipeBase):
|
||||
"""MooncakeTransferEngine based Pipe implementation."""
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None):
|
||||
"""Initialize the mooncake pipe and set related parameters."""
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
self.transfer_engine = MooncakeTransferEngine(self.kv_rank,
|
||||
self.local_rank)
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
self.none_tensor = torch.tensor([NONE_INT], device=self.device)
|
||||
|
||||
def _select_device(self, device: str) -> torch.device:
|
||||
"""Select available device (CUDA or CPU)."""
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def tensor_hash(self, tensor: torch.Tensor) -> int:
|
||||
"""Calculate the hash value of the tensor."""
|
||||
return hash(tensor.data_ptr())
|
||||
|
||||
def _send_impl(self, tensor: torch.Tensor) -> None:
|
||||
"""Implement the tensor sending logic using safetensors."""
|
||||
self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor}))
|
||||
|
||||
def _recv_impl(self) -> torch.Tensor:
|
||||
"""Implement the tensor receiving logic using safetensors."""
|
||||
data = self.transfer_engine.recv_bytes()
|
||||
return safetensors_load(data)["tensor"].to(self.device)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""Send tensor to the target process."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = tensor if tensor is not None else self.none_tensor
|
||||
assert (len(tensor.shape) > 0)
|
||||
self.transport_thread.submit(self._send_impl, tensor)
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""Receive tensor from other processes."""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
tensor = self.transport_thread.submit(self._recv_impl).result()
|
||||
if tensor.numel() == 1 and tensor.item() == NONE_INT:
|
||||
return None
|
||||
else:
|
||||
return tensor
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cleanup logic when closing the pipe."""
|
||||
self.transfer_engine.sender_socket.close()
|
||||
self.transfer_engine.receiver_socket.close()
|
||||
self.transfer_engine.sender_ack.close()
|
||||
self.transfer_engine.receiver_ack.close()
|
||||
self.transfer_engine.context.term() # Terminate the ZMQ context
|
||||
logger.info("Closed the transfer engine and cleaned up resources.")
|
||||
445
vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py
Normal file
445
vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py
Normal file
@@ -0,0 +1,445 @@
|
||||
# Copied adopted from https://github.com/vllm-project/vllm/blob/1ad957950ffc1552af5abda78c03d88ddb67945b/vllm/distributed/kv_transfer/kv_pipe/p2p_nccl_pipe.py.
|
||||
# Below is the original copyright:
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from typing import Any, Deque, Dict, List, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
|
||||
from vllm.utils import current_stream, get_ip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class P2pNcclPipe:
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: Optional[str] = None) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = self.config.kv_port + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
|
||||
if self.send_type == "GET":
|
||||
self.send_store: Dict[str,
|
||||
torch.Tensor] = {} # tensor_id: torch.Tensor
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
self.send_queue: Deque[
|
||||
List[Any]] = deque() # tensor_id: torch.Tensor
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self._send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
self.recv_store: Dict[str,
|
||||
torch.Tensor] = {} # tensor_id: torch.Tensor
|
||||
self.socks: Dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: Dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = self.config.kv_buffer_size
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self._ping,
|
||||
daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
def _create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info("👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address, self.comms)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.nccl.ncclGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
else:
|
||||
if self.send_type == "PUT":
|
||||
return self._send_sync(tensor_id, tensor, remote_address)
|
||||
elif self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
self.send_queue_cv.notify()
|
||||
else: # GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size,
|
||||
self.buffer_size, oldest_tenser_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.info(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address, tensor_id, tensor_size, tensor.shape,
|
||||
self.rank, self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
self.recv_store[tensor_id] = None
|
||||
while len(self.recv_store) > 10000:
|
||||
self.recv_store.pop(next(iter(self.recv_store)))
|
||||
|
||||
duration = time.time() - start_time
|
||||
if tensor is not None:
|
||||
self.buffer_size -= (tensor.element_size() * tensor.numel())
|
||||
logger.info(
|
||||
"🔵[PUT]Recv From %s, tensor_id:%s, shape:%s, "
|
||||
"duration:%.3fms, size:%.3fGB, rank:%d", remote_address,
|
||||
tensor_id, tensor.shape, duration * 1000,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
self.rank)
|
||||
else:
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
|
||||
"rank:%d", remote_address, tensor_id, duration * 1000,
|
||||
self.rank)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
start_time = time.time()
|
||||
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"🔵[GET]Recv From %s, tensor_id:%s, shape:%s, duration:%.3fms, "
|
||||
"size:%.3fGB, rank:%d", remote_address, tensor_id, tensor.shape,
|
||||
duration * 1000,
|
||||
tensor.element_size() * tensor.numel() / 1024**3, self.rank)
|
||||
|
||||
return tensor
|
||||
|
||||
def _listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket in socks:
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
logger.debug("Received message from %s, data:%s",
|
||||
remote_address.decode(), data)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"2"])
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
tensor = None
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._recv(comm, tensor, rank ^ 1,
|
||||
self.recv_stream)
|
||||
logger.info(
|
||||
"🔵[PUT]Recv Tensor, %s👈%s, MyRank:%s, "
|
||||
"data:%s, shape:%s", self.zmq_address,
|
||||
remote_address.decode(), rank, data,
|
||||
tensor.shape)
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype":
|
||||
str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
|
||||
logger.info(
|
||||
"🔵[GET]Send Tensor, %s👉%s, "
|
||||
"MyRank:%s, data:%s", self.zmq_address,
|
||||
remote_address.decode(), rank, data)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
def _send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
tensor_id, remote_address, tensor = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self._send_sync(tensor_id, tensor, remote_address)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.info(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def _send_sync(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
return False
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
# with self.send_queue_cv:
|
||||
# self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
# self.send_queue_cv.notify()
|
||||
logger.warning(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
logger.info("🔵Send Tensor, %s👉%s, MyRank:%s, data:%s, tensor:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape)
|
||||
return True
|
||||
|
||||
def _ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
comm, cudaStream_t(stream.cuda_stream))
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
279
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
279
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
communication features.
|
||||
|
||||
Key Features:
|
||||
- Supports sending and receiving tensors with metadata
|
||||
- Handles both CUDA and CPU device communications
|
||||
- Implements a non-blocking tensor transfer mechanism
|
||||
- Manages buffer size and provides backpressure control
|
||||
- Supports distributed process groups with configurable parameters
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BrokenPipeException(Exception):
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
Metadata = Dict[str, Optional[torch.Tensor]]
|
||||
|
||||
|
||||
class PyNcclPipe(KVPipeBase):
|
||||
|
||||
METADATA_LENGTH = 16
|
||||
MAX_TENSOR_DIMENSIONS = 14
|
||||
METADATA_DTYPE = torch.int64
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None,
|
||||
port_offset: int = 0):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
store_timeout=store_timeout,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
impl = self._get_device_send_recv_impl(self.group)
|
||||
self.device_send_func, self.device_recv_func = impl
|
||||
# set target rank
|
||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
||||
|
||||
# transportation-related variables
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_lock = threading.Lock()
|
||||
self.buffer_size_thresh = self.config.kv_buffer_size
|
||||
|
||||
def _get_device_send_recv_impl(
|
||||
self, group: StatelessProcessGroup
|
||||
) -> Tuple[Callable[[torch.Tensor, int], None], Callable[
|
||||
[torch.Tensor, int], None]]:
|
||||
|
||||
send: Callable[[torch.Tensor, int], None]
|
||||
recv: Callable[[torch.Tensor, int], None]
|
||||
if self.device.type == "cuda":
|
||||
# use PyNCCL for send / recv
|
||||
comm = PyNcclCommunicator(group, device=self.local_rank)
|
||||
comm.disabled = False
|
||||
send, recv = comm.send, comm.recv # type: ignore
|
||||
else:
|
||||
# This send / recv implementation here is NOT intended to transfer
|
||||
# KV caches (and should NOT be repurposed to transfer KV caches).
|
||||
# Currently it is only used to transmit control-plane messages
|
||||
# for PyNcclBuffer.
|
||||
send = group.send_obj
|
||||
|
||||
def my_recv(x, src):
|
||||
x[...] = group.recv_obj(src)
|
||||
|
||||
recv = my_recv
|
||||
|
||||
return send, recv
|
||||
|
||||
def _select_device(self, device: str):
|
||||
logger.info("Selecting device: %s", device)
|
||||
if device == "cuda":
|
||||
return torch.device(f"cuda:{self.local_rank}")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata:
|
||||
"""
|
||||
Create the metadata as a dictionary based on the input tensor.
|
||||
|
||||
Parameters:
|
||||
- tensor: The input tensor or None if no tensor is provided.
|
||||
|
||||
Returns:
|
||||
- metadata: A dictionary with the following keys:
|
||||
- "dtype": The data type of the tensor or None.
|
||||
- "shape": The shape of the tensor or None.
|
||||
"""
|
||||
if tensor is None:
|
||||
return {"dtype": None, "shape": None}
|
||||
else:
|
||||
return {"dtype": tensor.dtype, "shape": tensor.shape}
|
||||
|
||||
def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor:
|
||||
"""
|
||||
Create a buffer to receive the tensor based on the provided metadata.
|
||||
|
||||
Parameters:
|
||||
- metadata: A dictionary with keys "dtype" and "shape", describing
|
||||
the tensor's data type and shape.
|
||||
|
||||
Returns:
|
||||
- buffer: A tensor of the specified type and shape, allocated on
|
||||
self.device.
|
||||
"""
|
||||
return torch.empty(metadata["shape"],
|
||||
dtype=metadata["dtype"],
|
||||
device=self.device)
|
||||
|
||||
def _send_metadata(self, metadata: Metadata):
|
||||
"""
|
||||
Send the metadata dictionary to the target rank.
|
||||
|
||||
Parameters:
|
||||
- metadata: A dictionary with keys "dtype" and "shape".
|
||||
"""
|
||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
||||
|
||||
def _recv_metadata(self) -> Metadata:
|
||||
"""
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
Returns:
|
||||
- metadata: A dictionary with keys "dtype" and "shape" describing
|
||||
the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
|
||||
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
|
||||
Parameters:
|
||||
- tensor: The input tensor to be sent, or None if no tensor is
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
self._send_metadata(metadata)
|
||||
if tensor is not None:
|
||||
self.device_send_func(tensor.to(self.device),
|
||||
self.target_rank_for_send)
|
||||
|
||||
def _recv_impl(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
|
||||
Returns:
|
||||
- buffer: The received tensor, or None if no tensor is received.
|
||||
"""
|
||||
metadata = self._recv_metadata()
|
||||
if metadata["dtype"] is None:
|
||||
return None
|
||||
buffer = self._prepare_recv_buffer(metadata)
|
||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
||||
|
||||
return buffer
|
||||
|
||||
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
|
||||
tensor_size: int) -> None:
|
||||
"""
|
||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
||||
"""
|
||||
try:
|
||||
self._send_impl(tensor)
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size -= tensor_size
|
||||
except Exception as e:
|
||||
logger.error("[rank%d]: Exception when trying to send %s, msg: %s",
|
||||
torch.distributed.get_rank(), str(tensor), str(e))
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def block_if_full(self):
|
||||
"""
|
||||
Block the current thread if the buffer size is larger than the
|
||||
threshold.
|
||||
"""
|
||||
while self.buffer_size > self.buffer_size_thresh:
|
||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
||||
time.sleep(0.05)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
|
||||
Parameters:
|
||||
- tensor: The tensor to send, or None if no tensor is being sent.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
if tensor is not None:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
else:
|
||||
tensor_size = 0
|
||||
|
||||
self.block_if_full()
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
||||
tensor_size)
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||
|
||||
Returns:
|
||||
- tensor: The received tensor, or None if no tensor is received.
|
||||
"""
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
future = self.transport_thread.submit(self._recv_impl)
|
||||
|
||||
try:
|
||||
tensor = future.result()
|
||||
except Exception as e:
|
||||
logger.error("Encountering exception in KV receiving thread")
|
||||
logger.error("%s", e)
|
||||
logger.error("My device: %s", self.device)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
|
||||
return tensor
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the pipe and release associated resources.
|
||||
"""
|
||||
if hasattr(self,
|
||||
"transport_thread") and self.transport_thread is not None:
|
||||
self.transport_thread.shutdown()
|
||||
Reference in New Issue
Block a user