diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index a3a593e..2ea23bc 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -7,6 +7,7 @@ import time import types import unittest from collections import defaultdict, deque +from typing import OrderedDict from unittest.mock import MagicMock, patch import msgspec @@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase): tracker = KVCacheTaskTracker() self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) self.assertIsInstance(tracker.finished_requests, set) - self.assertIsInstance(tracker.delayed_free_requests, deque) + self.assertIsInstance(tracker.delayed_free_requests, OrderedDict) class TestGetAndClearFinishedSingleRequests(unittest.TestCase): @@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase): def test_update_done_task_count(self): self.assertEqual(len(self.tracker.finished_requests), 0) self.assertEqual(len(self.tracker.delayed_free_requests), 0) + self.assertEqual(len(self.tracker.record_finished_requests), 0) current_time = time.time() self.tracker.add_delayed_request("req_1", current_time) result = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(len(result), 1) - self.assertEqual(result[0], ("req_1", current_time)) + self.assertEqual(result["req_1"], current_time) + self.assertEqual(len(result_record), 0) self.tracker.update_done_task_count("req_1") result_finished = self.tracker.finished_requests result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests self.assertEqual(result_finished, {"req_1"}) self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 0) + + self.tracker.update_done_task_count("req_2") + result_finished = self.tracker.finished_requests + result_delayed = self.tracker.delayed_free_requests + result_record = self.tracker.record_finished_requests + self.assertEqual(result_finished, {"req_1", "req_2"}) + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_record), 1) + self.assertEqual(result_record, {"req_2"}) + + def test_updtate_add_delayed_request(self) -> None: + self.tracker.update_done_task_count("req2") + result_start_record = self.tracker.record_finished_requests + self.assertEqual(len(result_start_record), 1) + self.tracker.add_delayed_request("req2", time.time()) + result_delayed = self.tracker.delayed_free_requests + result_end_record = self.tracker.record_finished_requests + self.assertEqual(len(result_delayed), 0) + self.assertEqual(len(result_end_record), 0) def test_retrieve_expired_requests(self): current_time = time.time() @@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase): }) result_delay = self.tracker.delayed_free_requests self.assertEqual(len(result_delay), 1) - self.assertEqual(result_delay[0], ("req_2", current_time)) + self.assertIn("req_2", result_delay) def test_duplicate_task_update(self): self.tracker.update_done_task_count("req1") diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8c37ad4..7faf1be 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -11,7 +11,7 @@ from collections import defaultdict, deque from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple import msgspec import numpy as np @@ -68,12 +68,16 @@ class KVCacheTaskTracker: # intentionally delayed. Each entry is a tuple of (request_id, # timestamp). If a request remains in this queue for too long, it will # be force-freed. - self.delayed_free_requests: deque[Tuple[str, float]] = deque() + self.record_finished_requests: set[str] = set() + self.delayed_free_requests: OrderedDict[str, float] = OrderedDict() def update_done_task_count(self, request_id: str): with self.done_task_lock: self.finished_requests.add(request_id) - self._remove_delayed_requests(request_id) + if request_id in self.delayed_free_requests: + self._remove_delayed_requests(request_id) + else: + self.record_finished_requests.add(request_id) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -91,7 +95,10 @@ class KVCacheTaskTracker: def add_delayed_request(self, request_id: str, delay_start_time: float): """Add a delayed free request.""" with self.done_task_lock: - self.delayed_free_requests.append((request_id, delay_start_time)) + if request_id not in self.record_finished_requests: + self.delayed_free_requests[request_id] = delay_start_time + else: + self.record_finished_requests.discard(request_id) def _retrieve_expired_requests(self): """Retrieve all expired delayed requests.""" @@ -99,10 +106,11 @@ class KVCacheTaskTracker: # Free delayed requests if they exceed the timeout current_time = time.time() while self.delayed_free_requests: - request_id, delay_start_time = self.delayed_free_requests[0] + request_id = next(iter(self.delayed_free_requests)) + delay_start_time = self.delayed_free_requests[request_id] if (current_time - delay_start_time > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): - self.delayed_free_requests.popleft() + self.delayed_free_requests.popitem(last=False) expired_requests.add(request_id) logger.info("Force freed request: %s", request_id) else: @@ -111,8 +119,7 @@ class KVCacheTaskTracker: def _remove_delayed_requests(self, request_id: str): """Remove all delayed free requests matching the given request_id.""" - self.delayed_free_requests = deque( - (r, t) for r, t in self.delayed_free_requests if r != request_id) + self.delayed_free_requests.pop(request_id) class KVCacheSendingThread(threading.Thread):