Feat/add heartbeat mechanism for nixl conn (#10222)
Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
@@ -2,14 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import requests
|
||||
|
||||
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
||||
from sglang.srt.disaggregation.common.conn import (
|
||||
@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
|
||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_int_env_var
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -102,8 +106,14 @@ class TransferStatus:
|
||||
def is_done(self):
|
||||
if self.num_kvs_expected is None:
|
||||
return False
|
||||
# Check for failure state
|
||||
if self.num_kvs_expected == -1:
|
||||
return True # Failed transfers are considered "done"
|
||||
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
||||
|
||||
def is_failed(self):
|
||||
return self.num_kvs_expected == -1
|
||||
|
||||
|
||||
class NixlKVManager(CommonKVManager):
|
||||
def __init__(
|
||||
@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
|
||||
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
||||
TransferStatus
|
||||
)
|
||||
self.heartbeat_failures = {}
|
||||
self.session_pool = defaultdict(requests.Session)
|
||||
self.session_pool_lock = threading.Lock()
|
||||
self.addr_to_rooms_tracker = defaultdict(set)
|
||||
self.connection_lock = threading.Lock()
|
||||
|
||||
# Heartbeat interval should be at least 2 seconds
|
||||
self.heartbeat_interval = max(
|
||||
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
||||
)
|
||||
# Heartbeat failure should be at least 1
|
||||
self.max_failures = max(
|
||||
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
||||
)
|
||||
self._start_heartbeat_checker_thread()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||
)
|
||||
|
||||
def _start_heartbeat_checker_thread(self):
|
||||
"""
|
||||
Start the heartbeat checker thread for Decode worker.
|
||||
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
||||
"""
|
||||
|
||||
def heartbeat_checker():
|
||||
while True:
|
||||
time.sleep(self.heartbeat_interval)
|
||||
with self.connection_lock:
|
||||
addresses = list(self.prefill_dp_size_table.keys())
|
||||
|
||||
for bootstrap_addr in addresses:
|
||||
session = None
|
||||
try:
|
||||
with self.session_pool_lock:
|
||||
session = self.session_pool[bootstrap_addr]
|
||||
response = session.get(
|
||||
f"http://{bootstrap_addr}/health",
|
||||
timeout=(2, 3),
|
||||
headers={"Connection": "keep-alive"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
self.heartbeat_failures[bootstrap_addr] = 0
|
||||
|
||||
current_rooms = self.addr_to_rooms_tracker[
|
||||
bootstrap_addr
|
||||
].copy()
|
||||
|
||||
for bootstrap_room in current_rooms:
|
||||
# Remove successful transfers from the tracker
|
||||
if bootstrap_room not in self.transfer_statuses:
|
||||
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
||||
bootstrap_room
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Attempting to reconnect to {bootstrap_addr}..."
|
||||
)
|
||||
self.heartbeat_failures[bootstrap_addr] = (
|
||||
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
||||
)
|
||||
with self.session_pool_lock:
|
||||
if bootstrap_addr in self.session_pool:
|
||||
del self.session_pool[bootstrap_addr]
|
||||
except Exception:
|
||||
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
||||
self.heartbeat_failures[bootstrap_addr] = (
|
||||
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
||||
)
|
||||
|
||||
if (
|
||||
self.heartbeat_failures.get(bootstrap_addr, 0)
|
||||
>= self.max_failures
|
||||
):
|
||||
self._handle_node_failure(bootstrap_addr)
|
||||
with self.session_pool_lock:
|
||||
if bootstrap_addr in self.session_pool:
|
||||
del self.session_pool[bootstrap_addr]
|
||||
|
||||
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
||||
|
||||
def _handle_node_failure(self, failed_bootstrap_addr):
|
||||
"""Handle failure of a prefill node."""
|
||||
with self.connection_lock:
|
||||
keys_to_remove = [
|
||||
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
||||
]
|
||||
for k in keys_to_remove:
|
||||
del self.connection_pool[k]
|
||||
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
||||
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
||||
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
||||
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
||||
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
||||
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
||||
|
||||
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
||||
failed_bootstrap_addr, []
|
||||
)
|
||||
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
||||
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
||||
|
||||
# Mark all pending transfers associated with the failed node as failed
|
||||
affected_rooms = []
|
||||
for room in possible_affected_rooms:
|
||||
if (
|
||||
room in self.transfer_statuses
|
||||
and not self.transfer_statuses[room].is_done()
|
||||
):
|
||||
# Mark the transfer as failed by setting a special state
|
||||
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
||||
affected_rooms.append(room)
|
||||
|
||||
logger.error(
|
||||
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
||||
f"{len(affected_rooms)} transfers affected"
|
||||
)
|
||||
|
||||
def check_status(self, bootstrap_room: int):
|
||||
return self.request_status[bootstrap_room]
|
||||
|
||||
@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
self.conclude_state = None
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
||||
|
||||
# Track this room with its bootstrap address for heartbeat monitoring
|
||||
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
||||
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
||||
self.bootstrap_room
|
||||
)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
logger.debug(
|
||||
@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
|
||||
self.kv_mgr.update_transfer_status()
|
||||
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
||||
self.conclude_state = KVPoll.Success
|
||||
# Check if the transfer failed
|
||||
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
||||
self.conclude_state = KVPoll.Failed
|
||||
logger.error(
|
||||
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
||||
)
|
||||
else:
|
||||
self.conclude_state = KVPoll.Success
|
||||
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
||||
return KVPoll.Success # type: ignore
|
||||
return self.conclude_state # type: ignore
|
||||
return KVPoll.WaitingForInput # type: ignore
|
||||
|
||||
def _register_kv_args(self):
|
||||
|
||||
Reference in New Issue
Block a user