[Bugfix] Prevent engine hang during KVCacheSendingThread startup (#4754)

Previously, if the KVCacheSendingThread couldn't create a socket because
of port conflicts or other problems, the main thread would wait
endlessly for the ready_event signal, causing the entire engine
initialization to freeze. This update fixes the issue by adding timeouts
for thread startup and handling unexpected thread exits, so the
initialization process no longer gets stuck indefinitely.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-11 18:39:25 +08:00
committed by GitHub
parent 18221c0e1d
commit 3fade30275

View File

@@ -207,23 +207,30 @@ class KVCacheSendingThread(threading.Thread):
def run(self): def run(self):
"""Run the thread to handle KV cache transfer requests.""" """Run the thread to handle KV cache transfer requests."""
try:
encoder = msgspec.msgpack.Encoder() # Listen for new requests for metadata. NOTE(rob): we need each rank
encoded_data = encoder.encode(self.metadata) # to have a unique port. This hack to keeps us moving. We will
size_in_bytes = len(encoded_data) # switch when moving to etcd or where we have a single ZMQ socket in
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", # the scheduler.
str(size_in_bytes))
# Listen for new requests for metadata.
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler.
device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size
handshake_port = self.side_channel_port + device_index handshake_port = self.side_channel_port + device_index
path = make_zmq_path("tcp", self.side_channel_host, handshake_port) path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path) logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
self.ready_event.set() self.ready_event.set()
self.run_busy_loop(sock)
except Exception as e:
logger.error("Mooncake KVCacheSendingThread exception: %s",
e,
exc_info=True)
def run_busy_loop(self, sock: zmq.Socket): # type: ignore
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(self.metadata)
size_in_bytes = len(encoded_data)
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes",
str(size_in_bytes))
decoder = msgspec.msgpack.Decoder(type=tuple) decoder = msgspec.msgpack.Decoder(type=tuple)
while True: while True:
try: try:
@@ -242,8 +249,7 @@ class KVCacheSendingThread(threading.Thread):
if msg[0] == GET_META_MSG: if msg[0] == GET_META_MSG:
sock.send_multipart((identity, b"", encoded_data)) sock.send_multipart((identity, b"", encoded_data))
elif msg[0] == DONE_RECVING_MSG: elif msg[0] == DONE_RECVING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s", logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
msg[1])
request_id = msg[1] request_id = msg[1]
self.task_tracker.update_done_task_count(request_id) self.task_tracker.update_done_task_count(request_id)
# Acknowledge the request completion. # Acknowledge the request completion.
@@ -262,8 +268,7 @@ class KVCacheSendingThread(threading.Thread):
time.sleep(0.01) time.sleep(0.01)
else: else:
logger.error( logger.error(
"Connection listener got unexpected message %s", "Connection listener got unexpected message %s", msg)
msg)
except Exception as e: except Exception as e:
logger.error("Connection listener got exception %s: %s", logger.error("Connection listener got exception %s: %s",
type(e), e) type(e), e)
@@ -1151,7 +1156,18 @@ class MooncakeConnectorWorker:
self.block_len, ready_event, self.vllm_config, self.kv_caches, self.block_len, ready_event, self.vllm_config, self.kv_caches,
self._prefill_pp_layer_partition) self._prefill_pp_layer_partition)
self.kv_recv_thread.start() self.kv_recv_thread.start()
ready_event.wait()
start_wait_time = time.time()
thread = self.kv_send_thread if self.kv_role == 'kv_producer' else self.kv_recv_thread
assert thread is not None
while not ready_event.is_set():
if not thread.is_alive():
raise RuntimeError(
"KV Cache sending/receiving thread failed to start.")
if time.time() - start_wait_time > 5 * 60:
raise RuntimeError(
"Timeout waiting for KV Cache thread to be ready.")
time.sleep(3)
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = ( done_sending = (