diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 6cacef83..603a89b8 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -417,6 +417,11 @@ class KVCacheRecvingThread(threading.Thread): f"{request_id}: {e}", exc_info=True) finally: + if all_task_done: + self.task_tracker.update_done_task_count(request_id) + if request_id in self.proc_not_transfer_request: + del self.proc_not_transfer_request[request_id] + self.request_queue.task_done() # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. @@ -425,11 +430,6 @@ class KVCacheRecvingThread(threading.Thread): remote_port_send_num) self._send_done_signal_to_free_remote_port(request_id, remote_host, remote_port_send_num) - if all_task_done: - self.task_tracker.update_done_task_count(request_id) - if request_id in self.proc_not_transfer_request: - del self.proc_not_transfer_request[request_id] - self.request_queue.task_done() def _send_done_signal_to_free_remote_port(self, request_id, remote_host, remote_port_send_num): @@ -698,6 +698,13 @@ class KVCacheRecvingThread(threading.Thread): request_id, remote_host, remote_handshake_port) raise RuntimeError( f"Failed to receive ACK, resp: {resp.decode('utf-8')}") + except RuntimeError as e: + if isinstance(sock, zmq.Socket): # type: ignore + sock.close() + sock = None + logger.warning( + f"Unexpected error occurred in socket, {e}, closing the original channel" + ) finally: if sock is not None: self._return_remote_socket(sock, remote_host,