[P/D] bugfix for p node force free requset (#5431)

### What this PR does / why we need it?
Fix the bug where the P-node's schedule dead after it force-frees a
request due to timeout and then receives the completed kv cache pulled
by the D-node again. By add list to recode all requests.


- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
liziyu
2026-01-14 08:51:31 +08:00
committed by GitHub
parent 78d5ce3e01
commit e1bed43cff

View File

@@ -109,20 +109,27 @@ class KVCacheTaskTracker:
# intentionally delayed. Each entry is a tuple of (request_id, # intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will # timestamp). If a request remains in this queue for too long, it will
# be force-freed. # be force-freed.
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
self.reqs_to_process: set[str] = set()
def add_req_to_process(self, request_id: str):
self.reqs_to_process.add(request_id)
def add_not_transfer_request(self, request_id: str): def add_not_transfer_request(self, request_id: str):
with self.done_task_lock: with self.done_task_lock:
self.finished_requests.add(request_id) self.finished_requests.add(request_id)
self.reqs_to_process.discard(request_id)
def update_done_task_count(self, request_id: str): def update_done_task_count(self, request_id: str):
with self.done_task_lock: with self.done_task_lock:
if request_id in self.reqs_to_process:
self.finished_requests.add(request_id) self.finished_requests.add(request_id)
if request_id in self.delayed_free_requests: self.reqs_to_process.discard(request_id)
self._remove_delayed_requests(request_id) self.delayed_free_requests.pop(request_id, None)
else: else:
self.record_finished_requests.add(request_id) logger.error(
"MooncakeConnector finish req not in reqs to process.If it is a P node, this request may have been force freed."
)
def get_and_clear_finished_requests(self) -> set[str]: def get_and_clear_finished_requests(self) -> set[str]:
""" """
@@ -140,10 +147,7 @@ class KVCacheTaskTracker:
def add_delayed_request(self, request_id: str, delay_start_time: float): def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request.""" """Add a delayed free request."""
with self.done_task_lock: with self.done_task_lock:
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
def _retrieve_expired_requests(self): def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests.""" """Retrieve all expired delayed requests."""
@@ -156,16 +160,13 @@ class KVCacheTaskTracker:
if (current_time - delay_start_time if (current_time - delay_start_time
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popitem(last=False) self.delayed_free_requests.popitem(last=False)
self.reqs_to_process.discard(request_id)
expired_requests.add(request_id) expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id) logger.info("Force freed request: %s", request_id)
else: else:
break break
return expired_requests return expired_requests
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests.pop(request_id)
class KVCacheSendingThread(threading.Thread): class KVCacheSendingThread(threading.Thread):
@@ -769,6 +770,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self): def __init__(self):
self.requests: dict[str, ReqMeta] = {} self.requests: dict[str, ReqMeta] = {}
self.requests_to_send: dict[str, float] = {} self.requests_to_send: dict[str, float] = {}
self.reqs_in_batch: set[str] = set()
def add_new_req( def add_new_req(
self, self,
@@ -932,6 +934,7 @@ class MooncakeConnectorScheduler:
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {} self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
self._reqs_need_send: dict[str, float] = {} self._reqs_need_send: dict[str, float] = {}
self._reqs_in_batch: set[str] = set()
# master-slave meta information for cross-nodes # master-slave meta information for cross-nodes
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {} self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
@@ -980,6 +983,9 @@ class MooncakeConnectorScheduler:
"num_external_tokens=%s, kv_transfer_params=%s", "num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, params) num_external_tokens, params)
if params is not None and (params.get("do_remote_prefill", False)
or params.get("do_remote_decode", False)):
self._reqs_in_batch.add(request.request_id)
if params is not None and params.get("do_remote_prefill"): if params is not None and params.get("do_remote_prefill"):
if params.get("remote_block_ids"): if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host", if all(p in params for p in ("remote_engine_id", "remote_host",
@@ -1022,6 +1028,8 @@ class MooncakeConnectorScheduler:
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
meta.requests_to_send = self._reqs_need_send meta.requests_to_send = self._reqs_need_send
self._reqs_need_send = {} self._reqs_need_send = {}
meta.reqs_in_batch = self._reqs_in_batch
self._reqs_in_batch = set()
return meta return meta
@@ -1601,6 +1609,12 @@ class MooncakeConnectorWorker:
all_task_done=(i == self.tp_num_need_pulls * all_task_done=(i == self.tp_num_need_pulls *
self._prefill_pp_size - 1)) self._prefill_pp_size - 1))
for req_id in metadata.reqs_in_batch:
if self.kv_send_thread is not None:
self.kv_send_thread.task_tracker.add_req_to_process(req_id)
if self.kv_recv_thread is not None:
self.kv_recv_thread.task_tracker.add_req_to_process(req_id)
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1: if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
for req_id, delay_start_time in metadata.requests_to_send.items(): for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._prefill_get_remote_rank(req_id): if self.tp_rank in self._prefill_get_remote_rank(req_id):