From eb205d9f357848092946fe9d98dabde4bc3ee709 Mon Sep 17 00:00:00 2001 From: baxingpiaochong <771405853@qq.com> Date: Wed, 24 Sep 2025 11:22:46 +0800 Subject: [PATCH] [P/D][BugFix]Mooncake timeout release bug fix (#2899) ### What this PR does / why we need it? In the P node timeout release mechanism during PD separation, the req_id that requires timeout release is transmitted from the scheduler to the worker. If the KV cache between PDs is transferred too quickly, the P node's req_id may be released twice. The first release is when the D node notifies the P node that the KV cache has been pulled, and the second release is when the scheduler transmits the timeout release to the worker. To address this bug, an intermediate component is introduced to manage the release of req_ids. Pull kv and forward2 may occur one after the other in timing. The previous timeout defaulted to forward2 being before pull_kv. ### How was this patch tested? - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/f225ea7dd98e9f29752e5c032cd4a8ee1d712f16 --------- Signed-off-by: baxingpiaochong <771405853@qq.com> --- .../kv_connector/test_mooncake_connector.py | 31 +++++++++++++++++-- vllm_ascend/distributed/mooncake_connector.py | 23 +++++++++----- 2 files changed, 43 insertions(+), 11 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index a3a593e..2ea23bc 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -7,6 +7,7 @@ import time import types import unittest from collections import defaultdict, deque +from typing import OrderedDict from unittest.mock import MagicMock, patch import msgspec @@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase): tracker = KVCacheTaskTracker() self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) self.assertIsInstance(tracker.finished_requests, set) - self.assertIsInstance(tracker.delayed_free_requests, deque) + self.assertIsInstance(tracker.delayed_free_requests, OrderedDict) class TestGetAndClearFinishedSingleRequests(unittest.TestCase): @@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase): def test_update_done_task_count(self): self.assertEqual(len(self.tracker.finished_requests), 0) self.assertEqual(len(self.tracker.delayed_free_requests), 0) + self.assertEqual(len(self.tracker.record_finished_requests), 0) current_time = time.time() self.tracker.add_delayed_request("req_1", current_time) result = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(len(result), 1) - self.assertEqual(result[0], ("req_1", current_time)) + self.assertEqual(result["req_1"], current_time) + self.assertEqual(len(result_record), 0) self.tracker.update_done_task_count("req_1") result_finished = self.tracker.finished_requests result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(result_finished, {"req_1"}) self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 0) + + self.tracker.update_done_task_count("req_2") + result_finished = self.tracker.finished_requests + result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests + self.assertEqual(result_finished, {"req_1", "req_2"}) + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 1) + self.assertEqual(result_record, {"req_2"}) + + def test_updtate_add_delayed_request(self) -> None: + self.tracker.update_done_task_count("req2") + result_start_record = self.tracker.record_finished_requests + self.assertEqual(len(result_start_record), 1) + self.tracker.add_delayed_request("req2", time.time()) + result_delayed = self.tracker.delayed_free_requests + result_end_record = self.tracker.record_finished_requests + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_end_record), 0) def test_retrieve_expired_requests(self): current_time = time.time() @@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase): }) result_delay = self.tracker.delayed_free_requests self.assertEqual(len(result_delay), 1) - self.assertEqual(result_delay[0], ("req_2", current_time)) + self.assertIn("req_2", result_delay) def test_duplicate_task_update(self): self.tracker.update_done_task_count("req1") diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8c37ad4..7faf1be 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -11,7 +11,7 @@ from collections import defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple import msgspec import numpy as np @@ -68,12 +68,16 @@ class KVCacheTaskTracker: # intentionally delayed. Each entry is a tuple of (request_id, # timestamp). If a request remains in this queue for too long, it will # be force-freed. - self.delayed_free_requests: deque[Tuple[str, float]] = deque() + self.record_finished_requests: set[str] = set() + self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() def update_done_task_count(self, request_id: str): with self.done_task_lock: self.finished_requests.add(request_id) - self._remove_delayed_requests(request_id) + if request_id in self.delayed_free_requests: + self._remove_delayed_requests(request_id) + else: + self.record_finished_requests.add(request_id) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -91,7 +95,10 @@ class KVCacheTaskTracker: def add_delayed_request(self, request_id: str, delay_start_time: float): """Add a delayed free request.""" with self.done_task_lock: - self.delayed_free_requests.append((request_id, delay_start_time)) + if request_id not in self.record_finished_requests: + self.delayed_free_requests[request_id] = delay_start_time + else: + self.record_finished_requests.discard(request_id) def _retrieve_expired_requests(self): """Retrieve all expired delayed requests.""" @@ -99,10 +106,11 @@ class KVCacheTaskTracker: # Free delayed requests if they exceed the timeout current_time = time.time() while self.delayed_free_requests: - request_id, delay_start_time = self.delayed_free_requests[0] + request_id = next(iter(self.delayed_free_requests)) + delay_start_time = self.delayed_free_requests[request_id] if (current_time - delay_start_time > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): - self.delayed_free_requests.popleft() + self.delayed_free_requests.popitem(last=False) expired_requests.add(request_id) logger.info("Force freed request: %s", request_id) else: @@ -111,8 +119,7 @@ class KVCacheTaskTracker: def _remove_delayed_requests(self, request_id: str): """Remove all delayed free requests matching the given request_id.""" - self.delayed_free_requests = deque( - (r, t) for r, t in self.delayed_free_requests if r != request_id) + self.delayed_free_requests.pop(request_id) class KVCacheSendingThread(threading.Thread):