[PD Perf] replace Queue to FastQueue (#6649)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
ybyang
2025-05-28 16:37:51 +08:00
committed by GitHub
parent b1c8d4e9f3
commit 6b231325b9
2 changed files with 171 additions and 137 deletions

View File

@@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
FastQueue,
group_concurrent_contiguous, group_concurrent_contiguous,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
@@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager):
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
@@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager):
self.session_failures = defaultdict(int) self.session_failures = defaultdict(int)
self.failed_sessions = set() self.failed_sessions = set()
self.session_lock = threading.Lock() self.session_lock = threading.Lock()
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor( transfer_thread_pool_size = get_int_env_var(
get_int_env_var( "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE",
"SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", min(max(4, int(0.75 * cpu_count) // 8), 12),
min(max(1, cpu_count // 8), 8),
)
) )
transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4)
self.transfer_queues: List[FastQueue] = [
FastQueue() for _ in range(transfer_queue_size)
]
assert transfer_thread_pool_size >= transfer_queue_size, (
f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be "
f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}."
)
self.executors = [
concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for _ in range(transfer_queue_size)
]
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
).start()
self.bootstrap_time_out = get_int_env_var( self.bootstrap_time_out = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
) )
@@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager):
) )
# Heartbeat failure should be at least 1 # Heartbeat failure should be at least 1
self.max_failures = max( self.max_failures = max(
int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
) )
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
@@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager):
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
executor: concurrent.futures.ThreadPoolExecutor,
): ):
# Group by indices # Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
@@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager):
return 0 return 0
futures = [ futures = [
self.executor.submit( executor.submit(
process_layer, process_layer,
src_ptr, src_ptr,
dst_ptr, dst_ptr,
@@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager):
] ]
) )
def transfer_worker(
self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor
):
while True:
try:
kv_chunk: TransferKVChunk = queue.get()
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
executor,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if self.session_failures[req.mooncake_session_id] >= 1:
self.failed_sessions.add(req.mooncake_session_id)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = KVPoll.Success if all(polls) else KVPoll.Failed
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
def start_prefill_thread(self): def start_prefill_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
@@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager):
if len(self.transfer_infos[room]) == required_dst_info_num: if len(self.transfer_infos[room]) == required_dst_info_num:
self.update_status(room, KVPoll.WaitingForInput) self.update_status(room, KVPoll.WaitingForInput)
def transfer_thread():
# TODO: Shall we use KVPoll.Transferring state?
while True:
try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[
kv_chunk.index_slice
]
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
):
kv_chunk.prefill_kv_indices = (
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
)
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if (
self.session_failures[req.mooncake_session_id]
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index,
)
polls.append(True if ret == 0 else False)
dst_ranks_infos.append(
(req.endpoint, req.dst_port, req.room)
)
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
status = (
KVPoll.Success if all(polls) else KVPoll.Failed
)
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
if kv_chunk.room in self.transfer_infos:
self.transfer_infos.pop(kv_chunk.room)
except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
@@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager):
) )
return return
self.transfer_queue.put( # NOTE(shangming): sharding according to the dst_infos to make sure
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
TransferKVChunk( TransferKVChunk(
room=bootstrap_room, room=bootstrap_room,
prefill_kv_indices=kv_indices, prefill_kv_indices=kv_indices,

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import os import os
import random import random
import threading
import warnings import warnings
from collections import deque from collections import deque
from enum import Enum from enum import Enum
@@ -281,6 +282,25 @@ class MetadataBuffers:
) )
class FastQueue:
def __init__(self):
self._buf = deque()
self._cond = threading.Condition()
def put(self, item):
with self._cond:
self._buf.append(item)
# wake up a thread of wait()
self._cond.notify()
def get(self):
with self._cond:
# if queue is empty ,block until is notified()
while not self._buf:
self._cond.wait()
return self._buf.popleft()
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: