diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 4288b0d1..2d28f1e8 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -541,12 +541,12 @@ class KVCacheRecvingLayerThread(threading.Thread): self.done_requests = set() return finished_requests - def update_task(self, req_id, trans_count): + def update_task(self, req_id, trans_count, side_channel_path): with self.lock: if req_id not in self.task_tracker: - self.task_tracker[req_id] = 0 - self.task_tracker[req_id] += 1 - if self.task_tracker[req_id] == trans_count: + self.task_tracker[req_id] = set() + self.task_tracker[req_id].add(side_channel_path) + if len(self.task_tracker[req_id]) == trans_count: self.task_tracker.pop(req_id) self.done_requests.add(req_id) @@ -581,7 +581,8 @@ class KVCacheRecvingLayerThread(threading.Thread): logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) request_id = msg[1] trans_count = msg[2] - self.update_task(request_id, trans_count) + side_channel_path = msg[3] + self.update_task(request_id, trans_count, side_channel_path) sock.send_multipart((identity, b"", b"ACK")) else: logger.error("Connection listener got unexpected message %s", msg) @@ -1760,24 +1761,39 @@ class MooncakeLayerwiseConnectorWorker: try: path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) msg_encoder = msgspec.msgpack.Encoder() - encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) - with zmq_ctx(zmq.REQ, path) as sock: # type: ignore - ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") - # Avoid blocking forever waiting for the REQ/ACK response. - sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore + side_channel_path = f"{self.side_channel_host}:{self.handshake_port}" + encoded_data = msg_encoder.encode( + (DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx], side_channel_path) + ) + max_retries = 3 + for attempt in range(1, max_retries + 1): try: - ack = sock.recv() - except zmq.Again: # type: ignore - logger.warning( - "Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)", - external_req_id, - req_meta.remote_host, - req_meta.remote_port, - self.timeout, - ) - return - if ack != b"ACK": - raise ValueError(f"Unexpected ACK response: {ack}") + with zmq_ctx(zmq.REQ, path) as sock: # type: ignore + sock.setsockopt(zmq.SNDTIMEO, int(self.timeout * 1000)) + ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") + if not sock.poll(int(self.timeout * 1000), zmq.POLLIN): # type: ignore[attr-defined] + raise TimeoutError( + f"Timed out waiting for ACK from {req_meta.remote_host}:{req_meta.remote_port}" + ) + ack = sock.recv() + if ack != b"ACK": + raise ValueError(f"Unexpected ACK response: {ack}") + return + except Exception as e: + if attempt < max_retries: + logger.warning( + "Failed to send done sending signal for request %s to %s:%d on attempt %d/%d: %s. " + "Retrying...", + external_req_id, + req_meta.remote_host, + req_meta.remote_port, + attempt, + max_retries, + e, + ) + time.sleep(0.1) + else: + raise RuntimeError(f"Failed to receive ACK after {max_retries} attempts: {e}") from e except Exception as e: logger.error( f"Sending done sending signal for request {external_req_id} to "