From 0d04ad8c8f00081932566576da1e6d1dcd963d8d Mon Sep 17 00:00:00 2001 From: wangxiaochao Date: Tue, 18 Nov 2025 10:17:48 +0800 Subject: [PATCH] [feature] Mooncake_connector support pcp/dcp (#4183) add feature for Mooncake_connector supporting pcp/dcp - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2918c1b49c88c29783c86f78d2c4221cb9622379 --------- Signed-off-by: wangxiaochao Co-authored-by: wangxiaochao --- .../kv_connector/test_mooncake_connector.py | 31 ++- vllm_ascend/distributed/mooncake_connector.py | 186 +++++++++++++++--- 2 files changed, 180 insertions(+), 37 deletions(-) diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index a5bc066f..13a24596 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch import msgspec import zmq +from vllm.distributed.parallel_state import GroupCoordinator from vllm_ascend.utils import vllm_version_is @@ -94,7 +95,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase): 'side_channel_port': 5555, 'metadata': MagicMock(), 'ready_event': threading.Event(), - 'kv_caches': kv_caches + 'kv_caches': kv_caches, + 'pcp_rank': 0 } self.threads = [] @@ -139,7 +141,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase): "test": "metadata" }, 'ready_event': threading.Event(), - 'kv_caches': kv_caches + 'kv_caches': kv_caches, + 'pcp_rank': 0 } self.thread = KVCacheSendingThread(**self.common_args) @@ -174,7 +177,8 @@ class TestKVCacheSendingThread(unittest.TestCase): side_channel_port=free_port, metadata=metadata, ready_event=ready_event, - kv_caches={}) + kv_caches={}, + pcp_rank=0) thread.start() self.assertTrue(ready_event.wait(timeout=3), "Server thread startup timeout") @@ -617,7 +621,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase): "remote_block_ids": [4, 5, 6], "remote_engine_id": "remote_engine", "remote_host": "localhost", - "remote_port": 5000 + "remote_port": 5000, + "remote_pcp_size": 1, + "remote_dcp_size": 1 }) self.assertEqual(len(meta.requests), 1) @@ -663,7 +669,9 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", "remote_host": "localhost", - "remote_port": 5000 + "remote_port": 5000, + "remote_pcp_size": 1, + "remote_dcp_size": 1 } meta = self.scheduler.build_connector_meta(MagicMock()) @@ -1018,6 +1026,12 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.initialize.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0 + self.mock_dcp_group = MagicMock(spec=GroupCoordinator) + self.mock_dcp_group.rank_in_group = 0 + self.mock_dcp_group.world_size = 1 + self.mock_dcp_group.device_group = MagicMock() + self.mock_dcp = MagicMock() + self.mock_dcp.world_size = 1 self.patches = [ patch( @@ -1051,6 +1065,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase): MagicMock()), patch('vllm_ascend.distributed.mooncake_connector.threading.Event', MagicMock()), + patch('vllm.distributed.parallel_state.get_dcp_group', + return_value=self.mock_dcp_group), + patch('vllm.distributed.parallel_state._DCP', + return_value=self.mock_dcp), + patch( + 'vllm.distributed.get_decode_context_model_parallel_world_size', + return_value=1) ] for p in self.patches: diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 7951760d..3ca17a59 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -25,8 +25,10 @@ from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, - get_tp_group) +from vllm.distributed.parallel_state import ( + get_decode_context_model_parallel_rank, + get_decode_context_model_parallel_world_size, + get_tensor_model_parallel_rank, get_tp_group) from vllm.utils import logger from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -35,7 +37,14 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import vllm_version_is +from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is + +# isort: off +if prefill_context_parallel_enable(): + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) +# isort: on if vllm_version_is("0.11.0"): from vllm.utils import get_ip, make_zmq_path, make_zmq_socket @@ -66,6 +75,8 @@ class ReqMeta: remote_host: str remote_port: int remote_engine_id: str + remote_pcp_size: int + remote_dcp_size: int class KVCacheTaskTracker: @@ -142,7 +153,7 @@ class KVCacheSendingThread(threading.Thread): def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, side_channel_host: str, side_channel_port: int, metadata: MooncakeAgentMetadata, ready_event: threading.Event, - kv_caches: dict[str, Any]): + kv_caches: dict[str, Any], pcp_rank: int): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank self.decode_tp_size = decode_tp_size @@ -152,6 +163,7 @@ class KVCacheSendingThread(threading.Thread): self.metadata = metadata self.ready_event = ready_event self.kv_caches = kv_caches + self.pcp_rank = pcp_rank self.task_tracker = KVCacheTaskTracker() @@ -183,7 +195,8 @@ class KVCacheSendingThread(threading.Thread): # NOTE(rob): we need each rank to have a unique port. This hack to keeps # us moving. We will switch when moving to etcd or where we have a # single ZMQ socket in the scheduler. - handshake_port = self.side_channel_port + self.tp_rank + handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \ + + self.tp_rank path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore @@ -616,6 +629,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], + remote_pcp_size=kv_transfer_params["remote_pcp_size"], + remote_dcp_size=kv_transfer_params["remote_dcp_size"], ) @@ -713,14 +728,18 @@ class MooncakeConnectorScheduler: logger.info("Initializing Mooncake Scheduler %s", engine_id) self.side_channel_host = get_ip() + self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \ + if prefill_context_parallel_enable() else 1 + self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ - vllm_config.parallel_config.data_parallel_size + vllm_config.parallel_config.data_parallel_size * \ + self.pcp_size # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) # Requests that need to start recv. # New requests are added by update_state_after_alloc in @@ -848,6 +867,8 @@ class MooncakeConnectorScheduler: remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, + remote_pcp_size=self.pcp_size, + remote_dcp_size=self.dcp_size, last_token_id=request.output_token_ids[-1], ) @@ -875,7 +896,15 @@ class MooncakeConnectorWorker: self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.max_device_id = self.tp_size * self.dp_size + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 + self.dcp_size = get_decode_context_model_parallel_world_size() + self.dcp_rank = get_decode_context_model_parallel_rank( + ) if self.dcp_size > 1 else 0 + + self.max_device_id = self.tp_size * self.dp_size * self.pcp_size self.kv_role = vllm_config.kv_transfer_config.kv_role self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads @@ -883,8 +912,8 @@ class MooncakeConnectorWorker: self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size) - self.handshake_port = self.side_channel_port + self.tp_rank + vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) + self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank self.sockets: dict = {} # get tp device id @@ -893,20 +922,23 @@ class MooncakeConnectorWorker: device_ids_str = envs_ascend.PHYSICAL_DEVICES if device_ids_str is None: device_ids = list( - range(self.dp_rank * self.tp_size, - (self.dp_rank + 1) * self.tp_size)) + range(self.dp_rank * self.tp_size * self.pcp_size, + (self.dp_rank + 1) * self.tp_size * self.pcp_size)) else: device_ids = list(map(int, device_ids_str.split(','))) - start_index = self.dp_rank * self.tp_size - end_index = start_index + self.tp_size + start_index = self.dp_rank * self.tp_size * self.pcp_size + end_index = start_index + self.tp_size * self.pcp_size if len(device_ids) < end_index: raise ValueError( f"Not enough physical devices available for DP rank {self.dp_rank}. " f"Expected at least {end_index} devices, but found {len(device_ids)} " "in PHYSICAL_DEVICES.") device_ids = device_ids[start_index:end_index] - assert len(device_ids) > self.tp_rank # type: ignore - self.device_id = device_ids[self.tp_rank] # type: ignore + assert len( + device_ids + ) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore + self.device_id = device_ids[self.pcp_rank * self.tp_size + + self.tp_rank] # type: ignore if vllm_config.kv_transfer_config.get_from_extra_config( 'use_ascend_direct', True): @@ -1061,7 +1093,7 @@ class MooncakeConnectorWorker: self.kv_send_thread = KVCacheSendingThread( self.tp_rank, self._decode_tp_size, self.engine_id, self.side_channel_host, self.side_channel_port, metadata, - ready_event, self.kv_caches) + ready_event, self.kv_caches, self.pcp_rank) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( @@ -1094,6 +1126,92 @@ class MooncakeConnectorWorker: "requests: %d", len(done_sending), len(done_recving)) return done_sending, done_recving + def _get_kv_split_metadata( + self, + req_id: str, + meta: ReqMeta, + ) -> tuple[list[list[int]], list[list[int]], list[list[int]]]: + """ + In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node. + Use this function to calculate remote port and remote block number of each remote P node that we need to pull. + """ + if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: + choosen_rank_list = self._get_remote_tp_rank(req_id) + remote_handshake_port_list = [[ + x + meta.remote_port for x in choosen_rank_list + ]] + local_block_ids_list, remote_block_ids_list = [ + meta.local_block_ids + ], [meta.remote_block_ids] + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + + if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size: + # remote & local cp/dcp are equal, do kv transfer point-to-point + remote_kv_num = 1 + remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \ + for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)] + remote_block_nums = [len(meta.remote_block_ids)] + else: + assert self.pcp_size == 1 + if self.use_mla: + assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \ + (self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size) + else: + assert self.tp_size == self._prefill_tp_size and ( + self.dcp_size == 1 + or self.dcp_size == meta.remote_dcp_size) + # remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes + # 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size) + # 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size + remote_dcp_size = meta.remote_dcp_size // self.dcp_size + remote_kv_num = meta.remote_pcp_size * remote_dcp_size + cp_dcp_offsets = [] + for cp_idx in range(meta.remote_pcp_size): + cp_offset = cp_idx * self._prefill_tp_size + cp_dcp_offsets += list( + range(cp_offset, cp_offset + remote_dcp_size)) + tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size + remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \ + for cp_dcp_offset in cp_dcp_offsets] + # recompute cp/dcp block assign here, maybe we can also pass it from P node meta + local_block_num = len(meta.local_block_ids) + remote_block_nums = [ + local_block_num // (meta.remote_pcp_size * remote_dcp_size) + ] * meta.remote_pcp_size * remote_dcp_size + num_remain_blocks = local_block_num % (meta.remote_pcp_size * + remote_dcp_size) + for i in range(num_remain_blocks): + remote_block_nums[i] += 1 + # make sure the last block (which may be unfull) of P nodes is put to the last block of D node + remote_ports = remote_ports[ + num_remain_blocks:] + remote_ports[:num_remain_blocks] + remote_block_nums = remote_block_nums[ + num_remain_blocks:] + remote_block_nums[:num_remain_blocks] + + remote_handshake_port_list = [] + for remote_kv_id in range(remote_kv_num): + remote_handshake_port_list.append([remote_ports[remote_kv_id]]) + + # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list + # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]], + # remote_handshake_port_list[[30000],[30001],[30004],[30005]] + # D rank will get remote block 1 in port 30004 and save it in local block 5 + local_block_ids_list = [] + remote_block_ids_list = [] + local_block_offset = 0 + for remote_kv_id in range(len(remote_handshake_port_list)): + num_blocks_to_pull = remote_block_nums[remote_kv_id] + remote_block_ids_list.append( + meta.remote_block_ids[:num_blocks_to_pull]) + local_block_ids_list.append( + meta.local_block_ids[local_block_offset:local_block_offset + + num_blocks_to_pull]) + local_block_offset += num_blocks_to_pull + assert local_block_offset == len(meta.local_block_ids), \ + f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})" + + return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list + def start_load_kv(self, metadata: MooncakeConnectorMetadata): """Start loading KV blocks from remote engine.""" for req_id, meta in metadata.requests.items(): @@ -1103,21 +1221,25 @@ class MooncakeConnectorWorker: meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - choosen_rank_list = self._get_remote_tp_rank(req_id) - remote_handshake_port_list = [ - x + meta.remote_port for x in choosen_rank_list - ] - for i in range(self.num_need_pulls): - assert self.kv_recv_thread is not None - self.kv_recv_thread.add_request( - request_id=req_id, - local_block_ids=meta.local_block_ids, - remote_block_ids=meta.remote_block_ids, - remote_engine_id=meta.remote_engine_id, - remote_host=meta.remote_host, - remote_handshake_port=remote_handshake_port_list[i], - offset=i, - num_need_pulls=self.num_need_pulls) + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( + req_id, meta) + + for pcp_dcp_rank in range(len(remote_handshake_port_list)): + if len(local_block_ids_list[pcp_dcp_rank]) + len( + remote_block_ids_list[pcp_dcp_rank]) == 0: + continue + for i in range(self.num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=local_block_ids_list[pcp_dcp_rank], + remote_block_ids=remote_block_ids_list[pcp_dcp_rank], + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[ + pcp_dcp_rank][i], + offset=i, + num_need_pulls=self.num_need_pulls) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items():