[Bugfix] Prevent engine hang during KVCacheSendingThread startup (#4754)
Previously, if the KVCacheSendingThread couldn't create a socket because
of port conflicts or other problems, the main thread would wait
endlessly for the ready_event signal, causing the entire engine
initialization to freeze. This update fixes the issue by adding timeouts
for thread startup and handling unexpected thread exits, so the
initialization process no longer gets stuck indefinitely.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -207,66 +207,71 @@ class KVCacheSendingThread(threading.Thread):
|
|||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Run the thread to handle KV cache transfer requests."""
|
"""Run the thread to handle KV cache transfer requests."""
|
||||||
|
try:
|
||||||
|
# Listen for new requests for metadata. NOTE(rob): we need each rank
|
||||||
|
# to have a unique port. This hack to keeps us moving. We will
|
||||||
|
# switch when moving to etcd or where we have a single ZMQ socket in
|
||||||
|
# the scheduler.
|
||||||
|
device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size
|
||||||
|
handshake_port = self.side_channel_port + device_index
|
||||||
|
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
||||||
|
logger.info("Starting listening on path: %s", path)
|
||||||
|
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
|
||||||
|
self.ready_event.set()
|
||||||
|
self.run_busy_loop(sock)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Mooncake KVCacheSendingThread exception: %s",
|
||||||
|
e,
|
||||||
|
exc_info=True)
|
||||||
|
|
||||||
|
def run_busy_loop(self, sock: zmq.Socket): # type: ignore
|
||||||
encoder = msgspec.msgpack.Encoder()
|
encoder = msgspec.msgpack.Encoder()
|
||||||
encoded_data = encoder.encode(self.metadata)
|
encoded_data = encoder.encode(self.metadata)
|
||||||
size_in_bytes = len(encoded_data)
|
size_in_bytes = len(encoded_data)
|
||||||
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes",
|
logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes",
|
||||||
str(size_in_bytes))
|
str(size_in_bytes))
|
||||||
|
|
||||||
# Listen for new requests for metadata.
|
decoder = msgspec.msgpack.Decoder(type=tuple)
|
||||||
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
|
while True:
|
||||||
# us moving. We will switch when moving to etcd or where we have a
|
try:
|
||||||
# single ZMQ socket in the scheduler.
|
frames = sock.recv_multipart()
|
||||||
device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size
|
if len(frames) < 2:
|
||||||
handshake_port = self.side_channel_port + device_index
|
logger.error("Invalid message format: %s", frames)
|
||||||
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
continue
|
||||||
logger.info("Starting listening on path: %s", path)
|
|
||||||
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
|
|
||||||
self.ready_event.set()
|
|
||||||
decoder = msgspec.msgpack.Decoder(type=tuple)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
frames = sock.recv_multipart()
|
|
||||||
if len(frames) < 2:
|
|
||||||
logger.error("Invalid message format: %s", frames)
|
|
||||||
continue
|
|
||||||
|
|
||||||
identity = frames[0]
|
identity = frames[0]
|
||||||
payload = [f for f in frames[1:] if f != b""]
|
payload = [f for f in frames[1:] if f != b""]
|
||||||
if len(payload) != 1:
|
if len(payload) != 1:
|
||||||
logger.error("Invalid message format: %s", frames)
|
logger.error("Invalid message format: %s", frames)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg = decoder.decode(payload[0])
|
msg = decoder.decode(payload[0])
|
||||||
if msg[0] == GET_META_MSG:
|
if msg[0] == GET_META_MSG:
|
||||||
sock.send_multipart((identity, b"", encoded_data))
|
sock.send_multipart((identity, b"", encoded_data))
|
||||||
elif msg[0] == DONE_RECVING_MSG:
|
elif msg[0] == DONE_RECVING_MSG:
|
||||||
logger.debug("Got DONE_RECVING_MSG for request %s",
|
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
||||||
msg[1])
|
request_id = msg[1]
|
||||||
request_id = msg[1]
|
self.task_tracker.update_done_task_count(request_id)
|
||||||
self.task_tracker.update_done_task_count(request_id)
|
# Acknowledge the request completion.
|
||||||
# Acknowledge the request completion.
|
while True:
|
||||||
while True:
|
try:
|
||||||
try:
|
# Send ACK to the sender.
|
||||||
# Send ACK to the sender.
|
sock.send_multipart(
|
||||||
sock.send_multipart(
|
(identity, b"", b"ACK"),
|
||||||
(identity, b"", b"ACK"),
|
flags=zmq.NOBLOCK) # type: ignore
|
||||||
flags=zmq.NOBLOCK) # type: ignore
|
break
|
||||||
break
|
except zmq.Again: # type: ignore
|
||||||
except zmq.Again: # type: ignore
|
# If the socket is not ready, retry sending.
|
||||||
# If the socket is not ready, retry sending.
|
logger.debug(
|
||||||
logger.debug(
|
"Socket not ready, retrying to send ACK for "
|
||||||
"Socket not ready, retrying to send ACK for "
|
"request %s", msg[1])
|
||||||
"request %s", msg[1])
|
time.sleep(0.01)
|
||||||
time.sleep(0.01)
|
else:
|
||||||
else:
|
logger.error(
|
||||||
logger.error(
|
"Connection listener got unexpected message %s", msg)
|
||||||
"Connection listener got unexpected message %s",
|
except Exception as e:
|
||||||
msg)
|
logger.error("Connection listener got exception %s: %s",
|
||||||
except Exception as e:
|
type(e), e)
|
||||||
logger.error("Connection listener got exception %s: %s",
|
|
||||||
type(e), e)
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheRecvingThread(threading.Thread):
|
class KVCacheRecvingThread(threading.Thread):
|
||||||
@@ -1151,7 +1156,18 @@ class MooncakeConnectorWorker:
|
|||||||
self.block_len, ready_event, self.vllm_config, self.kv_caches,
|
self.block_len, ready_event, self.vllm_config, self.kv_caches,
|
||||||
self._prefill_pp_layer_partition)
|
self._prefill_pp_layer_partition)
|
||||||
self.kv_recv_thread.start()
|
self.kv_recv_thread.start()
|
||||||
ready_event.wait()
|
|
||||||
|
start_wait_time = time.time()
|
||||||
|
thread = self.kv_send_thread if self.kv_role == 'kv_producer' else self.kv_recv_thread
|
||||||
|
assert thread is not None
|
||||||
|
while not ready_event.is_set():
|
||||||
|
if not thread.is_alive():
|
||||||
|
raise RuntimeError(
|
||||||
|
"KV Cache sending/receiving thread failed to start.")
|
||||||
|
if time.time() - start_wait_time > 5 * 60:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Timeout waiting for KV Cache thread to be ready.")
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||||
done_sending = (
|
done_sending = (
|
||||||
|
|||||||
Reference in New Issue
Block a user