diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 4db7e887..3b305e9e 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -242,8 +242,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + kv_caches=self.kv_caches) def test_add_request(self): test_req = { @@ -296,8 +295,7 @@ class TestSocketManagement(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + kv_caches=self.kv_caches) self.thread.remote_sockets = defaultdict(deque) self.thread.remote_poller = MagicMock() @@ -354,8 +352,7 @@ class TestCoreFunctionality(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + kv_caches=self.kv_caches) self.thread.request_queue = self.mock_queue self.test_req = { "request_id": "req1", @@ -437,8 +434,7 @@ class TestMetadataHandling(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + kv_caches=self.kv_caches) self.test_metadata = MooncakeAgentMetadata( engine_id="remote_engine", te_rpc_port=9090, @@ -502,8 +498,7 @@ class TestMainThreadLoop(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches, - prefill_pp_layer_partition=None) + kv_caches=self.kv_caches) self.thread.request_queue = queue.Queue() @patch.object(KVCacheRecvingThread, '_handle_request') @@ -540,7 +535,6 @@ class MockVllmConfig: self.parallel_config = MagicMock() self.cache_config = MagicMock() self.kv_transfer_config = MagicMock() - self.speculative_config = MagicMock() self.model_config.use_mla = True self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank = 0 diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 1d37498f..b8b66b4c 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -276,19 +276,12 @@ class KVCacheSendingThread(threading.Thread): class KVCacheRecvingThread(threading.Thread): - def __init__(self, - tp_rank: int, - tp_size: int, - _prefill_pp_size: int, - engine: TransferEngine, - local_engine_id: str, + def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int, + engine: TransferEngine, local_engine_id: str, local_handshake_port: int, - local_kv_caches_base_addr: list[int], - block_len: list[int], - ready_event: threading.Event, - vllm_config: VllmConfig, - kv_caches: dict[str, Any], - prefill_pp_layer_partition: Optional[str] = None): + local_kv_caches_base_addr: list[int], block_len: list[int], + ready_event: threading.Event, vllm_config: VllmConfig, + kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -327,14 +320,6 @@ class KVCacheRecvingThread(threading.Thread): self.vllm_config = vllm_config self.model_config = self.vllm_config.model_config self.block_size = self.vllm_config.cache_config.block_size - self.num_layers = self.model_config.hf_config.num_hidden_layers - self.pp_layer_indices = { - rank: - get_prefill_pp_indices(self.num_layers, rank, - self._prefill_pp_size, - prefill_pp_layer_partition) - for rank in range(self._prefill_pp_size) - } if self.use_mla: self.k_head_dim = self.model_config.hf_config.kv_lora_rank self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim @@ -455,14 +440,9 @@ class KVCacheRecvingThread(threading.Thread): remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] - first_layer_index, end_layer_index = self.pp_layer_indices[ - prefill_pp_rank] - # support MTP layer kv transfer - if self.vllm_config.speculative_config is not None: - num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens - num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens - if prefill_pp_rank == self._prefill_pp_size - 1: - end_layer_index = end_layer_index + num_speculative_tokens + num_layers = self.model_config.hf_config.num_hidden_layers + first_layer_index, end_layer_index = get_pp_indices( + num_layers, prefill_pp_rank, self._prefill_pp_size) num_cache_per_layer = len(list( self.kv_caches.values())[0]) # Number of KV caches per layer local_kv_caches_base_addrs = \ @@ -1045,8 +1025,6 @@ class MooncakeConnectorWorker: # get prefill pp size from extra config self._decode_pp_size = decode_parallel_config.get("pp_size", 1) assert self._decode_pp_size == 1, "decode pp size must be 1" - self._prefill_pp_layer_partition = prefill_parallel_config.get( - "pp_layer_partition", None) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1153,8 +1131,7 @@ class MooncakeConnectorWorker: self.kv_recv_thread = KVCacheRecvingThread( self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, self.engine_id, self.handshake_port, kv_caches_base_addr, - self.block_len, ready_event, self.vllm_config, self.kv_caches, - self._prefill_pp_layer_partition) + self.block_len, ready_event, self.vllm_config, self.kv_caches) self.kv_recv_thread.start() start_wait_time = time.time() @@ -1494,31 +1471,3 @@ def ensure_zmq_recv( raise RuntimeError( f"Failed to receive data after {max_retries} " f"retries: {e}") - - -# decode node should know pp_partition_layer in prefill node, -# it is configured in kv_transfer_config by partition_list_str, -# default using vllm layer split algorithm. -def get_prefill_pp_indices( - num_hidden_layers: int, - pp_rank: int, - pp_size: int, - partition_list_str: Optional[str] = None) -> tuple[int, int]: - if partition_list_str is None: - return get_pp_indices(num_hidden_layers, pp_rank, pp_size) - else: - try: - partitions = [ - int(layer) for layer in partition_list_str.split(",") - ] - except ValueError as err: - raise ValueError("Invalid partition string: {}".format( - partition_list_str)) from err - if len(partitions) != pp_size: - raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") - if sum(partitions) != num_hidden_layers: - raise ValueError( - f"{sum(partitions)=} does not match {num_hidden_layers=}.") - start_layer = sum(partitions[:pp_rank]) - end_layer = start_layer + partitions[pp_rank] - return (start_layer, end_layer)