From 332b547728c24a900a40c47a4eba6f5048117efc Mon Sep 17 00:00:00 2001 From: lidenghui1110 <30521952+lidenghui1110@users.noreply.github.com> Date: Thu, 11 Dec 2025 17:23:21 +0800 Subject: [PATCH] [Bugfix] support mtp kv transfer and pp partition by hand in kv transfer (#4892) ### What this PR does / why we need it? Current mooncake connector has following problems with PP and MTP enabled: 1. MTP layer kv caches are not transfered, it may cause decreasing of accept ratio: This PR add MTP layer indices for last PP stage after calculating end_layer in transfer_kv_cache 2. While MTP enabled, PP layers divided by default may cause imbalance between stages, we need to use `VLLM_PP_LAYER_PARTITION` environment to make it balance by hand, but in mooncake connector kv transfer, decode doesn't know the partition of prefill node: This PR add config `pp_layer_partition` in `kv_connector_extra_config` to make decode node acquire the partition information of prefill node. ### Does this PR introduce _any_ user-facing change? When prefill using `VLLM_PP_LAYER_PARTITION` environment, add `pp_layer_partition` in `kv_connector_extra_config` like below: ``` export VLLM_PP_LAYER_PARTITION=33,28 "kv_connector_extra_config": { "use_ascend_direct": true, "prefill": { "dp_size": 1, "tp_size": 8, "pp_size": 2, "pp_layer_partition": "33,28" }, "decode": { "dp_size": 16, "tp_size": 1, "pp_size": 1 } } ``` ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 --------- Signed-off-by: lidenghui --- .../kv_connector/test_mooncake_connector.py | 16 +++-- vllm_ascend/distributed/mooncake_connector.py | 69 ++++++++++++++++--- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 3b305e9e..4db7e887 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -242,7 +242,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) def test_add_request(self): test_req = { @@ -295,7 +296,8 @@ class TestSocketManagement(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.remote_sockets = defaultdict(deque) self.thread.remote_poller = MagicMock() @@ -352,7 +354,8 @@ class TestCoreFunctionality(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.request_queue = self.mock_queue self.test_req = { "request_id": "req1", @@ -434,7 +437,8 @@ class TestMetadataHandling(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.test_metadata = MooncakeAgentMetadata( engine_id="remote_engine", te_rpc_port=9090, @@ -498,7 +502,8 @@ class TestMainThreadLoop(unittest.TestCase): block_len=[1024, 2048], ready_event=self.ready_event, vllm_config=self.vllm_config, - kv_caches=self.kv_caches) + kv_caches=self.kv_caches, + prefill_pp_layer_partition=None) self.thread.request_queue = queue.Queue() @patch.object(KVCacheRecvingThread, '_handle_request') @@ -535,6 +540,7 @@ class MockVllmConfig: self.parallel_config = MagicMock() self.cache_config = MagicMock() self.kv_transfer_config = MagicMock() + self.speculative_config = MagicMock() self.model_config.use_mla = True self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank = 0 diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 2414b92d..7d45c475 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -271,12 +271,19 @@ class KVCacheSendingThread(threading.Thread): class KVCacheRecvingThread(threading.Thread): - def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int, - engine: TransferEngine, local_engine_id: str, + def __init__(self, + tp_rank: int, + tp_size: int, + _prefill_pp_size: int, + engine: TransferEngine, + local_engine_id: str, local_handshake_port: int, - local_kv_caches_base_addr: list[int], block_len: list[int], - ready_event: threading.Event, vllm_config: VllmConfig, - kv_caches: dict[str, Any]): + local_kv_caches_base_addr: list[int], + block_len: list[int], + ready_event: threading.Event, + vllm_config: VllmConfig, + kv_caches: dict[str, Any], + prefill_pp_layer_partition: Optional[str] = None): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size @@ -315,6 +322,14 @@ class KVCacheRecvingThread(threading.Thread): self.vllm_config = vllm_config self.model_config = self.vllm_config.model_config self.block_size = self.vllm_config.cache_config.block_size + self.num_layers = self.model_config.hf_config.num_hidden_layers + self.pp_layer_indices = { + rank: + get_prefill_pp_indices(self.num_layers, rank, + self._prefill_pp_size, + prefill_pp_layer_partition) + for rank in range(self._prefill_pp_size) + } if self.use_mla: self.k_head_dim = self.model_config.hf_config.kv_lora_rank self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim @@ -435,9 +450,14 @@ class KVCacheRecvingThread(threading.Thread): remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] - num_layers = self.model_config.hf_config.num_hidden_layers - first_layer_index, end_layer_index = get_pp_indices( - num_layers, prefill_pp_rank, self._prefill_pp_size) + first_layer_index, end_layer_index = self.pp_layer_indices[ + prefill_pp_rank] + # support MTP layer kv transfer + if self.vllm_config.speculative_config is not None: + num_speculative_tokens = self.vllm_config.speculative_config.num_speculative_tokens + num_speculative_tokens = 0 if num_speculative_tokens is None else num_speculative_tokens + if prefill_pp_rank == self._prefill_pp_size - 1: + end_layer_index = end_layer_index + num_speculative_tokens num_cache_per_layer = len(list( self.kv_caches.values())[0]) # Number of KV caches per layer local_kv_caches_base_addrs = \ @@ -1020,6 +1040,8 @@ class MooncakeConnectorWorker: # get prefill pp size from extra config self._decode_pp_size = decode_parallel_config.get("pp_size", 1) assert self._decode_pp_size == 1, "decode pp size must be 1" + self._prefill_pp_layer_partition = prefill_parallel_config.get( + "pp_layer_partition", None) def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1126,7 +1148,8 @@ class MooncakeConnectorWorker: self.kv_recv_thread = KVCacheRecvingThread( self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, self.engine_id, self.handshake_port, kv_caches_base_addr, - self.block_len, ready_event, self.vllm_config, self.kv_caches) + self.block_len, ready_event, self.vllm_config, self.kv_caches, + self._prefill_pp_layer_partition) self.kv_recv_thread.start() ready_event.wait() @@ -1455,3 +1478,31 @@ def ensure_zmq_recv( raise RuntimeError( f"Failed to receive data after {max_retries} " f"retries: {e}") + + +# decode node should know pp_partition_layer in prefill node, +# it is configured in kv_transfer_config by partition_list_str, +# default using vllm layer split algorithm. +def get_prefill_pp_indices( + num_hidden_layers: int, + pp_rank: int, + pp_size: int, + partition_list_str: Optional[str] = None) -> tuple[int, int]: + if partition_list_str is None: + return get_pp_indices(num_hidden_layers, pp_rank, pp_size) + else: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + return (start_layer, end_layer)