From e1bed43cffcc8d3ee2b466f1d504fa68f9728236 Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Wed, 14 Jan 2026 08:51:31 +0800 Subject: [PATCH] [P/D] bugfix for p node force free requset (#5431) ### What this PR does / why we need it? Fix the bug where the P-node's schedule dead after it force-frees a request due to timeout and then receives the completed kv cache pulled by the D-node again. By add list to recode all requests. - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/81786c87748b0177111dfdc07af5351d8389baa1 --------- Signed-off-by: liziyu Signed-off-by: wangxiaoteng Co-authored-by: wangxiaoteng --- vllm_ascend/distributed/mooncake_connector.py | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index d46d64ba..9c33f25d 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -109,20 +109,27 @@ class KVCacheTaskTracker: # intentionally delayed. Each entry is a tuple of (request_id, # timestamp). If a request remains in this queue for too long, it will # be force-freed. - self.record_finished_requests: set[str] = set() self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() + self.reqs_to_process: set[str] = set() + + def add_req_to_process(self, request_id: str): + self.reqs_to_process.add(request_id) def add_not_transfer_request(self, request_id: str): with self.done_task_lock: self.finished_requests.add(request_id) + self.reqs_to_process.discard(request_id) def update_done_task_count(self, request_id: str): with self.done_task_lock: - self.finished_requests.add(request_id) - if request_id in self.delayed_free_requests: - self._remove_delayed_requests(request_id) + if request_id in self.reqs_to_process: + self.finished_requests.add(request_id) + self.reqs_to_process.discard(request_id) + self.delayed_free_requests.pop(request_id, None) else: - self.record_finished_requests.add(request_id) + logger.error( + "MooncakeConnector finish req not in reqs to process.If it is a P node, this request may have been force freed." + ) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -140,10 +147,7 @@ class KVCacheTaskTracker: def add_delayed_request(self, request_id: str, delay_start_time: float): """Add a delayed free request.""" with self.done_task_lock: - if request_id not in self.record_finished_requests: - self.delayed_free_requests[request_id] = delay_start_time - else: - self.record_finished_requests.discard(request_id) + self.delayed_free_requests[request_id] = delay_start_time def _retrieve_expired_requests(self): """Retrieve all expired delayed requests.""" @@ -156,16 +160,13 @@ class KVCacheTaskTracker: if (current_time - delay_start_time > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): self.delayed_free_requests.popitem(last=False) + self.reqs_to_process.discard(request_id) expired_requests.add(request_id) logger.info("Force freed request: %s", request_id) else: break return expired_requests - def _remove_delayed_requests(self, request_id: str): - """Remove all delayed free requests matching the given request_id.""" - self.delayed_free_requests.pop(request_id) - class KVCacheSendingThread(threading.Thread): @@ -769,6 +770,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): self.requests: dict[str, ReqMeta] = {} self.requests_to_send: dict[str, float] = {} + self.reqs_in_batch: set[str] = set() def add_new_req( self, @@ -932,6 +934,7 @@ class MooncakeConnectorScheduler: # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {} self._reqs_need_send: dict[str, float] = {} + self._reqs_in_batch: set[str] = set() # master-slave meta information for cross-nodes self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {} @@ -980,6 +983,9 @@ class MooncakeConnectorScheduler: "num_external_tokens=%s, kv_transfer_params=%s", num_external_tokens, params) + if params is not None and (params.get("do_remote_prefill", False) + or params.get("do_remote_decode", False)): + self._reqs_in_batch.add(request.request_id) if params is not None and params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", @@ -1022,6 +1028,8 @@ class MooncakeConnectorScheduler: self._reqs_need_recv.clear() meta.requests_to_send = self._reqs_need_send self._reqs_need_send = {} + meta.reqs_in_batch = self._reqs_in_batch + self._reqs_in_batch = set() return meta @@ -1601,6 +1609,12 @@ class MooncakeConnectorWorker: all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1)) + for req_id in metadata.reqs_in_batch: + if self.kv_send_thread is not None: + self.kv_send_thread.task_tracker.add_req_to_process(req_id) + if self.kv_recv_thread is not None: + self.kv_recv_thread.task_tracker.add_req_to_process(req_id) + if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1: for req_id, delay_start_time in metadata.requests_to_send.items(): if self.tp_rank in self._prefill_get_remote_rank(req_id):