[PD] Add kvargs table and thread pool for kvcache sender of mooncake (#5738)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-04-25 18:15:01 +08:00
committed by GitHub
parent c55550cbf0
commit 50eda8398e

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import concurrent.futures
import dataclasses import dataclasses
import logging import logging
import os
import queue import queue
import socket import socket
import struct import struct
@@ -73,9 +75,7 @@ class TransferInfo:
endpoint: str endpoint: str
dst_port: int dst_port: int
mooncake_session_id: str mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
dst_aux_index: int dst_aux_index: int
@classmethod @classmethod
@@ -85,10 +85,29 @@ class TransferInfo:
endpoint=msg[1].decode("ascii"), endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")), dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_index=int(msg[5].decode("ascii")),
)
@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
) )
@@ -123,8 +142,15 @@ class MooncakeKVManager(BaseKVManager):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue() self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {} self.transfer_infos: Dict[int, TransferInfo] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
self._register_to_bootstrap() self._register_to_bootstrap()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=cpu_count if cpu_count is not None else 64
)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
@@ -158,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
): ):
# group by indices # Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices prefill_kv_indices, dst_kv_indices
) )
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
for layer_id in range(num_layers): layers_params = [
src_ptr = self.kv_args.kv_data_ptrs[layer_id] (
dst_ptr = dst_kv_ptrs[layer_id] self.kv_args.kv_data_ptrs[layer_id],
item_len = self.kv_args.kv_item_lens[layer_id] dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(num_layers)
]
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index) length = item_len * len(prefill_index)
# TODO: make async later
status = self.engine.transfer_sync( status = self.engine.transfer_sync(
mooncake_session_id, src_addr, dst_addr, length mooncake_session_id, src_addr, dst_addr, length
) )
if status != 0: if status != 0:
return status return status
return 0
futures = [
self.executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
# Immediate shutdown on first error (existing tasks will finish)
executor.shutdown(wait=False)
for f in futures:
f.cancel()
return status
return 0 return 0
@@ -223,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
waiting_req_bytes = self.server_socket.recv_multipart() waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None": if room == "None":
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue continue
room = int(room) room = int(room)
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
@@ -244,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_kvcache( ret = self.send_kvcache(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_kv_indices, kv_chunk.prefill_kv_indices,
req.dst_kv_ptrs, self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
chunked_dst_kv_indice, chunked_dst_kv_indice,
) )
if ret != 0: if ret != 0:
@@ -259,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_aux( ret = self.send_aux(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_aux_index, kv_chunk.prefill_aux_index,
req.dst_aux_ptrs, self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index, req.dst_aux_index,
) )
self.request_status[req.room] = ( self.request_status[req.room] = (
@@ -460,6 +520,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
else: else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else: else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
@@ -502,6 +564,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None return None
def _register_kv_args(self):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
]
)
@classmethod @classmethod
def _connect(cls, endpoint: str): def _connect(cls, endpoint: str):
with cls._global_lock: with cls._global_lock:
@@ -520,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url) sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
@@ -534,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
get_local_ip_by_remote().encode("ascii"), get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(), kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"), str(aux_index).encode("ascii"),
] ]
) )
@@ -610,7 +688,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)