diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 871bdcbfc..1579209a3 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -2,14 +2,17 @@ from __future__ import annotations import dataclasses import logging +import os import struct import threading +import time import uuid from collections import defaultdict from typing import Dict, List, Optional, Set import numpy as np import numpy.typing as npt +import requests from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll from sglang.srt.disaggregation.common.conn import ( @@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import ( from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import get_int_env_var logger = logging.getLogger(__name__) @@ -102,8 +106,14 @@ class TransferStatus: def is_done(self): if self.num_kvs_expected is None: return False + # Check for failure state + if self.num_kvs_expected == -1: + return True # Failed transfers are considered "done" return self.num_kvs_expected == len(self.received_kvs) and self.received_aux + def is_failed(self): + return self.num_kvs_expected == -1 + class NixlKVManager(CommonKVManager): def __init__( @@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager): self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) + self.heartbeat_failures = {} + self.session_pool = defaultdict(requests.Session) + self.session_pool_lock = threading.Lock() + self.addr_to_rooms_tracker = defaultdict(set) + self.connection_lock = threading.Lock() + + # Heartbeat interval should be at least 2 seconds + self.heartbeat_interval = max( + float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 + ) + # Heartbeat failure should be at least 1 + self.max_failures = max( + get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1 + ) + self._start_heartbeat_checker_thread() else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) + def _start_heartbeat_checker_thread(self): + """ + Start the heartbeat checker thread for Decode worker. + TODO (smor): unite nixl heartbeat checker with mooncake's. + """ + + def heartbeat_checker(): + while True: + time.sleep(self.heartbeat_interval) + with self.connection_lock: + addresses = list(self.prefill_dp_size_table.keys()) + + for bootstrap_addr in addresses: + session = None + try: + with self.session_pool_lock: + session = self.session_pool[bootstrap_addr] + response = session.get( + f"http://{bootstrap_addr}/health", + timeout=(2, 3), + headers={"Connection": "keep-alive"}, + ) + if response.status_code == 200: + self.heartbeat_failures[bootstrap_addr] = 0 + + current_rooms = self.addr_to_rooms_tracker[ + bootstrap_addr + ].copy() + + for bootstrap_room in current_rooms: + # Remove successful transfers from the tracker + if bootstrap_room not in self.transfer_statuses: + self.addr_to_rooms_tracker[bootstrap_addr].discard( + bootstrap_room + ) + else: + logger.info( + f"Attempting to reconnect to {bootstrap_addr}..." + ) + self.heartbeat_failures[bootstrap_addr] = ( + self.heartbeat_failures.get(bootstrap_addr, 0) + 1 + ) + with self.session_pool_lock: + if bootstrap_addr in self.session_pool: + del self.session_pool[bootstrap_addr] + except Exception: + logger.info(f"Attempting to reconnect to {bootstrap_addr}...") + self.heartbeat_failures[bootstrap_addr] = ( + self.heartbeat_failures.get(bootstrap_addr, 0) + 1 + ) + + if ( + self.heartbeat_failures.get(bootstrap_addr, 0) + >= self.max_failures + ): + self._handle_node_failure(bootstrap_addr) + with self.session_pool_lock: + if bootstrap_addr in self.session_pool: + del self.session_pool[bootstrap_addr] + + threading.Thread(target=heartbeat_checker, daemon=True).start() + + def _handle_node_failure(self, failed_bootstrap_addr): + """Handle failure of a prefill node.""" + with self.connection_lock: + keys_to_remove = [ + k for k in self.connection_pool if k.startswith(failed_bootstrap_addr) + ] + for k in keys_to_remove: + del self.connection_pool[k] + if failed_bootstrap_addr in self.prefill_tp_size_table: + del self.prefill_tp_size_table[failed_bootstrap_addr] + if failed_bootstrap_addr in self.prefill_dp_size_table: + del self.prefill_dp_size_table[failed_bootstrap_addr] + if failed_bootstrap_addr in self.prefill_pp_size_table: + del self.prefill_pp_size_table[failed_bootstrap_addr] + + possible_affected_rooms = self.addr_to_rooms_tracker.get( + failed_bootstrap_addr, [] + ) + if failed_bootstrap_addr in self.addr_to_rooms_tracker: + del self.addr_to_rooms_tracker[failed_bootstrap_addr] + + # Mark all pending transfers associated with the failed node as failed + affected_rooms = [] + for room in possible_affected_rooms: + if ( + room in self.transfer_statuses + and not self.transfer_statuses[room].is_done() + ): + # Mark the transfer as failed by setting a special state + self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure + affected_rooms.append(room) + + logger.error( + f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), " + f"{len(affected_rooms)} transfers affected" + ) + def check_status(self, bootstrap_room: int): return self.request_status[bootstrap_room] @@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver): self.conclude_state = None super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) + # Track this room with its bootstrap address for heartbeat monitoring + if hasattr(self.kv_mgr, "addr_to_rooms_tracker"): + self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add( + self.bootstrap_room + ) + def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: logger.debug( @@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver): self.kv_mgr.update_transfer_status() if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore - self.conclude_state = KVPoll.Success + # Check if the transfer failed + if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed(): + self.conclude_state = KVPoll.Failed + logger.error( + f"Transfer for room {self.bootstrap_room} failed due to node failure" + ) + else: + self.conclude_state = KVPoll.Success del self.kv_mgr.transfer_statuses[self.bootstrap_room] - return KVPoll.Success # type: ignore + return self.conclude_state # type: ignore return KVPoll.WaitingForInput # type: ignore def _register_kv_args(self):