[PD] Add kvargs table and thread pool for kvcache sender of mooncake (#5738)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
@@ -73,9 +75,7 @@ class TransferInfo:
|
|||||||
endpoint: str
|
endpoint: str
|
||||||
dst_port: int
|
dst_port: int
|
||||||
mooncake_session_id: str
|
mooncake_session_id: str
|
||||||
dst_kv_ptrs: list[int]
|
|
||||||
dst_kv_indices: npt.NDArray[np.int64]
|
dst_kv_indices: npt.NDArray[np.int64]
|
||||||
dst_aux_ptrs: list[int]
|
|
||||||
dst_aux_index: int
|
dst_aux_index: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -85,10 +85,29 @@ class TransferInfo:
|
|||||||
endpoint=msg[1].decode("ascii"),
|
endpoint=msg[1].decode("ascii"),
|
||||||
dst_port=int(msg[2].decode("ascii")),
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
mooncake_session_id=msg[3].decode("ascii"),
|
mooncake_session_id=msg[3].decode("ascii"),
|
||||||
|
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
|
||||||
|
dst_aux_index=int(msg[5].decode("ascii")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class KVArgsRegisterInfo:
|
||||||
|
room: str
|
||||||
|
endpoint: str
|
||||||
|
dst_port: int
|
||||||
|
mooncake_session_id: str
|
||||||
|
dst_kv_ptrs: list[int]
|
||||||
|
dst_aux_ptrs: list[int]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_zmq(cls, msg: List[bytes]):
|
||||||
|
return cls(
|
||||||
|
room=str(msg[0].decode("ascii")),
|
||||||
|
endpoint=msg[1].decode("ascii"),
|
||||||
|
dst_port=int(msg[2].decode("ascii")),
|
||||||
|
mooncake_session_id=msg[3].decode("ascii"),
|
||||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
||||||
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
|
||||||
dst_aux_index=int(msg[7].decode("ascii")),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -123,8 +142,15 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
self.transfer_queue = queue.Queue()
|
self.transfer_queue = queue.Queue()
|
||||||
self.transfer_infos: Dict[int, TransferInfo] = {}
|
self.transfer_infos: Dict[int, TransferInfo] = {}
|
||||||
|
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
||||||
self.start_prefill_thread()
|
self.start_prefill_thread()
|
||||||
self._register_to_bootstrap()
|
self._register_to_bootstrap()
|
||||||
|
|
||||||
|
# Determine the number of threads to use for kv sender
|
||||||
|
cpu_count = os.cpu_count()
|
||||||
|
self.executor = concurrent.futures.ThreadPoolExecutor(
|
||||||
|
max_workers=cpu_count if cpu_count is not None else 64
|
||||||
|
)
|
||||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
self.start_decode_thread()
|
self.start_decode_thread()
|
||||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||||
@@ -158,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
dst_kv_ptrs: list[int],
|
dst_kv_ptrs: list[int],
|
||||||
dst_kv_indices: npt.NDArray[np.int64],
|
dst_kv_indices: npt.NDArray[np.int64],
|
||||||
):
|
):
|
||||||
# group by indices
|
# Group by indices
|
||||||
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
||||||
prefill_kv_indices, dst_kv_indices
|
prefill_kv_indices, dst_kv_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||||
for layer_id in range(num_layers):
|
layers_params = [
|
||||||
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
(
|
||||||
dst_ptr = dst_kv_ptrs[layer_id]
|
self.kv_args.kv_data_ptrs[layer_id],
|
||||||
item_len = self.kv_args.kv_item_lens[layer_id]
|
dst_kv_ptrs[layer_id],
|
||||||
|
self.kv_args.kv_item_lens[layer_id],
|
||||||
|
)
|
||||||
|
for layer_id in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Worker function for processing a single layer
|
||||||
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
||||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
length = item_len * len(prefill_index)
|
length = item_len * len(prefill_index)
|
||||||
|
|
||||||
# TODO: make async later
|
|
||||||
status = self.engine.transfer_sync(
|
status = self.engine.transfer_sync(
|
||||||
mooncake_session_id, src_addr, dst_addr, length
|
mooncake_session_id, src_addr, dst_addr, length
|
||||||
)
|
)
|
||||||
if status != 0:
|
if status != 0:
|
||||||
return status
|
return status
|
||||||
|
return 0
|
||||||
|
|
||||||
|
futures = [
|
||||||
|
self.executor.submit(
|
||||||
|
process_layer,
|
||||||
|
src_ptr,
|
||||||
|
dst_ptr,
|
||||||
|
item_len,
|
||||||
|
)
|
||||||
|
for (src_ptr, dst_ptr, item_len) in layers_params
|
||||||
|
]
|
||||||
|
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
status = future.result()
|
||||||
|
if status != 0:
|
||||||
|
# Immediate shutdown on first error (existing tasks will finish)
|
||||||
|
executor.shutdown(wait=False)
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
return status
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@@ -223,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
waiting_req_bytes = self.server_socket.recv_multipart()
|
waiting_req_bytes = self.server_socket.recv_multipart()
|
||||||
room = waiting_req_bytes[0].decode("ascii")
|
room = waiting_req_bytes[0].decode("ascii")
|
||||||
if room == "None":
|
if room == "None":
|
||||||
|
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
|
||||||
|
self.decode_kv_args_table[mooncake_session_id] = (
|
||||||
|
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Register KVArgs from {mooncake_session_id} successfully"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
room = int(room)
|
room = int(room)
|
||||||
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
||||||
@@ -244,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
ret = self.send_kvcache(
|
ret = self.send_kvcache(
|
||||||
req.mooncake_session_id,
|
req.mooncake_session_id,
|
||||||
kv_chunk.prefill_kv_indices,
|
kv_chunk.prefill_kv_indices,
|
||||||
req.dst_kv_ptrs,
|
self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
|
||||||
chunked_dst_kv_indice,
|
chunked_dst_kv_indice,
|
||||||
)
|
)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
@@ -259,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
ret = self.send_aux(
|
ret = self.send_aux(
|
||||||
req.mooncake_session_id,
|
req.mooncake_session_id,
|
||||||
kv_chunk.prefill_aux_index,
|
kv_chunk.prefill_aux_index,
|
||||||
req.dst_aux_ptrs,
|
self.decode_kv_args_table[
|
||||||
|
req.mooncake_session_id
|
||||||
|
].dst_aux_ptrs,
|
||||||
req.dst_aux_index,
|
req.dst_aux_index,
|
||||||
)
|
)
|
||||||
self.request_status[req.room] = (
|
self.request_status[req.room] = (
|
||||||
@@ -460,6 +520,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
||||||
|
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
|
||||||
|
self._register_kv_args()
|
||||||
else:
|
else:
|
||||||
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
||||||
|
|
||||||
@@ -502,6 +564,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _register_kv_args(self):
|
||||||
|
self.prefill_server_url = (
|
||||||
|
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
packed_kv_data_ptrs = b"".join(
|
||||||
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
||||||
|
)
|
||||||
|
packed_aux_data_ptrs = b"".join(
|
||||||
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
|
)
|
||||||
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
|
with lock:
|
||||||
|
sock.send_multipart(
|
||||||
|
[
|
||||||
|
"None".encode("ascii"),
|
||||||
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
|
self.session_id.encode("ascii"),
|
||||||
|
packed_kv_data_ptrs,
|
||||||
|
packed_aux_data_ptrs,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _connect(cls, endpoint: str):
|
def _connect(cls, endpoint: str):
|
||||||
with cls._global_lock:
|
with cls._global_lock:
|
||||||
@@ -520,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
||||||
)
|
)
|
||||||
|
|
||||||
packed_kv_data_ptrs = b"".join(
|
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
|
||||||
)
|
|
||||||
packed_aux_data_ptrs = b"".join(
|
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
|
||||||
)
|
|
||||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
with lock:
|
with lock:
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
@@ -534,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
get_local_ip_by_remote().encode("ascii"),
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||||
self.session_id.encode("ascii"),
|
self.session_id.encode("ascii"),
|
||||||
packed_kv_data_ptrs,
|
|
||||||
kv_indices.tobytes(),
|
kv_indices.tobytes(),
|
||||||
packed_aux_data_ptrs,
|
|
||||||
str(aux_index).encode("ascii"),
|
str(aux_index).encode("ascii"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -610,7 +688,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
"rank_port": rank_port,
|
"rank_port": rank_port,
|
||||||
}
|
}
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.Response(text="OK", status=200)
|
return web.Response(text="OK", status=200)
|
||||||
|
|||||||
Reference in New Issue
Block a user