From d9ac7e8539aac5947df30d56ed41784ea2d9a69a Mon Sep 17 00:00:00 2001 From: Chao Lei Date: Tue, 17 Mar 2026 23:17:45 +0800 Subject: [PATCH] [Bugfix] Assertion error when decode prefix cache fully hits (#7236) ### What this PR does / why we need it? #### Problem When decode node enables prefix cache and the local prefix cache fully hits, the following assertion error occurs: ``` (EngineCore_DP3 pid=34912) File "/usr/local/python3.11.14/lib/python3.11/site-packages/vllm/v1/engine/core.py", line 520, in step_with_batch_queue (EngineCore_DP3 pid=34912) engine_core_outputs = self.scheduler.update_from_output( (EngineCore_DP3 pid=34912) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (EngineCore_DP3 pid=34912) File "/usr/local/python3.11.14/lib/python3.11/site-packages/vllm/v1/core/sched/scheduler.py", line 1520, in update_from_output (EngineCore_DP3 pid=34912) self._update_from_kv_xfer_finished(kv_connector_output) (EngineCore_DP3 pid=34912) File "/usr/local/python3.11.14/lib/python3.11/site-packages/vllm/v1/core/sched/scheduler.py", line 2120, in _update_from_kv_xfer_finished (EngineCore_DP3 pid=34912) assert RequestStatus.is_finished(req.status) (EngineCore_DP3 pid=34912) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ (EngineCore_DP3 pid=34912) AssertionError ``` The error is triggered in scheduler.py at _update_from_kv_xfer_finished: ``` if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS: self.finished_recving_kv_req_ids.add(req_id) else: assert RequestStatus.is_finished(req.status) ``` #### Root Cause When decode node has prefix cache enabled and local prefix cache fully hits: 1. get_num_new_matched_tokens returns ext_tokens=0, load_kv_async=False when decode prefix cache fully hits 2. Request status becomes RUNNING (not WAITING_FOR_REMOTE_KVS) 3. However, update_state_after_alloc still adds the request to _reqs_need_recv because remote_block_ids exists in kv_transfer_params 4. Worker processes the request in _handle_request: - _transfer_kv_cache returns immediately (no actual transfer, local_block_ids is empty) - finally block still calls update_done_task_count(request_id) 5. finished_recving contains this request 6. When _update_from_kv_xfer_finished processes finished_recving, request status is RUNNING 7. Assertion fails #### Solution In _handle_request, only notify scheduler (update_done_task_count) when actual KV transfer happened (local_block_ids is not empty). The signals to notify Prefill to release KVCache (_send_done_signal_to_free_remote_port and _send_done_recv_signal) are still sent regardless. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: LCAIZJ --- .../distributed/kv_transfer/kv_p2p/mooncake_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 4c692e33..95d05fdc 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -442,7 +442,8 @@ class KVCacheRecvingThread(threading.Thread): finally: self._send_done_signal_to_free_remote_port(remote_request_id, remote_host, remote_port_send_num) if all_task_done: - self.task_tracker.update_done_task_count(request_id) + if len(req_meta["local_block_ids"]) > 0: + self.task_tracker.update_done_task_count(request_id) if request_id in self.proc_not_transfer_request: del self.proc_not_transfer_request[request_id] self.request_queue.task_done()