From f3ea657e932dbecff2df7bb9e8fa50a167d3dd1d Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Thu, 23 Oct 2025 20:39:35 +0800 Subject: [PATCH] [0.11.0][Bugfix] fix delay free prefill req & D node support prefix cache (#3609) ### What this PR does / why we need it? Fix mooncake connector. In scenarios where TP is not equal, when the prefill TP size is less than the number of key-value heads, _get_remote_tp_ranks_for_req will return a list of np.arrays. Performing an operation like int in list of np.arrays will cause an error. Converting the list of np.arrays into a single np.array resolves this issue. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? qwen235B P tp16, D tp1 P tp8, D tp1 P tp4, D tp1 P tp8, D tp2 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: liziyu --- .../kv_connector/test_mooncake_connector.py | 24 +++++++-------- vllm_ascend/distributed/mooncake_connector.py | 29 ++++++++++++------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 7bb2142..6c6c609 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -628,10 +628,12 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): def setUp(self): config = MockVllmConfig() - with patch( - 'vllm_ascend.distributed.mooncake_connector.init_ascend_config' - ): - self.scheduler = MooncakeConnectorScheduler(config, "test_engine") + self.p1 = patch( + 'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config', + new=MagicMock(return_value=None)) + self.p1.start() + self.addCleanup(self.p1.stop) + self.scheduler = MooncakeConnectorScheduler(config, "test_engine") def test_get_num_new_matched_tokens(self): request = MockRequest("req1") @@ -643,7 +645,7 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): request.kv_transfer_params = {"do_remote_prefill": True} tokens, async_flag = self.scheduler.get_num_new_matched_tokens( request, 0) - self.assertEqual(tokens, 3) + self.assertEqual(tokens, 4) self.assertTrue(async_flag) def test_build_connector_meta(self): @@ -820,7 +822,7 @@ class TestMooncakeConnectorScheduler(unittest.TestCase): kv_transfer_params={"do_remote_prefill": True}) tokens, async_flag = self.scheduler.get_num_new_matched_tokens( request, 0) - self.assertEqual(tokens, 3) + self.assertEqual(tokens, 4) self.assertTrue(async_flag) def test_update_state_after_alloc_no_remote_prefill(self): @@ -1036,8 +1038,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase): patch( 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', mock_string_to_int64_hash), - patch('vllm_ascend.distributed.mooncake_connector.TransferEngine', - return_value=self.mock_transfer_engine), + patch( + 'vllm_ascend.distributed.mooncake.transfer_engine.TransferEngine', + return_value=self.mock_transfer_engine), patch( 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', MagicMock()), @@ -1063,7 +1066,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): for p in self.patches: p.stop() # type: ignore - @unittest.skip("skip") def test_worker_use_ascend_direct(self): test_case = [True, False] @@ -1104,7 +1106,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): config, self.engine_id) self.assertIsNotNone(worker) - @unittest.skip("skip") def test_register_kv_caches_producer(self): worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) worker.register_kv_caches(self.kv_caches) @@ -1112,7 +1113,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertIsNotNone(worker.kv_send_thread) self.assertIsNone(worker.kv_recv_thread) - @unittest.skip("skip") def test_register_kv_caches_consumer(self): self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer' worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) @@ -1120,7 +1120,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertIsNone(worker.kv_send_thread) self.assertIsNotNone(worker.kv_recv_thread) - @unittest.skip("skip") def test_register_kv_caches_mla_case(self): mla_cache1 = MagicMock() mla_cache1.size.return_value = (10, 16, 1, 16) @@ -1133,7 +1132,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertTrue(worker.use_mla) self.assertEqual(len(worker.block_len), 2) - @unittest.skip("skip") def test_device_id_selection_with_physical_devices(self): # Test with physical devices set worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 23dfb32..57b4494 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -342,9 +342,15 @@ class KVCacheRecvingThread(threading.Thread): # Full prefix cache hit: do not need to read remote blocks, just notify # P worker that we have the blocks we need. - if len(local_block_ids) == 0: + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: return + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids = remote_block_ids[-num_local_blocks:] + # Check if we have the remote metadata cached. if remote_engine_id not in self.kv_caches_base_addr or \ remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: @@ -736,13 +742,11 @@ class MooncakeConnectorScheduler: num_computed_tokens, params) if params is not None and params.get("do_remote_prefill"): - assert num_computed_tokens == 0, "Currently only support " \ - "prefill with num_computed_tokens == 0." - # Assume that the request's KV cache is already fully prefilled and - # can be fetched entirely from the prefill node. - count = max(len(request.prompt_token_ids) - 1, 0) - if count > 0: - return count, True + # Remote prefill: get all prompt blocks from remote. + assert num_computed_tokens % self.block_size == 0 + # Note: We use the full token count as transmit data here. + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) + return count, count > 0 # No remote prefill for this request. return 0, False @@ -1137,10 +1141,13 @@ class MooncakeConnectorWorker: if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items(): - if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id): + if self.tp_rank in self._prefill_get_remote_tp_rank(req_id): self.kv_send_thread.add_delayed_request( req_id, delay_start_time) + def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]: + return sum(self._get_remote_tp_ranks_for_req(req_id), []) + def _get_remote_tp_rank(self, req_id: str) -> List[int]: return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] @@ -1176,8 +1183,8 @@ class MooncakeConnectorWorker: else: group_size = self._prefill_tp_size // self._decode_tp_size for i in range(self._decode_tp_size): - slice = ori_data[i * group_size:(i + 1) * group_size] - sampled_nums.append(slice) + ori_data_slice = ori_data[i * group_size:(i + 1) * group_size] + sampled_nums.append(ori_data_slice.tolist()) return sampled_nums