diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 8ab5066ec..9ebdd60f0 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -31,6 +31,7 @@ from sglang.srt.disaggregation.base.conn import ( from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.utils import ( DisaggregationMode, + FastQueue, group_concurrent_contiguous, ) from sglang.srt.server_args import ServerArgs @@ -151,7 +152,6 @@ class MooncakeKVManager(BaseKVManager): self.server_socket = zmq.Context().socket(zmq.PULL) self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: - self.transfer_queue = queue.Queue() self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.start_prefill_thread() @@ -159,15 +159,31 @@ class MooncakeKVManager(BaseKVManager): self.session_failures = defaultdict(int) self.failed_sessions = set() self.session_lock = threading.Lock() - # Determine the number of threads to use for kv sender cpu_count = os.cpu_count() - self.executor = concurrent.futures.ThreadPoolExecutor( - get_int_env_var( - "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", - min(max(1, cpu_count // 8), 8), - ) + transfer_thread_pool_size = get_int_env_var( + "SGLANG_DISAGGREGATION_THREAD_POOL_SIZE", + min(max(4, int(0.75 * cpu_count) // 8), 12), ) + transfer_queue_size = get_int_env_var("SGLANG_DISAGGREGATION_QUEUE_SIZE", 4) + self.transfer_queues: List[FastQueue] = [ + FastQueue() for _ in range(transfer_queue_size) + ] + assert transfer_thread_pool_size >= transfer_queue_size, ( + f"The environment variable SGLANG_DISAGGREGATION_THREAD_POOL_SIZE={transfer_thread_pool_size} must be " + f"greater than or equal to SGLANG_DISAGGREGATION_QUEUE_SIZE={transfer_queue_size}." + ) + self.executors = [ + concurrent.futures.ThreadPoolExecutor( + transfer_thread_pool_size // transfer_queue_size + ) + for _ in range(transfer_queue_size) + ] + for queue, executor in zip(self.transfer_queues, self.executors): + threading.Thread( + target=self.transfer_worker, args=(queue, executor), daemon=True + ).start() + self.bootstrap_time_out = get_int_env_var( "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30 ) @@ -183,7 +199,7 @@ class MooncakeKVManager(BaseKVManager): ) # Heartbeat failure should be at least 1 self.max_failures = max( - int(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 + get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1 ) self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} @@ -220,6 +236,7 @@ class MooncakeKVManager(BaseKVManager): prefill_kv_indices: npt.NDArray[np.int64], dst_kv_ptrs: list[int], dst_kv_indices: npt.NDArray[np.int64], + executor: concurrent.futures.ThreadPoolExecutor, ): # Group by indices prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( @@ -251,7 +268,7 @@ class MooncakeKVManager(BaseKVManager): return 0 futures = [ - self.executor.submit( + executor.submit( process_layer, src_ptr, dst_ptr, @@ -298,6 +315,123 @@ class MooncakeKVManager(BaseKVManager): ] ) + def transfer_worker( + self, queue: FastQueue, executor: concurrent.futures.ThreadPoolExecutor + ): + while True: + try: + kv_chunk: TransferKVChunk = queue.get() + reqs_to_be_processed = ( + self.transfer_infos[kv_chunk.room].values() + if kv_chunk.room in self.transfer_infos + else [] + ) + polls = [] + dst_ranks_infos = [] + for req in reqs_to_be_processed: + if not req.is_dummy: + # Early exit if the request has failed + with self.session_lock: + if req.mooncake_session_id in self.failed_sessions: + self.record_failure( + kv_chunk.room, + f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive", + ) + self.update_status(kv_chunk.room, KVPoll.Failed) + self.sync_status_to_decode_endpoint( + req.endpoint, + req.dst_port, + req.room, + KVPoll.Failed, + ) + break + + chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice] + + # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices + # is mismatched with the dst_kv_indices when page size > 1, this should never happen. + if len(chunked_dst_kv_indice) < len( + kv_chunk.prefill_kv_indices + ): + kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[ + len(chunked_dst_kv_indice) + ] + logger.warning( + f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + ) + + ret = self.send_kvcache( + req.mooncake_session_id, + kv_chunk.prefill_kv_indices, + self.decode_kv_args_table[ + req.mooncake_session_id + ].dst_kv_ptrs, + chunked_dst_kv_indice, + executor, + ) + if ret != 0: + with self.session_lock: + self.session_failures[req.mooncake_session_id] += 1 + # Failures should never happen if the session is not dead, if the session fails once, mark it as failed + if self.session_failures[req.mooncake_session_id] >= 1: + self.failed_sessions.add(req.mooncake_session_id) + logger.error( + f"Session {req.mooncake_session_id} failed." + ) + self.record_failure( + kv_chunk.room, + f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}", + ) + self.update_status(kv_chunk.room, KVPoll.Failed) + self.sync_status_to_decode_endpoint( + req.endpoint, req.dst_port, req.room, KVPoll.Failed + ) + break + + if kv_chunk.is_last: + # Only the last chunk we need to send the aux data + ret = self.send_aux( + req.mooncake_session_id, + kv_chunk.prefill_aux_index, + self.decode_kv_args_table[ + req.mooncake_session_id + ].dst_aux_ptrs, + req.dst_aux_index, + ) + polls.append(True if ret == 0 else False) + dst_ranks_infos.append( + (req.endpoint, req.dst_port, req.room) + ) + + # Only sync status when all the dst ranks have received the kvcache + if len(polls) == req.required_dst_info_num: + status = KVPoll.Success if all(polls) else KVPoll.Failed + self.update_status(req.room, status) + for endpoint, dst_port, room in dst_ranks_infos: + self.sync_status_to_decode_endpoint( + endpoint, dst_port, room, status + ) + else: + # Dummy request means the decode instance is not used, so its status can be marked as success directly + # Dummy request does not need to sync status to decode endpoint + if kv_chunk.is_last and req.room in self.request_status: + self.update_status(req.room, KVPoll.Success) + + if ( + kv_chunk.room not in self.request_status + or self.check_status(kv_chunk.room) == KVPoll.Success + ): + if kv_chunk.room in self.transfer_infos: + self.transfer_infos.pop(kv_chunk.room) + + except queue.Empty: + continue + except Exception as e: + # NOTE(shangming): Remove this when we make sure the transfer thread is bug-free + raise RuntimeError( + f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." + ) + def start_prefill_thread(self): self.rank_port = get_free_port() self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") @@ -335,134 +469,7 @@ class MooncakeKVManager(BaseKVManager): if len(self.transfer_infos[room]) == required_dst_info_num: self.update_status(room, KVPoll.WaitingForInput) - def transfer_thread(): - # TODO: Shall we use KVPoll.Transferring state? - while True: - try: - kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01) - reqs_to_be_processed = ( - self.transfer_infos[kv_chunk.room].values() - if kv_chunk.room in self.transfer_infos - else [] - ) - polls = [] - dst_ranks_infos = [] - for req in reqs_to_be_processed: - if not req.is_dummy: - # Early exit if the request has failed - with self.session_lock: - if req.mooncake_session_id in self.failed_sessions: - self.record_failure( - kv_chunk.room, - f"Decode instance could be dead, remote mooncake session {req.mooncake_session_id} is not alive", - ) - self.update_status(kv_chunk.room, KVPoll.Failed) - self.sync_status_to_decode_endpoint( - req.endpoint, - req.dst_port, - req.room, - KVPoll.Failed, - ) - break - - chunked_dst_kv_indice = req.dst_kv_indices[ - kv_chunk.index_slice - ] - - # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices - # is mismatched with the dst_kv_indices when page size > 1, this should never happen. - if len(chunked_dst_kv_indice) < len( - kv_chunk.prefill_kv_indices - ): - kv_chunk.prefill_kv_indices = ( - kv_chunk.prefill_kv_indices[ - len(chunked_dst_kv_indice) - ] - ) - logger.warning( - f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" - ) - - ret = self.send_kvcache( - req.mooncake_session_id, - kv_chunk.prefill_kv_indices, - self.decode_kv_args_table[ - req.mooncake_session_id - ].dst_kv_ptrs, - chunked_dst_kv_indice, - ) - if ret != 0: - with self.session_lock: - self.session_failures[req.mooncake_session_id] += 1 - # Failures should never happen if the session is not dead, if the session fails once, mark it as failed - if ( - self.session_failures[req.mooncake_session_id] - >= 1 - ): - self.failed_sessions.add( - req.mooncake_session_id - ) - logger.error( - f"Session {req.mooncake_session_id} failed." - ) - self.record_failure( - kv_chunk.room, - f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}", - ) - self.update_status(kv_chunk.room, KVPoll.Failed) - self.sync_status_to_decode_endpoint( - req.endpoint, req.dst_port, req.room, KVPoll.Failed - ) - break - - if kv_chunk.is_last: - # Only the last chunk we need to send the aux data - ret = self.send_aux( - req.mooncake_session_id, - kv_chunk.prefill_aux_index, - self.decode_kv_args_table[ - req.mooncake_session_id - ].dst_aux_ptrs, - req.dst_aux_index, - ) - polls.append(True if ret == 0 else False) - dst_ranks_infos.append( - (req.endpoint, req.dst_port, req.room) - ) - - # Only sync status when all the dst ranks have received the kvcache - if len(polls) == req.required_dst_info_num: - status = ( - KVPoll.Success if all(polls) else KVPoll.Failed - ) - self.update_status(req.room, status) - for endpoint, dst_port, room in dst_ranks_infos: - self.sync_status_to_decode_endpoint( - endpoint, dst_port, room, status - ) - else: - # Dummy request means the decode instance is not used, so its status can be marked as success directly - # Dummy request does not need to sync status to decode endpoint - if kv_chunk.is_last and req.room in self.request_status: - self.update_status(req.room, KVPoll.Success) - - if ( - kv_chunk.room not in self.request_status - or self.check_status(kv_chunk.room) == KVPoll.Success - ): - if kv_chunk.room in self.transfer_infos: - self.transfer_infos.pop(kv_chunk.room) - - except queue.Empty: - continue - except Exception as e: - # NOTE(shangming): Remove this when we make sure the transfer thread is bug-free - raise RuntimeError( - f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." - ) - threading.Thread(target=bootstrap_thread).start() - threading.Thread(target=transfer_thread).start() def start_decode_thread(self): self.rank_port = get_free_port() @@ -555,7 +562,14 @@ class MooncakeKVManager(BaseKVManager): ) return - self.transfer_queue.put( + # NOTE(shangming): sharding according to the dst_infos to make sure + # requests with the same dst_sessions will be added into the same + # queue, which enables early abort with failed sessions. + dst_infos = self.transfer_infos[bootstrap_room].keys() + session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos) + shard_idx = session_port_sum % len(self.transfer_queues) + + self.transfer_queues[shard_idx].put( TransferKVChunk( room=bootstrap_room, prefill_kv_indices=kv_indices, diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 8841d5f1a..db7dd3239 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses import os import random +import threading import warnings from collections import deque from enum import Enum @@ -281,6 +282,25 @@ class MetadataBuffers: ) +class FastQueue: + def __init__(self): + self._buf = deque() + self._cond = threading.Condition() + + def put(self, item): + with self._cond: + self._buf.append(item) + # wake up a thread of wait() + self._cond.notify() + + def get(self): + with self._cond: + # if queue is empty ,block until is notified() + while not self._buf: + self._cond.wait() + return self._buf.popleft() + + def group_concurrent_contiguous( src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: