diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 2bcc9288..d9ef029b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -508,7 +508,6 @@ class KVCacheRecvingLayerThread(threading.Thread): self.lock = threading.Lock() self.done_requests = set[str]() self.task_tracker = dict[str, int]() - self.request_map = dict[str, str]() self.ready_event = ready_event self.metadata = metadata @@ -525,11 +524,12 @@ class KVCacheRecvingLayerThread(threading.Thread): def update_task(self, req_id, trans_count): with self.lock: + if req_id not in self.task_tracker: + self.task_tracker[req_id] = 0 self.task_tracker[req_id] += 1 if self.task_tracker[req_id] == trans_count: self.task_tracker.pop(req_id) - self.done_requests.add(self.request_map[req_id]) - self.request_map.pop(req_id) + self.done_requests.add(req_id) def run(self): """Run the thread to handle KV cache transfer requests.""" @@ -996,6 +996,7 @@ class MooncakeLayerwiseConnectorWorker: self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.use_mla = self.vllm_config.model_config.use_mla + self.request_map = dict[str, str]() if self.use_mla: self.total_num_kv_heads = 1 else: @@ -1248,8 +1249,12 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_consumer else set() ) + done_recving = {self.request_map[s] for s in done_recving if s in self.request_map} done_recving.update(self.virtual_request) self.virtual_request = set() + for req_id in done_recving: + req_id = req_id[:-9] + self.request_map.pop(req_id, None) if len(done_recving) > 0: logger.info( f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" @@ -1427,9 +1432,7 @@ class MooncakeLayerwiseConnectorWorker: continue external_req_id = get_external_request_id(req_id) assert self.kv_recv_layer_thread is not None - with self.kv_recv_layer_thread.lock: - self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 - self.kv_recv_layer_thread.request_map[external_req_id] = req_id + self.request_map[external_req_id] = req_id elif self.vllm_config.kv_transfer_config.is_kv_producer: # update trans info update_metadata = {} @@ -1741,7 +1744,19 @@ class MooncakeLayerwiseConnectorWorker: encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") - ack = sock.recv() + # Avoid blocking forever waiting for the REQ/ACK response. + sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore + try: + ack = sock.recv() + except zmq.Again: # type: ignore + logger.warning( + "Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)", + external_req_id, + req_meta.remote_host, + req_meta.remote_port, + self.timeout, + ) + return if ack != b"ACK": raise ValueError(f"Unexpected ACK response: {ack}") except Exception as e: