diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 48271269..a0ef02f1 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -241,6 +241,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, + side_channel_port=30000, local_kv_caches_base_addr=[0x1000, 0x2000], block_len=[1024, 2048], ready_event=self.ready_event, @@ -294,6 +295,7 @@ class TestSocketManagement(unittest.TestCase): engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, + side_channel_port=30000, local_kv_caches_base_addr=[0x1000, 0x2000], block_len=[1024, 2048], ready_event=self.ready_event, @@ -351,6 +353,7 @@ class TestCoreFunctionality(unittest.TestCase): engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, + side_channel_port=30000, local_kv_caches_base_addr=[0x1000, 0x2000], block_len=[1024, 2048], ready_event=self.ready_event, @@ -367,6 +370,9 @@ class TestCoreFunctionality(unittest.TestCase): "remote_transfer_port": 7777, "offset": 0, "tp_num_need_pulls": 2, + "remote_port_send_num": { + 6666: 1 + }, "all_task_done": False } self.thread.task_tracker = MagicMock() @@ -382,7 +388,7 @@ class TestCoreFunctionality(unittest.TestCase): self.thread._handle_request(self.test_req) mock_transfer.assert_called_once_with(self.test_req) - mock_send.assert_called_once_with("req1", "localhost", 6666) + mock_send.assert_called_once_with("req1", "localhost", 6666, {6666: 1}) if not self.thread.task_tracker.update_done_task_count.called: self.thread.task_tracker.update_done_task_count("req1") self.thread.task_tracker.update_done_task_count.assert_called_once_with( @@ -433,6 +439,7 @@ class TestMetadataHandling(unittest.TestCase): engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, + side_channel_port=30000, local_kv_caches_base_addr=[0x1000, 0x2000], block_len=[1024, 2048], ready_event=self.ready_event, @@ -497,6 +504,7 @@ class TestMainThreadLoop(unittest.TestCase): engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, + side_channel_port=30000, local_kv_caches_base_addr=[0x1000, 0x2000], block_len=[1024, 2048], ready_event=self.ready_event, @@ -654,6 +662,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase): meta.add_new_req(request_id="req1", local_block_ids=[1, 2, 3], + num_external_tokens=48, kv_transfer_params={ "remote_block_ids": [4, 5, 6], "remote_engine_id": "remote_engine", @@ -706,7 +715,7 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): request = MockRequest("req1") blocks_mock = MagicMock() blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6] - self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6]) + self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6], 48) request.kv_transfer_params = { "remote_block_ids": [1, 2, 3], "remote_engine_id": "remote", @@ -1278,6 +1287,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase): worker.pcp_rank = pcp_rank worker._prefill_tp_size = _prefill_tp_size worker.local_remote_block_port_mapping = None + worker.block_size = 16 + worker.num_key_value_heads = 1 meta = types.SimpleNamespace() @@ -1286,6 +1297,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase): meta.remote_port = remote_port meta.remote_block_ids = remote_block_ids meta.local_block_ids = local_block_ids + meta.num_external_tokens = pcp_size * dcp_size * len( + local_block_ids) * worker.block_size + meta.num_prompt_blocks = pcp_size * dcp_size * len(local_block_ids) remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata( '0', meta) @@ -1324,17 +1338,17 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.assertEqual( get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], [1]), - ([[30009], [30001]], [[], [1]], [[], [1]])) + ([[30000], [30008]], [[1], []], [[1], []])) self.assertEqual( get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1], [1]), - ([[30009], [30001]], [[], [1]], [[], [1]])) + ([[30000], [30008]], [[1], []], [[1], []])) self.assertEqual( get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000, [1, 2, 3], [1, 2, 3, 4, 5]), - ([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]])) + ([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]])) if __name__ == '__main__': diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 8f269610..38284335 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import copy import hashlib import math import os @@ -42,7 +43,7 @@ from vllm.v1.request import RequestStatus from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import is_vl_model, prefill_context_parallel_enable +from vllm_ascend.utils import is_vl_model if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -65,6 +66,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): @dataclass class ReqMeta: local_block_ids: list[int] + num_external_tokens: int remote_block_ids: list[int] remote_host: str remote_port: int @@ -72,6 +74,7 @@ class ReqMeta: remote_pcp_size: int remote_dcp_size: int remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]] + num_prompt_blocks: int @dataclass @@ -184,6 +187,7 @@ class KVCacheSendingThread(threading.Thread): self.ready_event = ready_event self.kv_caches = kv_caches self.pcp_rank = pcp_rank + self.port_send_num: dict[str, int] = {} self.task_tracker = KVCacheTaskTracker() @@ -248,7 +252,22 @@ class KVCacheSendingThread(threading.Thread): elif msg[0] == DONE_RECVING_MSG: logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) request_id = msg[1] - self.task_tracker.update_done_task_count(request_id) + remote_port_send_num = msg[2] + if remote_port_send_num: + if request_id not in self.port_send_num: + self.port_send_num[request_id] = 0 + self.port_send_num[request_id] += 1 + device_index = self.pp_rank * self.tp_size + \ + self.tp_rank + self.pcp_rank * \ + self.prefill_tp_size + handshake_port = self.side_channel_port + device_index + if self.port_send_num[request_id] >= \ + remote_port_send_num[handshake_port]: + self.task_tracker.update_done_task_count( + request_id) + del self.port_send_num[request_id] + else: + self.task_tracker.update_done_task_count(request_id) # Acknowledge the request completion. while True: try: @@ -275,7 +294,7 @@ class KVCacheRecvingThread(threading.Thread): def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int, engine: TransferEngine, local_engine_id: str, - local_handshake_port: int, + local_handshake_port: int, side_channel_port: int, local_kv_caches_base_addr: list[int], block_len: list[int], ready_event: threading.Event, vllm_config: VllmConfig, kv_caches: dict[str, Any]): @@ -285,6 +304,7 @@ class KVCacheRecvingThread(threading.Thread): self._prefill_pp_size = _prefill_pp_size self.local_engine_id = local_engine_id self.local_handshake_port = local_handshake_port + self.side_channel_port = side_channel_port self.engine = engine self.ready_event = ready_event @@ -328,11 +348,19 @@ class KVCacheRecvingThread(threading.Thread): self.num_kv_heads = max( self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1) + self.proc_not_transfer_request: dict[str, bool] = {} - def add_request(self, request_id: str, local_block_ids: list[int], - remote_block_ids: list[int], remote_engine_id: str, - remote_host: str, remote_handshake_port: int, offset: int, - tp_num_need_pulls: int, all_task_done: bool): + def add_request(self, + request_id: str, + local_block_ids: list[int], + remote_block_ids: list[int], + remote_engine_id: str, + remote_host: str, + remote_handshake_port: int, + offset: int, + tp_num_need_pulls: int, + remote_port_send_num: dict[int, int] = {}, + all_task_done: bool = False): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -344,6 +372,7 @@ class KVCacheRecvingThread(threading.Thread): "remote_handshake_port": remote_handshake_port, "offset": offset, "tp_num_need_pulls": tp_num_need_pulls, + "remote_port_send_num": remote_port_send_num, "all_task_done": all_task_done }) @@ -373,6 +402,7 @@ class KVCacheRecvingThread(threading.Thread): request_id = req_meta["request_id"] remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] + remote_port_send_num = req_meta["remote_port_send_num"] all_task_done = req_meta["all_task_done"] try: @@ -389,11 +419,31 @@ class KVCacheRecvingThread(threading.Thread): # resource cleanup. Failing to do so may cause a memory leak on the # remote host. self._send_done_recv_signal(request_id, remote_host, - remote_handshake_port) + remote_handshake_port, + remote_port_send_num) + self._send_done_signal_to_free_remote_port(request_id, remote_host, + remote_port_send_num) if all_task_done: self.task_tracker.update_done_task_count(request_id) + if request_id in self.proc_not_transfer_request: + del self.proc_not_transfer_request[request_id] self.request_queue.task_done() + def _send_done_signal_to_free_remote_port(self, request_id, remote_host, + remote_port_send_num): + if self.side_channel_port != self.local_handshake_port \ + or not remote_port_send_num: + return + if request_id not in self.proc_not_transfer_request: + self.proc_not_transfer_request[request_id] = True + if self.proc_not_transfer_request[request_id]: + for remote_port in remote_port_send_num.keys(): + if remote_port_send_num[remote_port] == 0: + self._send_done_recv_signal(request_id, remote_host, + remote_port, + remote_port_send_num) + self.proc_not_transfer_request[request_id] = False + def _transfer_kv_cache(self, req_meta: dict[str, Any]): """Handle a KV cache transfer request.""" request_id = req_meta["request_id"] @@ -606,13 +656,15 @@ class KVCacheRecvingThread(threading.Thread): remote_handshake_port) def _send_done_recv_signal(self, request_id: str, remote_host: str, - remote_handshake_port: int): + remote_handshake_port: int, + remote_port_send_num: dict[int, int]): logger.debug("Sending done recving signal for request %s to %s:%d", request_id, remote_host, remote_handshake_port) sock: Optional[zmq.Socket] = None # type: ignore try: sock = self._get_remote_socket(remote_host, remote_handshake_port) - data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id)) + data_bytes = self.encoder.encode( + (DONE_RECVING_MSG, request_id, remote_port_send_num)) ensure_zmq_send(sock, data_bytes) resp = ensure_zmq_recv(sock, self.remote_poller, @@ -674,10 +726,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): self, request_id: str, local_block_ids: list[int], + num_external_tokens: int, kv_transfer_params: dict[str, Any], ): self.requests[request_id] = ReqMeta( local_block_ids=local_block_ids, + num_external_tokens=num_external_tokens, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], @@ -686,6 +740,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), remote_multi_nodes_meta_mapping=kv_transfer_params.get( "remote_multi_nodes_meta_mapping", {}), + num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0), ) @@ -827,7 +882,7 @@ class MooncakeConnectorScheduler: # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. - self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {} self._reqs_need_send: dict[str, float] = {} # master-slave meta information for cross-nodes @@ -885,7 +940,7 @@ class MooncakeConnectorScheduler: if num_external_tokens > 0 else []) # Get unhashed blocks to pull from remote. self._reqs_need_recv[request.request_id] = ( - request, local_block_ids) + request, local_block_ids, num_external_tokens) else: logger.warning( "Got invalid KVTransferParams: %s. This " @@ -902,7 +957,8 @@ class MooncakeConnectorScheduler: meta = MooncakeConnectorMetadata() # Loop through scheduled reqs and convert to ReqMeta. - for req_id, (req, block_ids) in self._reqs_need_recv.items(): + for req_id, (req, block_ids, + num_external_tokens) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None # For the case where there are no remote blocks to pull # (block_ids is empty), we don't need to schedule @@ -910,6 +966,7 @@ class MooncakeConnectorScheduler: meta.add_new_req( request_id=req_id, local_block_ids=block_ids, + num_external_tokens=num_external_tokens, kv_transfer_params=req.kv_transfer_params, ) @@ -946,6 +1003,9 @@ class MooncakeConnectorScheduler: len(computed_block_ids), request.request_id) self._reqs_need_send[request.request_id] = time.time() + num_prompt_blocks = math.ceil( + len(request.prompt_token_ids) / self.block_size) + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -957,6 +1017,7 @@ class MooncakeConnectorScheduler: remote_dcp_size=self.dcp_size, last_token_id=request.output_token_ids[-1], remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping, + num_prompt_blocks=num_prompt_blocks, ) def set_xfer_handshake_metadata( @@ -1045,6 +1106,8 @@ class MooncakeConnectorWorker: num_p_block_heads = max( 1, self.num_key_value_heads // self._prefill_tp_size) self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads + self.local_remote_block_port_mapping = None + self.remote_port_send_num: dict[int, int] = {} def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -1177,8 +1240,9 @@ class MooncakeConnectorWorker: else: self.kv_recv_thread = KVCacheRecvingThread( self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, - self.engine_id, self.handshake_port, kv_caches_base_addr, - self.block_len, ready_event, self.vllm_config, self.kv_caches) + self.engine_id, self.handshake_port, self.side_channel_port, + kv_caches_base_addr, self.block_len, ready_event, + self.vllm_config, self.kv_caches) self.kv_recv_thread.start() start_wait_time = time.time() @@ -1227,59 +1291,186 @@ class MooncakeConnectorWorker: ], [meta.remote_block_ids] return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list - if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size: - # remote & local cp/dcp are equal, do kv transfer point-to-point - remote_kv_num = 1 - remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \ - for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)] - remote_block_nums = [len(meta.remote_block_ids)] - else: - assert self.pcp_size == 1 - if self.use_mla: - assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \ - (self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size) - else: - assert self.tp_size == self._prefill_tp_size and ( - self.dcp_size == 1 - or self.dcp_size == meta.remote_dcp_size) - # remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes - # 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size) - # 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size - remote_dcp_size = meta.remote_dcp_size // self.dcp_size - remote_kv_num = meta.remote_pcp_size * remote_dcp_size - cp_dcp_offsets = [] - for cp_idx in range(meta.remote_pcp_size): - cp_offset = cp_idx * self._prefill_tp_size - cp_dcp_offsets += list( - range(cp_offset, cp_offset + remote_dcp_size)) - tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size - remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \ - for cp_dcp_offset in cp_dcp_offsets] - # recompute cp/dcp block assign here, maybe we can also pass it from P node meta - local_block_num = len(meta.local_block_ids) - remote_block_nums = [ - local_block_num // (meta.remote_pcp_size * remote_dcp_size) - ] * meta.remote_pcp_size * remote_dcp_size - num_remain_blocks = local_block_num % (meta.remote_pcp_size * - remote_dcp_size) - for i in range(num_remain_blocks): - remote_block_nums[i] += 1 - # make sure the last block (which may be unfull) of P nodes is put to the last block of D node - remote_ports = remote_ports[ - num_remain_blocks:] + remote_ports[:num_remain_blocks] - remote_block_nums = remote_block_nums[ - num_remain_blocks:] + remote_block_nums[:num_remain_blocks] + def context_parallel_parameters_check(): + assert (meta.remote_pcp_size * meta.remote_dcp_size) % ( + self.pcp_size * self.dcp_size) == 0 + if not self.use_mla: + p_node_heads_per_rank = math.ceil(self.num_key_value_heads / + self._prefill_tp_size) + d_node_heads_per_rank = math.ceil(self.num_key_value_heads / + self.tp_size) + assert d_node_heads_per_rank % p_node_heads_per_rank == 0 - remote_handshake_port_list = [] - for remote_kv_id in range(remote_kv_num): - remote_handshake_port_list.append([remote_ports[remote_kv_id]]) + def get_kv_head_groups(tp_size): + if self.use_mla: + kv_head_groups = [] + kv_head_ids = [0] + kv_head_groups.append(tuple(kv_head_ids)) + return kv_head_groups + if self.num_key_value_heads // tp_size >= 1: + kv_head_groups = [] + for tp_rank in range(tp_size): + kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \ + for head_idx in range(self.num_key_value_heads // tp_size)] + kv_head_groups.append(tuple(kv_head_ids)) + return kv_head_groups + if tp_size // self.num_key_value_heads > 1: + kv_head_groups = [] + for kv_head_ids_ in range(self.num_key_value_heads): + kv_head_groups.append(tuple([kv_head_ids_])) + return kv_head_groups + + def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base): + # key is kv_head_group, value is cp_groups and which cp_groups to select + cp_group_meta: dict = {} + kv_head_groups = get_kv_head_groups(tp_size) + dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size + + for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups): + if kv_head_group not in cp_group_meta: + cp_group_meta[kv_head_group] = {} + cp_group_meta[kv_head_group]['cp_groups'] = [] + cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0 + kv_head_group_offset = tp_size // len( + kv_head_groups) * kv_head_group_idx + for dcp_repeat_idx in range(dcp_repeat_num): + # len(cp_group) == pcp_size * dcp_size + cp_group = [] + dcp_repeat_offset = dcp_size * dcp_repeat_idx + for pcp_rank in range(pcp_size): + pcp_rank_offset = tp_size * pcp_rank + for dcp_rank in range(dcp_size): + cp_group.append(dcp_rank + port_base + + pcp_rank_offset + + dcp_repeat_offset + + kv_head_group_offset) + cp_group_meta[kv_head_group]['cp_groups'].append(cp_group) + + return cp_group_meta + + def get_local_remote_block_port_mappings(): + context_parallel_parameters_check() + p_node_cp_group_meta = get_cp_group_meta(self._prefill_tp_size, + meta.remote_pcp_size, + meta.remote_dcp_size, + meta.remote_port) + d_node_cp_group_meta = get_cp_group_meta(self.tp_size, + self.pcp_size, + self.dcp_size, + self.side_channel_port) + local_remote_block_port_mappings: dict[int, list[list[int]]] = {} + for d_node_head_key in d_node_cp_group_meta.keys(): + for p_node_head_key in p_node_cp_group_meta.keys(): + if not set(p_node_head_key).issubset(set(d_node_head_key)): + continue + d_node_head_group = d_node_cp_group_meta[d_node_head_key] + p_node_head_group = p_node_cp_group_meta[p_node_head_key] + for d_cp_group in d_node_head_group['cp_groups']: + select_cp_groups_id = p_node_head_group[ + 'select_cp_groups_id'] + p_cp_groups = p_node_head_group['cp_groups'] + p_cp_group = p_cp_groups[select_cp_groups_id] + p_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \ + if select_cp_groups_id + 1 < len(p_cp_groups) else 0 + for d_idx, d_port in enumerate(d_cp_group): + if d_port not in local_remote_block_port_mappings: + local_remote_block_port_mappings[d_port] = [] + p_port_remote_list = [] + for p_idx, p_port in enumerate(p_cp_group): + if p_idx % len(d_cp_group) == d_idx: + p_port_remote_list.append(p_port) + local_remote_block_port_mappings[d_port].append( + p_port_remote_list) + + logger.info( + "p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. " + "local_remote_block_port_mappings is:: %s. ", + p_node_cp_group_meta, d_node_cp_group_meta, + local_remote_block_port_mappings) + + return local_remote_block_port_mappings + + def get_remote_port_send_num(local_remote_block_port_mappings): + remote_port_send_num: dict[int, int] = {} + for port in range(self._prefill_tp_size * meta.remote_pcp_size): + remote_port_send_num[meta.remote_port + port] = 0 + for local_port in local_remote_block_port_mappings.keys(): + remote_port_head_list = local_remote_block_port_mappings[ + local_port] + for remote_port_list in remote_port_head_list: + for remote_port in remote_port_list: + remote_port_send_num[remote_port] += 1 + return remote_port_send_num + + if self.local_remote_block_port_mapping is None: + local_remote_block_port_mappings = get_local_remote_block_port_mappings( + ) + self.local_remote_block_port_mapping = local_remote_block_port_mappings[ + self.handshake_port] + self.remote_port_send_num = get_remote_port_send_num( + local_remote_block_port_mappings) + + local_remote_block_port_mapping = copy.deepcopy( + self.local_remote_block_port_mapping) + + num_external_blocks = math.ceil(meta.num_external_tokens / + self.block_size) + + assert math.ceil(num_external_blocks / (self.pcp_size * self.dcp_size)) == len(meta.local_block_ids), \ + f"num_external_blocks({num_external_blocks}), cp_size({self.pcp_size * self.dcp_size}), " \ + f"local_block_ids_len ({len(meta.local_block_ids)})" + assert meta.num_prompt_blocks >= num_external_blocks, \ + f"meta.num_prompt_blocks({meta.num_prompt_blocks}), num_external_blocks({num_external_blocks})" + + remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size + remote_block_nums_all = [meta.num_prompt_blocks // remote_cp_size + ] * remote_cp_size + num_remain_blocks = meta.num_prompt_blocks % remote_cp_size + for i in range(num_remain_blocks): + remote_block_nums_all[i] += 1 + last_block_location = (num_remain_blocks + remote_cp_size - + 1) % remote_cp_size + + # Considering prefix cache, the remote_block_nums_all should be revised + num_prefix_cached_blocks = meta.num_prompt_blocks - num_external_blocks + remote_block_nums_all = [ + num - num_prefix_cached_blocks // remote_cp_size + for num in remote_block_nums_all + ] + num_remain_blocks = num_prefix_cached_blocks % remote_cp_size + for i in range(num_remain_blocks): + remote_block_nums_all[i] -= 1 + + # make sure the last block (which may be unfull) of P nodes is put to the last block of D node + remote_block_nums: list[int] = [] + final_block_idx: int | None = None + local_cp_rank = self.dcp_rank + self.pcp_rank * self.dcp_size + local_cp_size = self.dcp_size * self.pcp_size + for cp_rank, block_num in enumerate(remote_block_nums_all): + if cp_rank % local_cp_size == local_cp_rank: + if last_block_location == cp_rank: + final_block_idx = len(remote_block_nums) + remote_block_nums.append(block_num) + + assert local_remote_block_port_mapping is not None + if final_block_idx is not None: + final_block_num = remote_block_nums.pop(final_block_idx) + remote_block_nums.append(final_block_num) + for mapping in local_remote_block_port_mapping: + final_block_port = mapping.pop(final_block_idx) + mapping.append(final_block_port) + + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], [] + for idx in range(len(local_remote_block_port_mapping[0])): + mapping_list = [] + for mapping in local_remote_block_port_mapping: + mapping_list.append(mapping[idx]) + remote_handshake_port_list.append(mapping_list) # the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list # such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]], # remote_handshake_port_list[[30000],[30001],[30004],[30005]] # D rank will get remote block 1 in port 30004 and save it in local block 5 - local_block_ids_list = [] - remote_block_ids_list = [] local_block_offset = 0 for remote_kv_id in range(len(remote_handshake_port_list)): num_blocks_to_pull = remote_block_nums[remote_kv_id] @@ -1289,8 +1480,9 @@ class MooncakeConnectorWorker: meta.local_block_ids[local_block_offset:local_block_offset + num_blocks_to_pull]) local_block_offset += num_blocks_to_pull - assert local_block_offset == len(meta.local_block_ids), \ - f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})" + + assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), \ + f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}" return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list @@ -1302,15 +1494,11 @@ class MooncakeConnectorWorker: "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - - if prefill_context_parallel_enable(): + if meta.remote_pcp_size * meta.remote_dcp_size > 1: remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( req_id, meta) for pcp_dcp_rank in range(len(remote_handshake_port_list)): - if len(local_block_ids_list[pcp_dcp_rank]) + len( - remote_block_ids_list[pcp_dcp_rank]) == 0: - continue for i in range(self.tp_num_need_pulls): assert self.kv_recv_thread is not None remote_host, remote_engine_id = self._get_remote_host_info_by_port( @@ -1329,6 +1517,7 @@ class MooncakeConnectorWorker: pcp_dcp_rank][i], offset=i, tp_num_need_pulls=self.tp_num_need_pulls, + remote_port_send_num=self.remote_port_send_num, all_task_done=( pcp_dcp_rank == len(remote_handshake_port_list) - 1 @@ -1355,7 +1544,7 @@ class MooncakeConnectorWorker: all_task_done=(i == self.tp_num_need_pulls * self._prefill_pp_size - 1)) - if self.kv_send_thread is not None: + if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1: for req_id, delay_start_time in metadata.requests_to_send.items(): if self.tp_rank in self._prefill_get_remote_rank(req_id): self.kv_send_thread.add_delayed_request( @@ -1363,6 +1552,11 @@ class MooncakeConnectorWorker: else: self.kv_send_thread.add_not_transfer_request(req_id) + if self.kv_send_thread is not None and self.pcp_size * self.dcp_size > 1: + for req_id, delay_start_time in metadata.requests_to_send.items(): + self.kv_send_thread.add_delayed_request( + req_id, delay_start_time) + def _get_remote_host_info_by_port(self, base_port: int, remote_handshake_port: int, remote_host: str, remote_engine_id: str,