diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 7d45c475..1d37498f 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -207,66 +207,71 @@ class KVCacheSendingThread(threading.Thread): def run(self): """Run the thread to handle KV cache transfer requests.""" + try: + # Listen for new requests for metadata. NOTE(rob): we need each rank + # to have a unique port. This hack to keeps us moving. We will + # switch when moving to etcd or where we have a single ZMQ socket in + # the scheduler. + device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size + handshake_port = self.side_channel_port + device_index + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + self.run_busy_loop(sock) + except Exception as e: + logger.error("Mooncake KVCacheSendingThread exception: %s", + e, + exc_info=True) + def run_busy_loop(self, sock: zmq.Socket): # type: ignore encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(self.metadata) size_in_bytes = len(encoded_data) logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", str(size_in_bytes)) - # Listen for new requests for metadata. - # NOTE(rob): we need each rank to have a unique port. This hack to keeps - # us moving. We will switch when moving to etcd or where we have a - # single ZMQ socket in the scheduler. - device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size - handshake_port = self.side_channel_port + device_index - path = make_zmq_path("tcp", self.side_channel_host, handshake_port) - logger.info("Starting listening on path: %s", path) - with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore - self.ready_event.set() - decoder = msgspec.msgpack.Decoder(type=tuple) - while True: - try: - frames = sock.recv_multipart() - if len(frames) < 2: - logger.error("Invalid message format: %s", frames) - continue + decoder = msgspec.msgpack.Decoder(type=tuple) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue - identity = frames[0] - payload = [f for f in frames[1:] if f != b""] - if len(payload) != 1: - logger.error("Invalid message format: %s", frames) - continue + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue - msg = decoder.decode(payload[0]) - if msg[0] == GET_META_MSG: - sock.send_multipart((identity, b"", encoded_data)) - elif msg[0] == DONE_RECVING_MSG: - logger.debug("Got DONE_RECVING_MSG for request %s", - msg[1]) - request_id = msg[1] - self.task_tracker.update_done_task_count(request_id) - # Acknowledge the request completion. - while True: - try: - # Send ACK to the sender. - sock.send_multipart( - (identity, b"", b"ACK"), - flags=zmq.NOBLOCK) # type: ignore - break - except zmq.Again: # type: ignore - # If the socket is not ready, retry sending. - logger.debug( - "Socket not ready, retrying to send ACK for " - "request %s", msg[1]) - time.sleep(0.01) - else: - logger.error( - "Connection listener got unexpected message %s", - msg) - except Exception as e: - logger.error("Connection listener got exception %s: %s", - type(e), e) + msg = decoder.decode(payload[0]) + if msg[0] == GET_META_MSG: + sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == DONE_RECVING_MSG: + logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) + request_id = msg[1] + self.task_tracker.update_done_task_count(request_id) + # Acknowledge the request completion. + while True: + try: + # Send ACK to the sender. + sock.send_multipart( + (identity, b"", b"ACK"), + flags=zmq.NOBLOCK) # type: ignore + break + except zmq.Again: # type: ignore + # If the socket is not ready, retry sending. + logger.debug( + "Socket not ready, retrying to send ACK for " + "request %s", msg[1]) + time.sleep(0.01) + else: + logger.error( + "Connection listener got unexpected message %s", msg) + except Exception as e: + logger.error("Connection listener got exception %s: %s", + type(e), e) class KVCacheRecvingThread(threading.Thread): @@ -1151,7 +1156,18 @@ class MooncakeConnectorWorker: self.block_len, ready_event, self.vllm_config, self.kv_caches, self._prefill_pp_layer_partition) self.kv_recv_thread.start() - ready_event.wait() + + start_wait_time = time.time() + thread = self.kv_send_thread if self.kv_role == 'kv_producer' else self.kv_recv_thread + assert thread is not None + while not ready_event.is_set(): + if not thread.is_alive(): + raise RuntimeError( + "KV Cache sending/receiving thread failed to start.") + if time.time() - start_wait_time > 5 * 60: + raise RuntimeError( + "Timeout waiting for KV Cache thread to be ready.") + time.sleep(3) def get_finished(self) -> tuple[set[str], set[str]]: done_sending = (