diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 4d6ac1b6f..7982f7b63 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -361,7 +361,7 @@ class DecodeTransferQueue: indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): if poll == KVPoll.Failed: - error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" + error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" try: decode_req.kv_receiver.failure_exception() except Exception as e: @@ -409,7 +409,8 @@ class DecodeTransferQueue: : decode_req.req.top_logprobs_num ].tolist() ) - + if hasattr(decode_req.kv_receiver, "clear"): + decode_req.kv_receiver.clear() transferred_reqs.append(decode_req.req) indices_to_remove.add(i) elif poll in [ diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index 57c426f25..3b9aa62f7 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -9,6 +9,8 @@ import queue import socket import struct import threading +import time +from collections import defaultdict from functools import cache from typing import Dict, List, Optional, Tuple, Union @@ -51,6 +53,16 @@ def group_concurrent_contiguous( return src_groups, dst_groups +class KVTransferError(Exception): + def __init__(self, bootstrap_room: int, failure_reason: str): + super().__init__(failure_reason) + self.bootstrap_room = bootstrap_room + self.failure_reason = failure_reason + + def __str__(self): + return f"KVTransferError(bootstrap_room={self.bootstrap_room}): {self.failure_reason}" + + # prefill @dataclasses.dataclass class TransferKVChunk: @@ -153,13 +165,34 @@ class MooncakeKVManager(BaseKVManager): self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.start_prefill_thread() self._register_to_bootstrap() + self.session_failures = defaultdict(int) + self.failed_sessions = set() + self.session_lock = threading.Lock() # Determine the number of threads to use for kv sender cpu_count = os.cpu_count() self.executor = concurrent.futures.ThreadPoolExecutor( - min(cpu_count // 4, 16) + int( + os.getenv( + "DISAGGREGATION_THREAD_POOL_SIZE", + min(max(1, cpu_count // 8), 8), + ) + ) ) elif self.disaggregation_mode == DisaggregationMode.DECODE: + self.heartbeat_failures = {} + self.session_pool = defaultdict(requests.Session) + self.session_pool_lock = threading.Lock() + self.addr_to_rooms_tracker = defaultdict(list) + self.connection_lock = threading.Lock() + # Heartbeat interval should be at least 2 seconds + self.heartbeat_interval = max( + float(os.getenv("DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 + ) + # Heartbeat failure should be at least 1 + self.max_failures = max( + int(os.getenv("DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1 + ) self.start_decode_thread() self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.prefill_tp_size_table: Dict[str, int] = {} @@ -169,6 +202,9 @@ class MooncakeKVManager(BaseKVManager): f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) + self.failure_records: Dict[int, str] = {} + self.failure_lock = threading.Lock() + def register_buffer_to_engine(self): for kv_data_ptr, kv_data_len in zip( self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens @@ -235,8 +271,6 @@ class MooncakeKVManager(BaseKVManager): for future in concurrent.futures.as_completed(futures): status = future.result() if status != 0: - # Immediate shutdown on first error (existing tasks will finish) - self.executor.shutdown(wait=False) for f in futures: f.cancel() return status @@ -255,20 +289,20 @@ class MooncakeKVManager(BaseKVManager): self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len ) decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len - # TODO: mooncake transfer engine can do async transfer. Do async later - # Not sure about the amount of aux data, maybe transfer it by zmq is more effective status = self.engine.transfer_sync( mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len ) return status - def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int): + def sync_status_to_decode_endpoint( + self, remote: str, dst_port: int, room: int, status: int + ): if ":" in remote: remote = remote.split(":")[0] self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( [ str(room).encode("ascii"), - str(self.check_status(room)).encode("ascii"), + str(status).encode("ascii"), ] ) @@ -287,6 +321,11 @@ class MooncakeKVManager(BaseKVManager): self.decode_kv_args_table[mooncake_session_id] = ( KVArgsRegisterInfo.from_zmq(waiting_req_bytes) ) + with self.session_lock: + if mooncake_session_id in self.failed_sessions: + self.failed_sessions.remove(mooncake_session_id) + if mooncake_session_id in self.session_failures: + del self.session_failures[mooncake_session_id] logger.debug( f"Register KVArgs from {mooncake_session_id} successfully" ) @@ -309,17 +348,48 @@ class MooncakeKVManager(BaseKVManager): while True: try: kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01) - reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values() + reqs_to_be_processed = ( + self.transfer_infos[kv_chunk.room].values() + if kv_chunk.room in self.transfer_infos + else [] + ) polls = [] dst_ranks_infos = [] for req in reqs_to_be_processed: if not req.is_dummy: + # Early exit if the request has failed + with self.session_lock: + if req.mooncake_session_id in self.failed_sessions: + self.record_failure( + kv_chunk.room, + f"Decode instance could be dead, {req.mooncake_session_id} failed due to multiple errors", + ) + self.update_status(kv_chunk.room, KVPoll.Failed) + self.sync_status_to_decode_endpoint( + req.endpoint, + req.dst_port, + req.room, + KVPoll.Failed, + ) + break + chunked_dst_kv_indice = req.dst_kv_indices[ kv_chunk.index_slice ] - assert len(chunked_dst_kv_indice) == len( + + # NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices + # is mismatched with the dst_kv_indices when page size > 1, this should never happen. + if len(chunked_dst_kv_indice) < len( kv_chunk.prefill_kv_indices - ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + ): + kv_chunk.prefill_kv_indices = ( + kv_chunk.prefill_kv_indices[ + len(chunked_dst_kv_indice) + ] + ) + logger.warning( + f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" + ) ret = self.send_kvcache( req.mooncake_session_id, @@ -330,11 +400,28 @@ class MooncakeKVManager(BaseKVManager): chunked_dst_kv_indice, ) if ret != 0: + with self.session_lock: + self.session_failures[req.mooncake_session_id] += 1 + # Failures should never happen if the session is not dead, if the session fails once, mark it as failed + if ( + self.session_failures[req.mooncake_session_id] + >= 1 + ): + self.failed_sessions.add( + req.mooncake_session_id + ) + logger.error( + f"Session {req.mooncake_session_id} failed." + ) + self.record_failure( + kv_chunk.room, + f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}", + ) self.update_status(kv_chunk.room, KVPoll.Failed) self.sync_status_to_decode_endpoint( - req.endpoint, req.dst_port, req.room + req.endpoint, req.dst_port, req.room, KVPoll.Failed ) - continue + break if kv_chunk.is_last: # Only the last chunk we need to send the aux data @@ -353,25 +440,33 @@ class MooncakeKVManager(BaseKVManager): # Only sync status when all the dst ranks have received the kvcache if len(polls) == req.required_dst_info_num: - self.update_status( - req.room, - KVPoll.Success if all(polls) else KVPoll.Failed, + status = ( + KVPoll.Success if all(polls) else KVPoll.Failed ) + self.update_status(req.room, status) for endpoint, dst_port, room in dst_ranks_infos: self.sync_status_to_decode_endpoint( - endpoint, dst_port, room + endpoint, dst_port, room, status ) else: # Dummy request means the decode instance is not used, so its status can be marked as success directly # Dummy request does not need to sync status to decode endpoint - if kv_chunk.is_last: + if kv_chunk.is_last and req.room in self.request_status: self.update_status(req.room, KVPoll.Success) - if self.check_status(kv_chunk.room) == KVPoll.Success: + if ( + kv_chunk.room not in self.request_status + or self.check_status(kv_chunk.room) == KVPoll.Success + ): self.transfer_infos.pop(kv_chunk.room) except queue.Empty: continue + except Exception as e: + # NOTE(shangming): Remove this when we make sure the transfer thread is bug-free + raise RuntimeError( + f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." + ) threading.Thread(target=bootstrap_thread).start() threading.Thread(target=transfer_thread).start() @@ -385,9 +480,67 @@ class MooncakeKVManager(BaseKVManager): (bootstrap_room, status) = self.server_socket.recv_multipart() status = int(status.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii")) + if status == KVPoll.Failed: + self.record_failure( + bootstrap_room, + f"Failed to get kvcache from prefill instance, it might be dead", + ) self.update_status(bootstrap_room, status) + def heartbeat_checker(): + while True: + time.sleep(self.heartbeat_interval) + with self.connection_lock: + addresses = list(self.prefill_dp_size_table.keys()) + + for bootstrap_addr in addresses: + session = None + try: + with self.session_pool_lock: + session = self.session_pool[bootstrap_addr] + response = session.get( + f"http://{bootstrap_addr}/health", + timeout=(2, 3), + headers={"Connection": "keep-alive"}, + ) + if response.status_code == 200: + self.heartbeat_failures[bootstrap_addr] = 0 + + for bootstrap_room in self.addr_to_rooms_tracker[ + bootstrap_addr + ]: + # Remove KVPoll.Success requests from the map + if bootstrap_room not in self.request_status: + self.addr_to_rooms_tracker[bootstrap_addr].remove( + bootstrap_room + ) + else: + logger.info( + f"Attempting to reconnect to {bootstrap_addr}..." + ) + self.heartbeat_failures[bootstrap_addr] = ( + self.heartbeat_failures.get(bootstrap_addr, 0) + 1 + ) + with self.session_pool_lock: + if bootstrap_addr in self.session_pool: + del self.session_pool[bootstrap_addr] + except Exception: + logger.info(f"Attempting to reconnect to {bootstrap_addr}...") + self.heartbeat_failures[bootstrap_addr] = ( + self.heartbeat_failures.get(bootstrap_addr, 0) + 1 + ) + + if ( + self.heartbeat_failures.get(bootstrap_addr, 0) + >= self.max_failures + ): + self._handle_node_failure(bootstrap_addr) + with self.session_pool_lock: + if bootstrap_addr in self.session_pool: + del self.session_pool[bootstrap_addr] + threading.Thread(target=decode_thread).start() + threading.Thread(target=heartbeat_checker).start() def add_transfer_request( self, @@ -400,6 +553,15 @@ class MooncakeKVManager(BaseKVManager): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) + if ( + bootstrap_room not in self.request_status + or self.check_status(bootstrap_room) == KVPoll.Failed + ): + logger.debug( + "Request with bootstrap_room=%s already failed", bootstrap_room + ) + return + self.transfer_queue.put( TransferKVChunk( room=bootstrap_room, @@ -418,10 +580,17 @@ class MooncakeKVManager(BaseKVManager): if bootstrap_room not in self.request_status: self.request_status[bootstrap_room] = status else: - # NOTE: The prefill engine could recv bootstrapping first - self.request_status[bootstrap_room] = max( - self.request_status[bootstrap_room], status - ) + # NOTE: status is only allowed to be incremented unless it is KVPoll.Failed + if status == KVPoll.Failed: + self.request_status[bootstrap_room] = KVPoll.Failed + else: + self.request_status[bootstrap_room] = max( + self.request_status[bootstrap_room], status + ) + + def record_failure(self, bootstrap_room: int, failure_reason: str): + with self.failure_lock: + self.failure_records[bootstrap_room] = failure_reason def get_session_id(self): return self.engine.get_session_id() @@ -445,15 +614,51 @@ class MooncakeKVManager(BaseKVManager): } try: - response = requests.put(url, json=payload) + response = requests.put(url, json=payload, timeout=5) if response.status_code == 200: logger.debug("Prefill successfully registered to bootstrap server.") else: logger.error( - f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" + f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}" ) except Exception as e: - logger.error(f"Prefill Failed to register to bootstrap server: {e}") + logger.error( + f"Prefill instance failed to register to bootstrap server: {e}" + ) + + def _handle_node_failure(self, failed_bootstrap_addr): + with self.connection_lock: + keys_to_remove = [ + k for k in self.connection_pool if k.startswith(failed_bootstrap_addr) + ] + for k in keys_to_remove: + del self.connection_pool[k] + if failed_bootstrap_addr in self.prefill_tp_size_table: + del self.prefill_tp_size_table[failed_bootstrap_addr] + if failed_bootstrap_addr in self.prefill_dp_size_table: + del self.prefill_dp_size_table[failed_bootstrap_addr] + + possible_affected_rooms = self.addr_to_rooms_tracker.get( + failed_bootstrap_addr, [] + ) + del self.addr_to_rooms_tracker[failed_bootstrap_addr] + + # Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed + affected_rooms = [] + for room in possible_affected_rooms: + if ( + room in self.request_status + and self.check_status(room) != KVPoll.Success + ): + self.record_failure( + room, + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr})", + ) + self.update_status(room, KVPoll.Failed) + affected_rooms.append(room) + logger.error( + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests" + ) class MooncakeKVSender(BaseKVSender): @@ -466,7 +671,7 @@ class MooncakeKVSender(BaseKVSender): self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.aux_index = None self.bootstrap_server_url = bootstrap_addr - self.session_id = self.kv_mgr.get_session_id() + self.conclude_state = None # inner state self.curr_idx = 0 @@ -496,11 +701,30 @@ class MooncakeKVSender(BaseKVSender): ) def poll(self) -> KVPoll: - return self.kv_mgr.check_status(self.bootstrap_room) + if self.conclude_state is None: + status = self.kv_mgr.check_status(self.bootstrap_room) + if status in (KVPoll.Success, KVPoll.Failed): + self.conclude_state = status + + return status + else: + return self.conclude_state + + def clear(self) -> None: + self.kv_mgr.request_status.pop(self.bootstrap_room) def failure_exception(self): - # TODO: raise a real exception - raise Exception("Fake KVSender Exception") + self.clear() + + # Explicitly set the status to failure since this request has failed in another rank + if self.conclude_state is None: + self.conclude_state = KVPoll.Failed + + with self.kv_mgr.failure_lock: + failure_reason = self.kv_mgr.failure_records.pop( + self.bootstrap_room, "Failed due to an unknown reason from another rank" + ) + raise KVTransferError(self.bootstrap_room, failure_reason) class MooncakeKVReceiver(BaseKVReceiver): @@ -519,17 +743,24 @@ class MooncakeKVReceiver(BaseKVReceiver): self.bootstrap_addr = bootstrap_addr self.kv_mgr = mgr self.session_id = self.kv_mgr.get_session_id() - self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) + self.conclude_state = None if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: self.prefill_tp_size, self.prefill_dp_size = ( - self._get_prefill_dp_size_from_server() + self._get_prefill_parallel_info_from_server() ) if self.prefill_tp_size is None or self.prefill_dp_size is None: - logger.error( - f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}" + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}", ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return else: + logger.debug( + f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}" + ) self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( self.prefill_tp_size ) @@ -587,7 +818,7 @@ class MooncakeKVReceiver(BaseKVReceiver): self.target_tp_rank = self.target_tp_ranks[0] self.required_dst_info_num = 1 - self.target_dp_group = bootstrap_room % self.prefill_dp_size + self.target_dp_group = self.bootstrap_room % self.prefill_dp_size # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank bootstrap_key = ( @@ -607,32 +838,37 @@ class MooncakeKVReceiver(BaseKVReceiver): target_tp_rank == self.target_tp_rank or self.target_tp_rank is None ) + logger.debug( + f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" + ) bootstrap_infos.append(bootstrap_info) else: - logger.error( - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}" + self.kv_mgr.record_failure( + self.bootstrap_room, + f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}", ) - self.bootstrap_infos = bootstrap_infos + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed) + return - if len(self.bootstrap_infos) == 0: - logger.error( - f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) - else: - self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos - # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server - self._register_kv_args() + self.bootstrap_infos = bootstrap_infos + self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos + + # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server + self._register_kv_args() else: self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] assert len(self.bootstrap_infos) > 0 - self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) + self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].append( + self.bootstrap_room + ) + self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput) def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): """Fetch the bootstrap info from the bootstrap server.""" try: url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" - response = requests.get(url) + response = requests.get(url, timeout=5) if response.status_code == 200: bootstrap_info = response.json() return bootstrap_info @@ -645,7 +881,7 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - def _get_prefill_dp_size_from_server(self) -> int: + def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]: """Fetch the prefill parallel info from the bootstrap server.""" try: url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" @@ -659,10 +895,10 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error( f"Failed to get prefill parallel info: {response.status_code}, {response.text}" ) - return None + return None, None except Exception as e: logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") - return None + return None, None def _register_kv_args(self): for bootstrap_info in self.bootstrap_infos: @@ -704,9 +940,6 @@ class MooncakeKVReceiver(BaseKVReceiver): self.prefill_server_url = ( f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" ) - logger.debug( - f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" - ) is_dummy = bootstrap_info["is_dummy"] sock, lock = self._connect("tcp://" + self.prefill_server_url) @@ -724,11 +957,30 @@ class MooncakeKVReceiver(BaseKVReceiver): ) def poll(self) -> KVPoll: - return self.kv_mgr.check_status(self.bootstrap_room) + if self.conclude_state is None: + status = self.kv_mgr.check_status(self.bootstrap_room) + if status in (KVPoll.Success, KVPoll.Failed): + self.conclude_state = status + + return status + else: + return self.conclude_state + + def clear(self) -> None: + self.kv_mgr.request_status.pop(self.bootstrap_room) def failure_exception(self): - # TODO: raise a real exception - raise Exception("Fake KVReceiver Exception") + self.clear() + + # Explicitly set the status to failure since this request has failed in another rank + if self.conclude_state is None: + self.conclude_state = KVPoll.Failed + + with self.kv_mgr.failure_lock: + failure_reason = self.kv_mgr.failure_records.pop( + self.bootstrap_room, "Failed due to an unknown reason from another rank" + ) + raise KVTransferError(self.bootstrap_room, failure_reason) class MooncakeKVBootstrapServer(BaseKVBootstrapServer): @@ -752,6 +1004,10 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): def _setup_routes(self): self.app.router.add_route("*", "/route", self._handle_route) + self.app.router.add_get("/health", self._handle_health_check) + + async def _handle_health_check(self, request): + return web.Response(text="OK", status=200) async def _handle_route(self, request: web.Request): method = request.method @@ -780,14 +1036,14 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self.dp_size = dp_size tp_size_per_dp_rank = tp_size // dp_size - if self.tp_size_per_dp_rank == None: + if self.tp_size_per_dp_rank is None: self.tp_size_per_dp_rank = tp_size_per_dp_rank - # Add lock to make sure thread-safe if role == "Prefill": dp_group = engine_rank // tp_size_per_dp_rank tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank + # Add lock to make sure thread-safe async with self.lock: if dp_group not in self.prefill_port_table: self.prefill_port_table[dp_group] = {} @@ -797,7 +1053,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): "rank_port": rank_port, } logger.debug( - f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" + f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" ) return web.Response(text="OK", status=200) @@ -833,7 +1089,11 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - self._runner = web.AppRunner(self.app) + access_log = None + if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG: + access_log = self.app.logger + + self._runner = web.AppRunner(self.app, access_log=access_log) self._loop.run_until_complete(self._runner.setup()) site = web.TCPSite(self._runner, port=self.port) diff --git a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py index 1f3c44bcc..5643af70b 100644 --- a/python/sglang/srt/disaggregation/mooncake/transfer_engine.py +++ b/python/sglang/srt/disaggregation/mooncake/transfer_engine.py @@ -30,16 +30,24 @@ class MooncakeTransferEngine: self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" def register(self, ptr, length): - ret_value = self.engine.register_memory(ptr, length) + try: + ret_value = self.engine.register_memory(ptr, length) + except Exception: + # Mark register as failed + ret_value = -1 + if ret_value != 0: - logger.error("Mooncake memory registration failed.") - raise RuntimeError("Mooncake memory registration failed.") + logger.debug("Mooncake memory registration %s failed.", ptr) def deregister(self, ptr): - ret_value = self.engine.unregister_memory(ptr) + try: + ret_value = self.engine.unregister_memory(ptr) + except Exception: + # Mark deregister as failed + ret_value = -1 + if ret_value != 0: - logger.error("Mooncake memory deregistration failed.") - raise RuntimeError("Mooncake memory deregistration failed.") + logger.debug("Mooncake memory deregistration %s failed.", ptr) def initialize( self, @@ -61,18 +69,26 @@ class MooncakeTransferEngine: self, session_id: str, buffer: int, peer_buffer_address: int, length: int ) -> int: """Synchronously transfer data to the specified address.""" - # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair - # later: based on the cached queue pair to send data - ret = self.engine.transfer_sync_write( - session_id, buffer, peer_buffer_address, length - ) - if ret < 0: - logger.error("Mooncake Transfer Engine Return Error.") - raise RuntimeError("Mooncake Transfer Engine Return Error.") - return ret + try: + # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair + # later: based on the cached queue pair to send data + ret = self.engine.transfer_sync_write( + session_id, buffer, peer_buffer_address, length + ) + except Exception: + # Mark transfer request as failed + ret = -1 - def get_localhost(self): - return self.hostname + if ret < 0: + # Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped. + logger.debug( + "Failed to transfer data from %s to %s - %s.", + buffer, + session_id, + peer_buffer_address, + ) + + return ret def get_session_id(self): return self.session_id diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 8b325811e..5a416e896 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin: self.tree_cache.cache_finished_req(req) # unlock the tree req.finished_reason = FINISH_LENGTH(length=0) # FIXME: clean up req's data in transfer engine + if hasattr(req.disagg_kv_sender, "clear"): + req.disagg_kv_sender.clear() done_reqs.append(req) elif poll == KVPoll.Failed: error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"