diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 13a24596..20ae60f0 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -89,7 +89,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase): kv_caches: Dict[str, Any] = {} self.common_args = { 'tp_rank': 1, - 'decode_tp_size': 4, + 'prefill_tp_size': 4, 'local_engine_id': 'engine_1', 'side_channel_host': 'localhost', 'side_channel_port': 5555, @@ -133,7 +133,7 @@ class TestGetAndClearFinishedRequests(unittest.TestCase): kv_caches: Dict[str, Any] = {} self.common_args = { 'tp_rank': 1, - 'decode_tp_size': 4, + 'prefill_tp_size': 4, 'local_engine_id': 'engine_1', 'side_channel_host': 'localhost', 'side_channel_port': 5555, @@ -171,7 +171,7 @@ class TestKVCacheSendingThread(unittest.TestCase): free_port = s.getsockname()[1] thread = KVCacheSendingThread(tp_rank=0, - decode_tp_size=1, + prefill_tp_size=1, local_engine_id="engine1", side_channel_host=host, side_channel_port=free_port, @@ -237,7 +237,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): "remote_host": "localhost", "remote_handshake_port": 6666, "offset": 0, - "num_need_pulls": 2 + "num_need_pulls": 2, + "all_task_done": False } self.thread.add_request( request_id=test_req["request_id"], @@ -247,7 +248,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): remote_host=test_req["remote_host"], remote_handshake_port=test_req["remote_handshake_port"], offset=test_req["offset"], - num_need_pulls=test_req["num_need_pulls"]) + num_need_pulls=test_req["num_need_pulls"], + all_task_done=test_req["all_task_done"]) queued = self.thread.request_queue.get_nowait() self.assertEqual(queued["request_id"], "req1") self.assertEqual(queued["remote_host"], "localhost") @@ -341,7 +343,8 @@ class TestCoreFunctionality(unittest.TestCase): "remote_handshake_port": 6666, "remote_transfer_port": 7777, "offset": 0, - "num_need_pulls": 2 + "num_need_pulls": 2, + "all_task_done": False } self.thread.task_tracker = MagicMock() self.engine.batch_transfer_sync_read.return_value = 0 @@ -485,7 +488,8 @@ class TestMainThreadLoop(unittest.TestCase): "remote_handshake_port": 6666, "remote_transfer_port": 7777, "offset": 0, - "num_need_pulls": 2 + "num_need_pulls": 2, + "all_task_done": False } self.thread.request_queue.put(test_request) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 3ca17a59..cf3bbaa0 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -150,13 +150,14 @@ class KVCacheTaskTracker: class KVCacheSendingThread(threading.Thread): - def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, - side_channel_host: str, side_channel_port: int, - metadata: MooncakeAgentMetadata, ready_event: threading.Event, - kv_caches: dict[str, Any], pcp_rank: int): + def __init__(self, tp_rank: int, prefill_tp_size: int, + local_engine_id: str, side_channel_host: str, + side_channel_port: int, metadata: MooncakeAgentMetadata, + ready_event: threading.Event, kv_caches: dict[str, Any], + pcp_rank: int): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank - self.decode_tp_size = decode_tp_size + self.prefill_tp_size = prefill_tp_size self.local_engine_id = local_engine_id self.side_channel_host = side_channel_host self.side_channel_port = side_channel_port @@ -195,7 +196,7 @@ class KVCacheSendingThread(threading.Thread): # NOTE(rob): we need each rank to have a unique port. This hack to keeps # us moving. We will switch when moving to etcd or where we have a # single ZMQ socket in the scheduler. - handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \ + handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \ + self.tp_rank path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) @@ -295,7 +296,7 @@ class KVCacheRecvingThread(threading.Thread): def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, remote_host: str, remote_handshake_port: int, offset: int, - num_need_pulls: int): + num_need_pulls: int, all_task_done: bool): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -306,7 +307,8 @@ class KVCacheRecvingThread(threading.Thread): "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, "offset": offset, - "num_need_pulls": num_need_pulls + "num_need_pulls": num_need_pulls, + "all_task_done": all_task_done }) def get_and_clear_finished_requests(self) -> set[str]: @@ -335,8 +337,7 @@ class KVCacheRecvingThread(threading.Thread): request_id = req_meta["request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] - offset = req_meta["offset"] - num_need_pulls = req_meta["num_need_pulls"] + all_task_done = req_meta["all_task_done"] try: logger.debug( @@ -353,7 +354,7 @@ class KVCacheRecvingThread(threading.Thread): # remote host. self._send_done_recv_signal(request_id, remote_host, remote_handshake_port) - if offset == num_need_pulls - 1: + if all_task_done: self.task_tracker.update_done_task_count(request_id) self.request_queue.task_done() @@ -1091,7 +1092,7 @@ class MooncakeConnectorWorker: ready_event = threading.Event() if self.kv_role == 'kv_producer': self.kv_send_thread = KVCacheSendingThread( - self.tp_rank, self._decode_tp_size, self.engine_id, + self.tp_rank, self._prefill_tp_size, self.engine_id, self.side_channel_host, self.side_channel_port, metadata, ready_event, self.kv_caches, self.pcp_rank) self.kv_send_thread.start() @@ -1239,7 +1240,10 @@ class MooncakeConnectorWorker: remote_handshake_port=remote_handshake_port_list[ pcp_dcp_rank][i], offset=i, - num_need_pulls=self.num_need_pulls) + num_need_pulls=self.num_need_pulls, + all_task_done=(pcp_dcp_rank + == len(remote_handshake_port_list) - 1 + and i == self.num_need_pulls - 1)) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items():