diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index a0ef02f1..0b7ff138 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -246,7 +246,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) def test_add_request(self): test_req = { @@ -300,7 +301,8 @@ class TestSocketManagement(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.remote_sockets = defaultdict(deque) self.thread.remote_poller = MagicMock() @@ -358,7 +360,8 @@ class TestCoreFunctionality(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.request_queue = self.mock_queue self.test_req = { "request_id": "req1", @@ -444,7 +447,8 @@ class TestMetadataHandling(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.test_metadata = MooncakeAgentMetadata( engine_id="remote_engine", te_rpc_port=9090, @@ -509,7 +513,8 @@ class TestMainThreadLoop(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.request_queue = queue.Queue() @patch.object(KVCacheRecvingThread, '_handle_request') @@ -546,6 +551,7 @@ class MockVllmConfig: self.parallel_config = MagicMock() self.cache_config = MagicMock() self.kv_transfer_config = MagicMock() + self.speculative_config = MagicMock() self.model_config.use_mla = True self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank = 0 diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 603a89b8..2b0fe92a 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -292,12 +292,20 @@ class KVCacheSendingThread(threading.Thread): class KVCacheRecvingThread(threading.Thread): - def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int, - engine: TransferEngine, local_engine_id: str, - local_handshake_port: int, side_channel_port: int, - local_kv_caches_base_addr: list[int], block_len: list[int], - ready_event: threading.Event, vllm_config: VllmConfig, - kv_caches: dict[str, Any]): + def __init__(self, + tp_rank: int, + tp_size: int, + _prefill_pp_size: int, + engine: TransferEngine, + local_engine_id: str, + local_handshake_port: int, + side_channel_port: int, + local_kv_caches_base_addr: list[int], + block_len: list[int], + ready_event: threading.Event, + vllm_config: VllmConfig, + kv_caches: dict[str, Any], + prefill_pp_layer_partition: Optional[str] = None): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -337,6 +345,14 @@ class KVCacheRecvingThread(threading.Thread): self.vllm_config = vllm_config self.model_config = self.vllm_config.model_config self.block_size = self.vllm_config.cache_config.block_size + self.num_layers = self.model_config.hf_config.num_hidden_layers + self.pp_layer_indices = { + rank: + get_prefill_pp_indices(self.num_layers, rank, + self._prefill_pp_size, + prefill_pp_layer_partition) + for rank in range(self._prefill_pp_size) + } if not is_vl_model(vllm_config): if self.use_mla: self.k_head_dim = self.model_config.hf_text_config.kv_lora_rank @@ -490,9 +506,13 @@ class KVCacheRecvingThread(threading.Thread): remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] - num_layers = self.model_config.hf_text_config.num_hidden_layers - first_layer_index, end_layer_index = get_pp_indices( - num_layers, prefill_pp_rank, self._prefill_pp_size) + first_layer_index, end_layer_index = self.pp_layer_indices[ + prefill_pp_rank] + # support MTP layer kv transfer + if self.vllm_config.speculative_config is not None: + # all MTP layer use the same kv cache layer, so only need to transfer once + if prefill_pp_rank == self._prefill_pp_size - 1: + end_layer_index = end_layer_index + 1 num_cache_per_layer = len(list( self.kv_caches.values())[0]) # Number of KV caches per layer local_kv_caches_base_addrs = \ @@ -1161,6 +1181,8 @@ class MooncakeConnectorWorker: # get prefill pp size from extra config self._decode_pp_size = decode_parallel_config.get("pp_size", 1) assert self._decode_pp_size == 1, "decode pp size must be 1" + self._prefill_pp_layer_partition = prefill_parallel_config.get( + "pp_layer_partition", None) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1270,7 +1292,8 @@ class MooncakeConnectorWorker: self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, self.engine_id, self.handshake_port, self.side_channel_port, kv_caches_base_addr, self.block_len, ready_event, - self.vllm_config, self.kv_caches) + self.vllm_config, self.kv_caches, + self._prefill_pp_layer_partition) self.kv_recv_thread.start() start_wait_time = time.time() @@ -1761,3 +1784,31 @@ def ensure_zmq_recv( raise RuntimeError( f"Failed to receive data after {max_retries} " f"retries: {e}") + + +# decode node should know pp_partition_layer in prefill node, +# it is configured in kv_transfer_config by partition_list_str, +# default using vllm layer split algorithm. +def get_prefill_pp_indices( + num_hidden_layers: int, + pp_rank: int, + pp_size: int, + partition_list_str: Optional[str] = None) -> tuple[int, int]: + if partition_list_str is None: + return get_pp_indices(num_hidden_layers, pp_rank, pp_size) + else: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + return (start_layer, end_layer)