[P/D] bugfix for p node force free requset (#5431)
### What this PR does / why we need it?
Fix the bug where the P-node's schedule dead after it force-frees a
request due to timeout and then receives the completed kv cache pulled
by the D-node again. By add list to recode all requests.
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -109,20 +109,27 @@ class KVCacheTaskTracker:
|
|||||||
# intentionally delayed. Each entry is a tuple of (request_id,
|
# intentionally delayed. Each entry is a tuple of (request_id,
|
||||||
# timestamp). If a request remains in this queue for too long, it will
|
# timestamp). If a request remains in this queue for too long, it will
|
||||||
# be force-freed.
|
# be force-freed.
|
||||||
self.record_finished_requests: set[str] = set()
|
|
||||||
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
||||||
|
self.reqs_to_process: set[str] = set()
|
||||||
|
|
||||||
|
def add_req_to_process(self, request_id: str):
|
||||||
|
self.reqs_to_process.add(request_id)
|
||||||
|
|
||||||
def add_not_transfer_request(self, request_id: str):
|
def add_not_transfer_request(self, request_id: str):
|
||||||
with self.done_task_lock:
|
with self.done_task_lock:
|
||||||
self.finished_requests.add(request_id)
|
self.finished_requests.add(request_id)
|
||||||
|
self.reqs_to_process.discard(request_id)
|
||||||
|
|
||||||
def update_done_task_count(self, request_id: str):
|
def update_done_task_count(self, request_id: str):
|
||||||
with self.done_task_lock:
|
with self.done_task_lock:
|
||||||
self.finished_requests.add(request_id)
|
if request_id in self.reqs_to_process:
|
||||||
if request_id in self.delayed_free_requests:
|
self.finished_requests.add(request_id)
|
||||||
self._remove_delayed_requests(request_id)
|
self.reqs_to_process.discard(request_id)
|
||||||
|
self.delayed_free_requests.pop(request_id, None)
|
||||||
else:
|
else:
|
||||||
self.record_finished_requests.add(request_id)
|
logger.error(
|
||||||
|
"MooncakeConnector finish req not in reqs to process.If it is a P node, this request may have been force freed."
|
||||||
|
)
|
||||||
|
|
||||||
def get_and_clear_finished_requests(self) -> set[str]:
|
def get_and_clear_finished_requests(self) -> set[str]:
|
||||||
"""
|
"""
|
||||||
@@ -140,10 +147,7 @@ class KVCacheTaskTracker:
|
|||||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||||
"""Add a delayed free request."""
|
"""Add a delayed free request."""
|
||||||
with self.done_task_lock:
|
with self.done_task_lock:
|
||||||
if request_id not in self.record_finished_requests:
|
self.delayed_free_requests[request_id] = delay_start_time
|
||||||
self.delayed_free_requests[request_id] = delay_start_time
|
|
||||||
else:
|
|
||||||
self.record_finished_requests.discard(request_id)
|
|
||||||
|
|
||||||
def _retrieve_expired_requests(self):
|
def _retrieve_expired_requests(self):
|
||||||
"""Retrieve all expired delayed requests."""
|
"""Retrieve all expired delayed requests."""
|
||||||
@@ -156,16 +160,13 @@ class KVCacheTaskTracker:
|
|||||||
if (current_time - delay_start_time
|
if (current_time - delay_start_time
|
||||||
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
||||||
self.delayed_free_requests.popitem(last=False)
|
self.delayed_free_requests.popitem(last=False)
|
||||||
|
self.reqs_to_process.discard(request_id)
|
||||||
expired_requests.add(request_id)
|
expired_requests.add(request_id)
|
||||||
logger.info("Force freed request: %s", request_id)
|
logger.info("Force freed request: %s", request_id)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return expired_requests
|
return expired_requests
|
||||||
|
|
||||||
def _remove_delayed_requests(self, request_id: str):
|
|
||||||
"""Remove all delayed free requests matching the given request_id."""
|
|
||||||
self.delayed_free_requests.pop(request_id)
|
|
||||||
|
|
||||||
|
|
||||||
class KVCacheSendingThread(threading.Thread):
|
class KVCacheSendingThread(threading.Thread):
|
||||||
|
|
||||||
@@ -769,6 +770,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.requests: dict[str, ReqMeta] = {}
|
self.requests: dict[str, ReqMeta] = {}
|
||||||
self.requests_to_send: dict[str, float] = {}
|
self.requests_to_send: dict[str, float] = {}
|
||||||
|
self.reqs_in_batch: set[str] = set()
|
||||||
|
|
||||||
def add_new_req(
|
def add_new_req(
|
||||||
self,
|
self,
|
||||||
@@ -932,6 +934,7 @@ class MooncakeConnectorScheduler:
|
|||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
|
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
|
||||||
self._reqs_need_send: dict[str, float] = {}
|
self._reqs_need_send: dict[str, float] = {}
|
||||||
|
self._reqs_in_batch: set[str] = set()
|
||||||
|
|
||||||
# master-slave meta information for cross-nodes
|
# master-slave meta information for cross-nodes
|
||||||
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
|
self.multi_nodes_meta_mapping: dict[str, dict[str, Any]] = {}
|
||||||
@@ -980,6 +983,9 @@ class MooncakeConnectorScheduler:
|
|||||||
"num_external_tokens=%s, kv_transfer_params=%s",
|
"num_external_tokens=%s, kv_transfer_params=%s",
|
||||||
num_external_tokens, params)
|
num_external_tokens, params)
|
||||||
|
|
||||||
|
if params is not None and (params.get("do_remote_prefill", False)
|
||||||
|
or params.get("do_remote_decode", False)):
|
||||||
|
self._reqs_in_batch.add(request.request_id)
|
||||||
if params is not None and params.get("do_remote_prefill"):
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
if params.get("remote_block_ids"):
|
if params.get("remote_block_ids"):
|
||||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||||
@@ -1022,6 +1028,8 @@ class MooncakeConnectorScheduler:
|
|||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
meta.requests_to_send = self._reqs_need_send
|
meta.requests_to_send = self._reqs_need_send
|
||||||
self._reqs_need_send = {}
|
self._reqs_need_send = {}
|
||||||
|
meta.reqs_in_batch = self._reqs_in_batch
|
||||||
|
self._reqs_in_batch = set()
|
||||||
|
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
@@ -1601,6 +1609,12 @@ class MooncakeConnectorWorker:
|
|||||||
all_task_done=(i == self.tp_num_need_pulls *
|
all_task_done=(i == self.tp_num_need_pulls *
|
||||||
self._prefill_pp_size - 1))
|
self._prefill_pp_size - 1))
|
||||||
|
|
||||||
|
for req_id in metadata.reqs_in_batch:
|
||||||
|
if self.kv_send_thread is not None:
|
||||||
|
self.kv_send_thread.task_tracker.add_req_to_process(req_id)
|
||||||
|
if self.kv_recv_thread is not None:
|
||||||
|
self.kv_recv_thread.task_tracker.add_req_to_process(req_id)
|
||||||
|
|
||||||
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
|
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
|
||||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||||
if self.tp_rank in self._prefill_get_remote_rank(req_id):
|
if self.tp_rank in self._prefill_get_remote_rank(req_id):
|
||||||
|
|||||||
Reference in New Issue
Block a user