From 82e26b5a6e168d3eae4c033bce4a98085b9a5119 Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Tue, 31 Mar 2026 18:55:36 +0800 Subject: [PATCH] [BugFix][v0.18.0]Adjust request map pop time (#7857) ### What this PR does / why we need it? Adjust request map pop time.This pull request optimizes the KV cache transfer mechanism by streamlining how requests are tracked and cleaned up. By removing unnecessary mapping structures and adjusting the timing of request removal, the system achieves more efficient state management during the transfer process. pick-from:https://github.com/vllm-project/vllm-ascend/pull/7855 ### How was this patch tested? By ci Signed-off-by: wangxiaoteng --- .../kv_p2p/mooncake_layerwise_connector.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py index 2bcc9288..d9ef029b 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_layerwise_connector.py @@ -508,7 +508,6 @@ class KVCacheRecvingLayerThread(threading.Thread): self.lock = threading.Lock() self.done_requests = set[str]() self.task_tracker = dict[str, int]() - self.request_map = dict[str, str]() self.ready_event = ready_event self.metadata = metadata @@ -525,11 +524,12 @@ class KVCacheRecvingLayerThread(threading.Thread): def update_task(self, req_id, trans_count): with self.lock: + if req_id not in self.task_tracker: + self.task_tracker[req_id] = 0 self.task_tracker[req_id] += 1 if self.task_tracker[req_id] == trans_count: self.task_tracker.pop(req_id) - self.done_requests.add(self.request_map[req_id]) - self.request_map.pop(req_id) + self.done_requests.add(req_id) def run(self): """Run the thread to handle KV cache transfer requests.""" @@ -996,6 +996,7 @@ class MooncakeLayerwiseConnectorWorker: self.side_channel_host = get_ip() self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config) self.use_mla = self.vllm_config.model_config.use_mla + self.request_map = dict[str, str]() if self.use_mla: self.total_num_kv_heads = 1 else: @@ -1248,8 +1249,12 @@ class MooncakeLayerwiseConnectorWorker: if self.vllm_config.kv_transfer_config.is_kv_consumer else set() ) + done_recving = {self.request_map[s] for s in done_recving if s in self.request_map} done_recving.update(self.virtual_request) self.virtual_request = set() + for req_id in done_recving: + req_id = req_id[:-9] + self.request_map.pop(req_id, None) if len(done_recving) > 0: logger.info( f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}" @@ -1427,9 +1432,7 @@ class MooncakeLayerwiseConnectorWorker: continue external_req_id = get_external_request_id(req_id) assert self.kv_recv_layer_thread is not None - with self.kv_recv_layer_thread.lock: - self.kv_recv_layer_thread.task_tracker[external_req_id] = 0 - self.kv_recv_layer_thread.request_map[external_req_id] = req_id + self.request_map[external_req_id] = req_id elif self.vllm_config.kv_transfer_config.is_kv_producer: # update trans info update_metadata = {} @@ -1741,7 +1744,19 @@ class MooncakeLayerwiseConnectorWorker: encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) with zmq_ctx(zmq.REQ, path) as sock: # type: ignore ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") - ack = sock.recv() + # Avoid blocking forever waiting for the REQ/ACK response. + sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore + try: + ack = sock.recv() + except zmq.Again: # type: ignore + logger.warning( + "Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)", + external_req_id, + req_meta.remote_host, + req_meta.remote_port, + self.timeout, + ) + return if ack != b"ACK": raise ValueError(f"Unexpected ACK response: {ack}") except Exception as e: