Feat/add heartbeat mechanism for nixl conn (#10222)
Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
@@ -2,14 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import struct
|
import struct
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Set
|
from typing import Dict, List, Optional, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
||||||
from sglang.srt.disaggregation.common.conn import (
|
from sglang.srt.disaggregation.common.conn import (
|
||||||
@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
|
|||||||
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
from sglang.srt.utils import get_int_env_var
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -102,8 +106,14 @@ class TransferStatus:
|
|||||||
def is_done(self):
|
def is_done(self):
|
||||||
if self.num_kvs_expected is None:
|
if self.num_kvs_expected is None:
|
||||||
return False
|
return False
|
||||||
|
# Check for failure state
|
||||||
|
if self.num_kvs_expected == -1:
|
||||||
|
return True # Failed transfers are considered "done"
|
||||||
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
||||||
|
|
||||||
|
def is_failed(self):
|
||||||
|
return self.num_kvs_expected == -1
|
||||||
|
|
||||||
|
|
||||||
class NixlKVManager(CommonKVManager):
|
class NixlKVManager(CommonKVManager):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
|
|||||||
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
||||||
TransferStatus
|
TransferStatus
|
||||||
)
|
)
|
||||||
|
self.heartbeat_failures = {}
|
||||||
|
self.session_pool = defaultdict(requests.Session)
|
||||||
|
self.session_pool_lock = threading.Lock()
|
||||||
|
self.addr_to_rooms_tracker = defaultdict(set)
|
||||||
|
self.connection_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Heartbeat interval should be at least 2 seconds
|
||||||
|
self.heartbeat_interval = max(
|
||||||
|
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
||||||
|
)
|
||||||
|
# Heartbeat failure should be at least 1
|
||||||
|
self.max_failures = max(
|
||||||
|
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
||||||
|
)
|
||||||
|
self._start_heartbeat_checker_thread()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _start_heartbeat_checker_thread(self):
|
||||||
|
"""
|
||||||
|
Start the heartbeat checker thread for Decode worker.
|
||||||
|
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def heartbeat_checker():
|
||||||
|
while True:
|
||||||
|
time.sleep(self.heartbeat_interval)
|
||||||
|
with self.connection_lock:
|
||||||
|
addresses = list(self.prefill_dp_size_table.keys())
|
||||||
|
|
||||||
|
for bootstrap_addr in addresses:
|
||||||
|
session = None
|
||||||
|
try:
|
||||||
|
with self.session_pool_lock:
|
||||||
|
session = self.session_pool[bootstrap_addr]
|
||||||
|
response = session.get(
|
||||||
|
f"http://{bootstrap_addr}/health",
|
||||||
|
timeout=(2, 3),
|
||||||
|
headers={"Connection": "keep-alive"},
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
self.heartbeat_failures[bootstrap_addr] = 0
|
||||||
|
|
||||||
|
current_rooms = self.addr_to_rooms_tracker[
|
||||||
|
bootstrap_addr
|
||||||
|
].copy()
|
||||||
|
|
||||||
|
for bootstrap_room in current_rooms:
|
||||||
|
# Remove successful transfers from the tracker
|
||||||
|
if bootstrap_room not in self.transfer_statuses:
|
||||||
|
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
||||||
|
bootstrap_room
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Attempting to reconnect to {bootstrap_addr}..."
|
||||||
|
)
|
||||||
|
self.heartbeat_failures[bootstrap_addr] = (
|
||||||
|
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
||||||
|
)
|
||||||
|
with self.session_pool_lock:
|
||||||
|
if bootstrap_addr in self.session_pool:
|
||||||
|
del self.session_pool[bootstrap_addr]
|
||||||
|
except Exception:
|
||||||
|
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
||||||
|
self.heartbeat_failures[bootstrap_addr] = (
|
||||||
|
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.heartbeat_failures.get(bootstrap_addr, 0)
|
||||||
|
>= self.max_failures
|
||||||
|
):
|
||||||
|
self._handle_node_failure(bootstrap_addr)
|
||||||
|
with self.session_pool_lock:
|
||||||
|
if bootstrap_addr in self.session_pool:
|
||||||
|
del self.session_pool[bootstrap_addr]
|
||||||
|
|
||||||
|
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
||||||
|
|
||||||
|
def _handle_node_failure(self, failed_bootstrap_addr):
|
||||||
|
"""Handle failure of a prefill node."""
|
||||||
|
with self.connection_lock:
|
||||||
|
keys_to_remove = [
|
||||||
|
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
||||||
|
]
|
||||||
|
for k in keys_to_remove:
|
||||||
|
del self.connection_pool[k]
|
||||||
|
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
||||||
|
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
||||||
|
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
||||||
|
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
||||||
|
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
||||||
|
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
||||||
|
|
||||||
|
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
||||||
|
failed_bootstrap_addr, []
|
||||||
|
)
|
||||||
|
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
||||||
|
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
||||||
|
|
||||||
|
# Mark all pending transfers associated with the failed node as failed
|
||||||
|
affected_rooms = []
|
||||||
|
for room in possible_affected_rooms:
|
||||||
|
if (
|
||||||
|
room in self.transfer_statuses
|
||||||
|
and not self.transfer_statuses[room].is_done()
|
||||||
|
):
|
||||||
|
# Mark the transfer as failed by setting a special state
|
||||||
|
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
||||||
|
affected_rooms.append(room)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
||||||
|
f"{len(affected_rooms)} transfers affected"
|
||||||
|
)
|
||||||
|
|
||||||
def check_status(self, bootstrap_room: int):
|
def check_status(self, bootstrap_room: int):
|
||||||
return self.request_status[bootstrap_room]
|
return self.request_status[bootstrap_room]
|
||||||
|
|
||||||
@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
self.conclude_state = None
|
self.conclude_state = None
|
||||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
||||||
|
|
||||||
|
# Track this room with its bootstrap address for heartbeat monitoring
|
||||||
|
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
||||||
|
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
||||||
|
self.bootstrap_room
|
||||||
|
)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
|
|
||||||
self.kv_mgr.update_transfer_status()
|
self.kv_mgr.update_transfer_status()
|
||||||
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
||||||
self.conclude_state = KVPoll.Success
|
# Check if the transfer failed
|
||||||
|
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
||||||
|
self.conclude_state = KVPoll.Failed
|
||||||
|
logger.error(
|
||||||
|
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conclude_state = KVPoll.Success
|
||||||
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
||||||
return KVPoll.Success # type: ignore
|
return self.conclude_state # type: ignore
|
||||||
return KVPoll.WaitingForInput # type: ignore
|
return KVPoll.WaitingForInput # type: ignore
|
||||||
|
|
||||||
def _register_kv_args(self):
|
def _register_kv_args(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user