diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index c7a20e0..0b2782d 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch import msgspec import zmq from vllm.utils import make_zmq_path -from zmq import Context # type: ignore fake_engine = types.ModuleType("mooncake.engine") fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] @@ -32,193 +31,17 @@ DONE_RECVING_MSG = b"done_recving_msg" class TestKVCacheTaskTrackerInit(unittest.TestCase): def test_init_basic_properties(self): - tracker = KVCacheTaskTracker(tp_rank=1, - local_engine_id="engine1", - target_count=10) - self.assertEqual(tracker.tp_rank, 1) - self.assertEqual(tracker.local_engine_id, "engine1") - self.assertEqual(tracker.target_count, 10) + tracker = KVCacheTaskTracker() self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) - self.assertIsInstance(tracker.done_task_counts, defaultdict) self.assertIsInstance(tracker.finished_requests, set) - - def test_socket_path_generation(self): - tracker = KVCacheTaskTracker(tp_rank=1, - local_engine_id="engine42", - target_count=1) - self.assertEqual(tracker.socket_path, - "ipc:///tmp/vllm_mooncake_connector_engine42.ipc") - - @patch("vllm_ascend.distributed.mooncake_connector.threading.Thread") - def test_tp_rank_zero_initialization(self, mock_thread): - tracker = KVCacheTaskTracker(tp_rank=0, - local_engine_id="test", - target_count=1) - mock_thread.assert_called_once_with( - target=tracker._listen_for_completion_signals, - daemon=True, - name="KVCacheTaskTrackerListenerThread") - mock_thread.return_value.start.assert_called_once() - self.assertIsNone(tracker.socket) - self.assertTrue(tracker.listener.daemon) - - @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket") - @patch("vllm_ascend.distributed.mooncake_connector.logger") - def test_tp_rank_non_zero_initialization(self, mock_logger, - mock_make_zmq_socket): - mock_socket = MagicMock() - mock_make_zmq_socket.return_value = mock_socket - tracker = KVCacheTaskTracker(tp_rank=1, - local_engine_id="test", - target_count=1) - mock_make_zmq_socket.assert_called_once_with( - ctx=unittest.mock.ANY, - path="ipc:///tmp/vllm_mooncake_connector_test.ipc", - socket_type=zmq.PUSH, # type: ignore - bind=False) - mock_logger.info.assert_called_once_with( - "Connecting to transfer socket at %s", - "ipc:///tmp/vllm_mooncake_connector_test.ipc") - self.assertIsNone(tracker.listener) - self.assertEqual(tracker.socket, mock_socket) - - -class TestKVCacheTaskTrackerListenMethod(unittest.TestCase): - - def setUp(self): - self.tp_rank = 0 - self.local_engine_id = "test_engine_ut" - self.target_count = 3 - self.tracker = KVCacheTaskTracker(self.tp_rank, self.local_engine_id, - self.target_count) - self.original_listen = self.tracker._listen_for_completion_signals - - def tearDown(self): - self.tracker._listen_for_completion_signals = self.original_listen - Context.instance().term() - time.sleep(0.1) - - def test_normal_message_processing(self): - listener_thread = threading.Thread( - target=self.tracker._listen_for_completion_signals, daemon=True) - listener_thread.start() - time.sleep(0.2) - test_messages = [("request_001", 1), ("request_001", 2), - ("request_002", 0), ("request_003", 1)] - ctx = Context() - sender_socket = ctx.socket(zmq.PUSH) # type: ignore - sender_socket.connect(self.tracker.socket_path) - for msg in test_messages: - sender_socket.send_pyobj(msg) - time.sleep(0.05) - sender_socket.close() - time.sleep(0.2) - - with self.tracker.done_task_lock: - self.assertEqual(len(self.tracker.done_task_counts["request_001"]), - 2) - self.assertIn(1, self.tracker.done_task_counts["request_001"]) - self.assertIn(2, self.tracker.done_task_counts["request_001"]) - self.assertEqual(len(self.tracker.done_task_counts["request_002"]), - 1) - self.assertIn(0, self.tracker.done_task_counts["request_002"]) - self.assertEqual(len(self.tracker.done_task_counts["request_003"]), - 1) - self.assertIn(1, self.tracker.done_task_counts["request_003"]) - - @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", - autospec=True) - def test_listen_with_timeout(self, mock_make_socket): - mock_socket = MagicMock() - - def mock_recv(): - start = time.time() - while time.time() - start < 0.5: - time.sleep(0.01) - return ("req1", 0) - - mock_socket.recv_pyobj = mock_recv - mock_make_socket.return_value = mock_socket - - test_thread = threading.Thread( - target=self.tracker._listen_for_completion_signals, daemon=True) - test_thread.start() - test_thread.join(timeout=1.0) - mock_make_socket.assert_called_once() - - -class TestKVCacheTaskTrackerTP(unittest.TestCase): - - def setUp(self): - self.local_engine_id = "test_engine" - self.target_count = 3 - - def test_update_done_task_count_tp_rank_0(self): - tracker = KVCacheTaskTracker(tp_rank=0, - local_engine_id=self.local_engine_id, - target_count=self.target_count) - test_request_id = "test_req_001" - test_tp_rank = 1 - tracker.update_done_task_count(test_request_id, test_tp_rank) - with tracker.done_task_lock: - self.assertEqual(len(tracker.done_task_counts[test_request_id]), 1) - self.assertIn(test_tp_rank, - tracker.done_task_counts[test_request_id]) - - @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", - autospec=True) - def test_update_done_task_count_non_zero_tp(self, mock_make_socket): - mock_socket = MagicMock() - mock_make_socket.return_value = mock_socket - tracker = KVCacheTaskTracker(tp_rank=1, - local_engine_id=self.local_engine_id, - target_count=self.target_count) - test_request_id = "test_req_002" - test_tp_rank = 1 - tracker.update_done_task_count(test_request_id, test_tp_rank) - mock_socket.send_pyobj.assert_called_once_with( - (test_request_id, test_tp_rank)) - with tracker.done_task_lock: - self.assertNotIn(test_request_id, tracker.done_task_counts) - - @patch("vllm_ascend.distributed.mooncake_connector.logger", autospec=True) - @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", - autospec=True) - def test_update_done_task_count_logging(self, mock_make_socket, - mock_logger): - mock_socket = MagicMock() - mock_make_socket.return_value = mock_socket - tracker = KVCacheTaskTracker(tp_rank=2, - local_engine_id=self.local_engine_id, - target_count=self.target_count) - test_request_id = "test_req_003" - tracker.update_done_task_count(test_request_id, 2) - mock_logger.debug.assert_called_once_with( - "Sent done signal for request %s to tp 0", test_request_id) - - @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", - autospec=True) - def test_update_multiple_calls(self, mock_make_socket): - mock_socket = MagicMock() - mock_make_socket.return_value = mock_socket - tracker = KVCacheTaskTracker(tp_rank=1, - local_engine_id=self.local_engine_id, - target_count=self.target_count) - test_data = [("req1", 1), ("req1", 1), ("req2", 1)] - for req_id, rank in test_data: - tracker.update_done_task_count(req_id, rank) - self.assertEqual(mock_socket.send_pyobj.call_count, 3) - mock_socket.send_pyobj.assert_called_with(("req2", 1)) + self.assertIsInstance(tracker.delayed_free_requests, deque) class TestGetAndClearFinishedSingleRequests(unittest.TestCase): def setUp(self): - self.tracker = KVCacheTaskTracker(tp_rank=0, - local_engine_id="test", - target_count=3) + self.tracker = KVCacheTaskTracker() self.tracker.finished_requests = set() - self.tracker.done_task_counts = defaultdict(set) self.tracker.done_task_lock = threading.Lock() def test_empty_requests(self): @@ -251,14 +74,6 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase): self.assertEqual(sum(1 for r in results if r), 1) self.assertEqual(len(self.tracker.finished_requests), 0) - def test_after_increment(self): - self.tracker._increment_task_count("req_123", 0) - self.tracker._increment_task_count("req_123", 1) - self.tracker._increment_task_count("req_123", 2) - result = self.tracker.get_and_clear_finished_requests() - self.assertEqual(result, {"req_123"}) - self.assertEqual(self.tracker.get_and_clear_finished_requests(), set()) - class TestKVCacheSendingThreadInit(unittest.TestCase): @@ -282,47 +97,6 @@ class TestKVCacheSendingThreadInit(unittest.TestCase): if hasattr(thread, 'is_alive') and thread.is_alive(): thread.join(timeout=0.1) - @patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker') - def test_initialization_basic(self, mock_tracker): - thread = KVCacheSendingThread(**self.common_args) - self.threads.append(thread) - self.assertEqual(thread.tp_rank, 1) - self.assertEqual(thread.decode_tp_size, 4) - self.assertEqual(thread.local_engine_id, 'engine_1') - mock_tracker.assert_called_once() - args = mock_tracker.call_args[0] - kwargs = mock_tracker.call_args[1] - if args: - self.assertEqual(args[0], 1) - self.assertEqual(args[1], 'engine_1') - self.assertEqual(args[2], 4) - else: - self.assertEqual(kwargs['tp_rank'], 1) - self.assertEqual(kwargs['local_engine_id'], 'engine_1') - self.assertEqual(kwargs['target_count'], 4) - - @patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker') - def test_task_tracker_initialization(self, mock_tracker): - args = self.common_args.copy() - args.update({ - 'tp_rank': 2, - 'decode_tp_size': 8, - 'local_engine_id': 'engine_2' - }) - thread = KVCacheSendingThread(**args) - self.threads.append(thread) - mock_tracker.assert_called_once() - call_args = mock_tracker.call_args[0] - call_kwargs = mock_tracker.call_args[1] - if call_args: - self.assertEqual(call_args[0], 2) - self.assertEqual(call_args[1], 'engine_2') - self.assertEqual(call_args[2], 8) - else: - self.assertEqual(call_kwargs['tp_rank'], 2) - self.assertEqual(call_kwargs['local_engine_id'], 'engine_2') - self.assertEqual(call_kwargs['target_count'], 8) - def test_thread_daemon_property(self): thread = KVCacheSendingThread(**self.common_args) self.threads.append(thread) @@ -542,7 +316,7 @@ class TestCoreFunctionality(unittest.TestCase): mock_transfer.assert_called_once_with(self.test_req) mock_send.assert_called_once_with("req1", "localhost", 6666) self.thread.task_tracker.update_done_task_count.assert_called_once_with( - "req1", self.thread.tp_rank) + "req1") self.mock_queue.task_done.assert_called_once() @patch.object(KVCacheRecvingThread, '_get_remote_metadata') @@ -675,9 +449,11 @@ class TestMainThreadLoop(unittest.TestCase): class MockVllmConfig: def __init__(self): + self.model_config = MagicMock() self.parallel_config = MagicMock() self.cache_config = MagicMock() self.kv_transfer_config = MagicMock() + self.model_config.use_mla = True self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank_local = 0 self.parallel_config.data_parallel_size_local = 1 @@ -714,28 +490,40 @@ class MockRequest: class TestKVCacheTaskTracker(unittest.TestCase): def setUp(self): - self.tracker = KVCacheTaskTracker(tp_rank=0, - local_engine_id="test_engine", - target_count=2) + self.tracker = KVCacheTaskTracker() - def test_update_task_count(self): - self.assertEqual(len(self.tracker.done_task_counts), 0) + def test_update_done_task_count(self): self.assertEqual(len(self.tracker.finished_requests), 0) + self.assertEqual(len(self.tracker.delayed_free_requests), 0) - self.tracker.update_done_task_count("req1", 0) - self.tracker.update_done_task_count("req1", 1) + current_time = time.time() + self.tracker.add_delayed_request("req_1", current_time) + result = self.tracker.delayed_free_requests + self.assertEqual(len(result), 1) + self.assertEqual(result[0], ("req_1", current_time)) - self.assertEqual(len(self.tracker.finished_requests), 1) - self.assertTrue("req1" in self.tracker.finished_requests) + self.tracker.update_done_task_count("req_1") + result_finished = self.tracker.finished_requests + result_delayed = self.tracker.delayed_free_requests + self.assertEqual(result_finished, {"req_1"}) + self.assertEqual(len(result_delayed), 0) - finished = self.tracker.get_and_clear_finished_requests() - self.assertEqual(finished, {"req1"}) - self.assertEqual(len(self.tracker.finished_requests), 0) + def test_retrieve_expired_requests(self): + current_time = time.time() + self.tracker.add_delayed_request("req_1", current_time - 600) + self.tracker.add_delayed_request("req_2", current_time) + result = self.tracker._retrieve_expired_requests() + self.assertEqual(result, { + "req_1", + }) + result_delay = self.tracker.delayed_free_requests + self.assertEqual(len(result_delay), 1) + self.assertEqual(result_delay[0], ("req_2", current_time)) def test_duplicate_task_update(self): - self.tracker.update_done_task_count("req1", 0) - self.tracker.update_done_task_count("req1", 0) - self.tracker.update_done_task_count("req1", 1) + self.tracker.update_done_task_count("req1") + self.tracker.update_done_task_count("req1") + self.tracker.update_done_task_count("req1") finished = self.tracker.get_and_clear_finished_requests() self.assertEqual(finished, {"req1"}) @@ -745,6 +533,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase): def test_add_new_req(self): meta = MooncakeConnectorMetadata() + self.assertEqual(len(meta.requests), 0) + self.assertEqual(len(meta.requests_to_send), 0) + meta.add_new_req(request_id="req1", local_block_ids=[1, 2, 3], kv_transfer_params={ @@ -802,6 +593,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3]) self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + def test_get_finished_count(self): + count = self.scheduler.get_finished_count() + self.assertEqual(count, 2) + class TestHelperFunctions(unittest.TestCase): diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index c527db3..4faf37d 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -58,74 +58,21 @@ class ReqMeta: class KVCacheTaskTracker: - def __init__(self, tp_rank: int, local_engine_id: str, target_count: int): + def __init__(self): super().__init__() - self.tp_rank = tp_rank - self.local_engine_id = local_engine_id - self.target_count = target_count self.done_task_lock = threading.Lock() - self.done_task_counts: defaultdict[str, set[int]] = defaultdict(set) self.finished_requests: set[str] = set() + # Only used in prefill node. Tracks requests whose kv blocks freeing is + # intentionally delayed. Each entry is a tuple of (request_id, + # timestamp). If a request remains in this queue for too long, it will + # be force-freed. + self.delayed_free_requests: deque[Tuple[str, float]] = deque() - self.socket_path = \ - f"ipc:///tmp/vllm_mooncake_connector_{self.local_engine_id}.ipc" - if tp_rank == 0: - self.listener = threading.Thread( - target=self._listen_for_completion_signals, - daemon=True, - name="KVCacheTaskTrackerListenerThread") - self.listener.start() - self.socket = None - else: - self.listener = None # type: ignore - self.socket = make_zmq_socket( - ctx=zmq.Context(), # type: ignore - path=self.socket_path, - socket_type=zmq.PUSH, # type: ignore - bind=False) - logger.info("Connecting to transfer socket at %s", - self.socket_path) - - def _listen_for_completion_signals(self): - socket = make_zmq_socket( - ctx=zmq.Context(), # type: ignore - path=self.socket_path, - socket_type=zmq.PULL, # type: ignore - bind=True) - logger.info("Listening for completion signals on %s", self.socket_path) - - while True: - try: - done_request_id, tp_rank = socket.recv_pyobj() - logger.debug("Received completion notification for request: " - f"{done_request_id} from tp rank {tp_rank}") - self._increment_task_count(done_request_id, tp_rank) - except Exception as e: - logger.error(f"Error in run_busy_loop: {e}") - - def update_done_task_count(self, request_id: str, tp_rank: int): - if self.tp_rank == 0: - self._increment_task_count(request_id, tp_rank) - else: - self.socket.send_pyobj((request_id, tp_rank)) # type: ignore - logger.debug("Sent done signal for request %s to tp 0", request_id) - - def _increment_task_count(self, request_id: str, tp_rank: int): + def update_done_task_count(self, request_id: str): with self.done_task_lock: - if tp_rank in self.done_task_counts[request_id]: - logger.warning( - f"Received duplicate done signal for request {request_id} " - f"from tp rank {tp_rank}. Ignoring.") - return - - self.done_task_counts[request_id].add(tp_rank) - if len(self.done_task_counts[request_id]) == self.target_count: - self.finished_requests.add(request_id) - self.done_task_counts.pop(request_id) - logger.info("All transfers completed for request: " - f"{request_id}. Total ranks: " - f"{self.target_count}.") + self.finished_requests.add(request_id) + self._remove_delayed_requests(request_id) def get_and_clear_finished_requests(self) -> set[str]: """ @@ -135,9 +82,37 @@ class KVCacheTaskTracker: """ with self.done_task_lock: finished_requests = self.finished_requests.copy() + expired_requests = self._retrieve_expired_requests() + finished_requests.update(expired_requests) self.finished_requests.clear() return finished_requests + def add_delayed_request(self, request_id: str, delay_start_time: float): + """Add a delayed free request.""" + with self.done_task_lock: + self.delayed_free_requests.append((request_id, delay_start_time)) + + def _retrieve_expired_requests(self): + """Retrieve all expired delayed requests.""" + expired_requests: set[str] = set() + # Free delayed requests if they exceed the timeout + current_time = time.time() + while self.delayed_free_requests: + request_id, delay_start_time = self.delayed_free_requests[0] + if (current_time - delay_start_time + > envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT): + self.delayed_free_requests.popleft() + expired_requests.add(request_id) + logger.info("Force freed request: %s", request_id) + else: + break + return expired_requests + + def _remove_delayed_requests(self, request_id: str): + """Remove all delayed free requests matching the given request_id.""" + self.delayed_free_requests = deque( + (r, t) for r, t in self.delayed_free_requests if r != request_id) + class KVCacheSendingThread(threading.Thread): @@ -154,9 +129,7 @@ class KVCacheSendingThread(threading.Thread): self.metadata = metadata self.ready_event = ready_event - self.task_tracker = KVCacheTaskTracker(self.tp_rank, - self.local_engine_id, - self.decode_tp_size) + self.task_tracker = KVCacheTaskTracker() def get_and_clear_finished_requests(self) -> set[str]: """ @@ -166,6 +139,10 @@ class KVCacheSendingThread(threading.Thread): """ return self.task_tracker.get_and_clear_finished_requests() + def add_delayed_request(self, request_id: str, delay_start_time: float): + return self.task_tracker.add_delayed_request(request_id, + delay_start_time) + def run(self): """Run the thread to handle KV cache transfer requests.""" @@ -204,9 +181,8 @@ class KVCacheSendingThread(threading.Thread): elif msg[0] == DONE_RECVING_MSG: logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) - request_id, decode_tp_rank = msg[1], msg[2] - self.task_tracker.update_done_task_count( - request_id, decode_tp_rank) + request_id = msg[1] + self.task_tracker.update_done_task_count(request_id) # Acknowledge the request completion. while True: try: @@ -259,9 +235,7 @@ class KVCacheRecvingThread(threading.Thread): # TODO(jianzs): make this configurable self.executor = ThreadPoolExecutor(max_workers=32) - self.task_tracker = KVCacheTaskTracker(self.tp_rank, - self.local_engine_id, - self.tp_size) + self.task_tracker = KVCacheTaskTracker() self.encoder = msgspec.msgpack.Encoder() self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) @@ -323,7 +297,7 @@ class KVCacheRecvingThread(threading.Thread): logger.error("Failed to transfer KV cache for request " f"{request_id}: {e}") finally: - self.task_tracker.update_done_task_count(request_id, self.tp_rank) + self.task_tracker.update_done_task_count(request_id) # Always send the done signal to the remote host to ensure proper # resource cleanup. Failing to do so may cause a memory leak on the # remote host. @@ -422,8 +396,7 @@ class KVCacheRecvingThread(threading.Thread): sock: Optional[zmq.Socket] = None # type: ignore try: sock = self._get_remote_socket(remote_host, remote_handshake_port) - data_bytes = self.encoder.encode( - (DONE_RECVING_MSG, request_id, self.tp_rank)) + data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id)) ensure_zmq_send(sock, data_bytes) resp = ensure_zmq_recv(sock, self.remote_poller, @@ -479,6 +452,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): def __init__(self): self.requests: dict[str, ReqMeta] = {} + self.requests_to_send: dict[str, float] = {} def add_new_req( self, @@ -543,6 +517,10 @@ class MooncakeConnector(KVConnectorBase_V1): assert self.connector_scheduler is not None return self.connector_scheduler.request_finished(request, block_ids) + def get_finished_count(self) -> Optional[int]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_finished_count() + ############################################################ # Worker Side Methods ############################################################ @@ -599,6 +577,7 @@ class MooncakeConnectorScheduler: # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -684,6 +663,8 @@ class MooncakeConnectorScheduler: # Clear the list once workers start the transfers self._reqs_need_recv.clear() + meta.requests_to_send = self._reqs_need_send + self._reqs_need_send = {} return meta @@ -711,6 +692,8 @@ class MooncakeConnectorScheduler: if delay_free_blocks: logger.info("Delaying free of %d blocks for request %s", len(computed_block_ids), request.request_id) + self._reqs_need_send[request.request_id] = time.time() + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -720,6 +703,27 @@ class MooncakeConnectorScheduler: remote_port=self.side_channel_port, ) + def get_finished_count(self) -> Optional[int]: + prefill_parallel_config: dict[ + str, + Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + decode_parallel_config: dict[ + str, + Any] = self.vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + + if self.vllm_config.model_config.use_mla: + return self._decode_tp_size + else: + # TODO support mha and gqa + return None + class MooncakeConnectorWorker: """Implementation of Worker side methods""" @@ -737,6 +741,7 @@ class MooncakeConnectorWorker: self.engine = TransferEngine() # Metadata. + self.vllm_config = vllm_config self.engine_id = engine_id self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size @@ -946,6 +951,12 @@ class MooncakeConnectorWorker: remote_handshake_port=remote_handshake_port, ) + if self.kv_send_thread is not None: + for req_id, delay_start_time in metadata.requests_to_send.items(): + if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id): + self.kv_send_thread.add_delayed_request( + req_id, delay_start_time) + def _get_remote_tp_rank(self, req_id: str) -> int: return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 04d94a9..78f8c50 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -139,6 +139,11 @@ env_variables: Dict[str, Callable[[], Any]] = { # caused by the initialization of the Mooncake connector. "PHYSICAL_DEVICES": lambda: os.getenv("PHYSICAL_DEVICES", None), + # Timeout (in seconds) for delayed KVCache block release. In the prefill + # node, if a request is marked for delayed KV block release and the blocks + # are not freed within this timeout, they will be forcibly released. + "VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT": + lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), } # end-env-vars-definition