[feature] mooncake support pcp/dcp in common conditions (#5224)

### What this PR does / why we need it?
1. This PR is proposed to support complicated pcp/dcp parallelisms in
Prefill and Decode nodes in Mooncake, such as Prefill: TP8/PCP2DCP8 and
Decode: TP8/DCP4/DP2, which is not supported now. We establish the link
mappings to transfer KVCache between prefill and decode nodes. The main
function is realized in Function of `_get_kv_split_metadata` in
Mooncake_connector.py
2. After a prefill rank is pulled KVCache by a decode rank, the decode
rank will send `DONE_RECVING_MSG` to the prefill rank and the prefill
rank will free its KVCache blocks. If a prefill rank is pulled KVCache
more than one time by several decode ranks and it surely could happen in
complicated pcp/dcp parallelisms, it will cause the prefill rank free
its KVCache blocks for several times, which could cause memory issue.
This PR solve this issue by counting the times of prefill rank would be
pulled KVCache and in the last time, it will free the prefill rank
KVCache blocks. The related code is in Function of `run_busy_loop` in
Mooncake_connector.py
3. If a prefill rank is not pulled KVCache by any decode ranks, the
first rank in decode node will send "DONE_RECVING_MSG" to free its
blocks. The related code is in Function of
`_send_done_signal_to_free_remote_port` in Mooncake_connector.py

### How was this patch tested?
This PR is tested in many pcp/dcp parallelisms, and the accuracy are all
correct.
MLA model:
Prefill node:  TP8/DP2, Decode node: TP8/DP2
Prefill node:  TP8/PCP2/DCP8, Decode node: TP8/DP2
Prefill node:  TP8/PCP2/DCP8, Decode node: TP8/DCP4/DP2
Prefill node:  TP8/PCP2/DCP4, Decode node: TP4/DCP2/DP4
Prefill node:  TP8/PCP2/DCP2, Decode node: TP4/DCP4/DP4
Prefill node:  TP8/PCP2, Decode node: TP4/DCP2

GQA model:
Prefill node:  TP8/DP2, Decode node: TP8/DP2
Prefill node:  TP8/PCP2/DCP2, Decode node: TP8/DP2
Prefill node:  TP8/PCP2/DCP2, Decode node: TP8/DCP2/DP2
Prefill node:  TP8/PCP2/DCP2, Decode node: TP4/DP4
Prefill node:  TP16/DCP2/PCP1, Decode node: TP8/DCP2/DP2


- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
- Co-author by: Daishixun dsxtsteven@sina.com

---------

Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
wangxiaochao6
2025-12-31 09:53:03 +08:00
committed by GitHub
parent a5ae07a5d2
commit a539ae753a
2 changed files with 283 additions and 75 deletions

View File

@@ -241,6 +241,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
side_channel_port=30000,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event,
@@ -294,6 +295,7 @@ class TestSocketManagement(unittest.TestCase):
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
side_channel_port=30000,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event,
@@ -351,6 +353,7 @@ class TestCoreFunctionality(unittest.TestCase):
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
side_channel_port=30000,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event,
@@ -367,6 +370,9 @@ class TestCoreFunctionality(unittest.TestCase):
"remote_transfer_port": 7777,
"offset": 0,
"tp_num_need_pulls": 2,
"remote_port_send_num": {
6666: 1
},
"all_task_done": False
}
self.thread.task_tracker = MagicMock()
@@ -382,7 +388,7 @@ class TestCoreFunctionality(unittest.TestCase):
self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
mock_send.assert_called_once_with("req1", "localhost", 6666, {6666: 1})
if not self.thread.task_tracker.update_done_task_count.called:
self.thread.task_tracker.update_done_task_count("req1")
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
@@ -433,6 +439,7 @@ class TestMetadataHandling(unittest.TestCase):
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
side_channel_port=30000,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event,
@@ -497,6 +504,7 @@ class TestMainThreadLoop(unittest.TestCase):
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
side_channel_port=30000,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event,
@@ -654,6 +662,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
meta.add_new_req(request_id="req1",
local_block_ids=[1, 2, 3],
num_external_tokens=48,
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine",
@@ -706,7 +715,7 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
request = MockRequest("req1")
blocks_mock = MagicMock()
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6], 48)
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
@@ -1278,6 +1287,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
worker.pcp_rank = pcp_rank
worker._prefill_tp_size = _prefill_tp_size
worker.local_remote_block_port_mapping = None
worker.block_size = 16
worker.num_key_value_heads = 1
meta = types.SimpleNamespace()
@@ -1286,6 +1297,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
meta.remote_port = remote_port
meta.remote_block_ids = remote_block_ids
meta.local_block_ids = local_block_ids
meta.num_external_tokens = pcp_size * dcp_size * len(
local_block_ids) * worker.block_size
meta.num_prompt_blocks = pcp_size * dcp_size * len(local_block_ids)
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
'0', meta)
@@ -1324,17 +1338,17 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))
([[30000], [30008]], [[1], []], [[1], []]))
self.assertEqual(
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
[1]),
([[30009], [30001]], [[], [1]], [[], [1]]))
([[30000], [30008]], [[1], []], [[1], []]))
self.assertEqual(
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
[1, 2, 3], [1, 2, 3, 4, 5]),
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
if __name__ == '__main__':

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import copy
import hashlib
import math
import os
@@ -42,7 +43,7 @@ from vllm.v1.request import RequestStatus
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import is_vl_model, prefill_context_parallel_enable
from vllm_ascend.utils import is_vl_model
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -65,6 +66,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
@dataclass
class ReqMeta:
local_block_ids: list[int]
num_external_tokens: int
remote_block_ids: list[int]
remote_host: str
remote_port: int
@@ -72,6 +74,7 @@ class ReqMeta:
remote_pcp_size: int
remote_dcp_size: int
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
num_prompt_blocks: int
@dataclass
@@ -184,6 +187,7 @@ class KVCacheSendingThread(threading.Thread):
self.ready_event = ready_event
self.kv_caches = kv_caches
self.pcp_rank = pcp_rank
self.port_send_num: dict[str, int] = {}
self.task_tracker = KVCacheTaskTracker()
@@ -248,7 +252,22 @@ class KVCacheSendingThread(threading.Thread):
elif msg[0] == DONE_RECVING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
request_id = msg[1]
self.task_tracker.update_done_task_count(request_id)
remote_port_send_num = msg[2]
if remote_port_send_num:
if request_id not in self.port_send_num:
self.port_send_num[request_id] = 0
self.port_send_num[request_id] += 1
device_index = self.pp_rank * self.tp_size + \
self.tp_rank + self.pcp_rank * \
self.prefill_tp_size
handshake_port = self.side_channel_port + device_index
if self.port_send_num[request_id] >= \
remote_port_send_num[handshake_port]:
self.task_tracker.update_done_task_count(
request_id)
del self.port_send_num[request_id]
else:
self.task_tracker.update_done_task_count(request_id)
# Acknowledge the request completion.
while True:
try:
@@ -275,7 +294,7 @@ class KVCacheRecvingThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
engine: TransferEngine, local_engine_id: str,
local_handshake_port: int,
local_handshake_port: int, side_channel_port: int,
local_kv_caches_base_addr: list[int], block_len: list[int],
ready_event: threading.Event, vllm_config: VllmConfig,
kv_caches: dict[str, Any]):
@@ -285,6 +304,7 @@ class KVCacheRecvingThread(threading.Thread):
self._prefill_pp_size = _prefill_pp_size
self.local_engine_id = local_engine_id
self.local_handshake_port = local_handshake_port
self.side_channel_port = side_channel_port
self.engine = engine
self.ready_event = ready_event
@@ -328,11 +348,19 @@ class KVCacheRecvingThread(threading.Thread):
self.num_kv_heads = max(
self.model_config.hf_text_config.num_key_value_heads //
self.tp_size, 1)
self.proc_not_transfer_request: dict[str, bool] = {}
def add_request(self, request_id: str, local_block_ids: list[int],
remote_block_ids: list[int], remote_engine_id: str,
remote_host: str, remote_handshake_port: int, offset: int,
tp_num_need_pulls: int, all_task_done: bool):
def add_request(self,
request_id: str,
local_block_ids: list[int],
remote_block_ids: list[int],
remote_engine_id: str,
remote_host: str,
remote_handshake_port: int,
offset: int,
tp_num_need_pulls: int,
remote_port_send_num: dict[int, int] = {},
all_task_done: bool = False):
"""Add a new request to the queue for processing."""
logger.debug(f"Adding request {request_id} to the queue.")
self.request_queue.put({
@@ -344,6 +372,7 @@ class KVCacheRecvingThread(threading.Thread):
"remote_handshake_port": remote_handshake_port,
"offset": offset,
"tp_num_need_pulls": tp_num_need_pulls,
"remote_port_send_num": remote_port_send_num,
"all_task_done": all_task_done
})
@@ -373,6 +402,7 @@ class KVCacheRecvingThread(threading.Thread):
request_id = req_meta["request_id"]
remote_host = req_meta["remote_host"]
remote_handshake_port = req_meta["remote_handshake_port"]
remote_port_send_num = req_meta["remote_port_send_num"]
all_task_done = req_meta["all_task_done"]
try:
@@ -389,11 +419,31 @@ class KVCacheRecvingThread(threading.Thread):
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
self._send_done_recv_signal(request_id, remote_host,
remote_handshake_port)
remote_handshake_port,
remote_port_send_num)
self._send_done_signal_to_free_remote_port(request_id, remote_host,
remote_port_send_num)
if all_task_done:
self.task_tracker.update_done_task_count(request_id)
if request_id in self.proc_not_transfer_request:
del self.proc_not_transfer_request[request_id]
self.request_queue.task_done()
def _send_done_signal_to_free_remote_port(self, request_id, remote_host,
remote_port_send_num):
if self.side_channel_port != self.local_handshake_port \
or not remote_port_send_num:
return
if request_id not in self.proc_not_transfer_request:
self.proc_not_transfer_request[request_id] = True
if self.proc_not_transfer_request[request_id]:
for remote_port in remote_port_send_num.keys():
if remote_port_send_num[remote_port] == 0:
self._send_done_recv_signal(request_id, remote_host,
remote_port,
remote_port_send_num)
self.proc_not_transfer_request[request_id] = False
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
"""Handle a KV cache transfer request."""
request_id = req_meta["request_id"]
@@ -606,13 +656,15 @@ class KVCacheRecvingThread(threading.Thread):
remote_handshake_port)
def _send_done_recv_signal(self, request_id: str, remote_host: str,
remote_handshake_port: int):
remote_handshake_port: int,
remote_port_send_num: dict[int, int]):
logger.debug("Sending done recving signal for request %s to %s:%d",
request_id, remote_host, remote_handshake_port)
sock: Optional[zmq.Socket] = None # type: ignore
try:
sock = self._get_remote_socket(remote_host, remote_handshake_port)
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
data_bytes = self.encoder.encode(
(DONE_RECVING_MSG, request_id, remote_port_send_num))
ensure_zmq_send(sock, data_bytes)
resp = ensure_zmq_recv(sock,
self.remote_poller,
@@ -674,10 +726,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
self,
request_id: str,
local_block_ids: list[int],
num_external_tokens: int,
kv_transfer_params: dict[str, Any],
):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
num_external_tokens=num_external_tokens,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
@@ -686,6 +740,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
"remote_multi_nodes_meta_mapping", {}),
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
)
@@ -827,7 +882,7 @@ class MooncakeConnectorScheduler:
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
self._reqs_need_send: dict[str, float] = {}
# master-slave meta information for cross-nodes
@@ -885,7 +940,7 @@ class MooncakeConnectorScheduler:
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (
request, local_block_ids)
request, local_block_ids, num_external_tokens)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
@@ -902,7 +957,8 @@ class MooncakeConnectorScheduler:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
for req_id, (req, block_ids,
num_external_tokens) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
@@ -910,6 +966,7 @@ class MooncakeConnectorScheduler:
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
num_external_tokens=num_external_tokens,
kv_transfer_params=req.kv_transfer_params,
)
@@ -946,6 +1003,9 @@ class MooncakeConnectorScheduler:
len(computed_block_ids), request.request_id)
self._reqs_need_send[request.request_id] = time.time()
num_prompt_blocks = math.ceil(
len(request.prompt_token_ids) / self.block_size)
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
@@ -957,6 +1017,7 @@ class MooncakeConnectorScheduler:
remote_dcp_size=self.dcp_size,
last_token_id=request.output_token_ids[-1],
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
num_prompt_blocks=num_prompt_blocks,
)
def set_xfer_handshake_metadata(
@@ -1045,6 +1106,8 @@ class MooncakeConnectorWorker:
num_p_block_heads = max(
1, self.num_key_value_heads // self._prefill_tp_size)
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
self.local_remote_block_port_mapping = None
self.remote_port_send_num: dict[int, int] = {}
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
@@ -1177,8 +1240,9 @@ class MooncakeConnectorWorker:
else:
self.kv_recv_thread = KVCacheRecvingThread(
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
self.engine_id, self.handshake_port, kv_caches_base_addr,
self.block_len, ready_event, self.vllm_config, self.kv_caches)
self.engine_id, self.handshake_port, self.side_channel_port,
kv_caches_base_addr, self.block_len, ready_event,
self.vllm_config, self.kv_caches)
self.kv_recv_thread.start()
start_wait_time = time.time()
@@ -1227,59 +1291,186 @@ class MooncakeConnectorWorker:
], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
remote_block_nums = [len(meta.remote_block_ids)]
else:
assert self.pcp_size == 1
if self.use_mla:
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
else:
assert self.tp_size == self._prefill_tp_size and (
self.dcp_size == 1
or self.dcp_size == meta.remote_dcp_size)
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(meta.remote_pcp_size):
cp_offset = cp_idx * self._prefill_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(meta.local_block_ids)
remote_block_nums = [
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
] * meta.remote_pcp_size * remote_dcp_size
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_ports = remote_ports[
num_remain_blocks:] + remote_ports[:num_remain_blocks]
remote_block_nums = remote_block_nums[
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
def context_parallel_parameters_check():
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (
self.pcp_size * self.dcp_size) == 0
if not self.use_mla:
p_node_heads_per_rank = math.ceil(self.num_key_value_heads /
self._prefill_tp_size)
d_node_heads_per_rank = math.ceil(self.num_key_value_heads /
self.tp_size)
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
remote_handshake_port_list = []
for remote_kv_id in range(remote_kv_num):
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
def get_kv_head_groups(tp_size):
if self.use_mla:
kv_head_groups = []
kv_head_ids = [0]
kv_head_groups.append(tuple(kv_head_ids))
return kv_head_groups
if self.num_key_value_heads // tp_size >= 1:
kv_head_groups = []
for tp_rank in range(tp_size):
kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \
for head_idx in range(self.num_key_value_heads // tp_size)]
kv_head_groups.append(tuple(kv_head_ids))
return kv_head_groups
if tp_size // self.num_key_value_heads > 1:
kv_head_groups = []
for kv_head_ids_ in range(self.num_key_value_heads):
kv_head_groups.append(tuple([kv_head_ids_]))
return kv_head_groups
def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base):
# key is kv_head_group, value is cp_groups and which cp_groups to select
cp_group_meta: dict = {}
kv_head_groups = get_kv_head_groups(tp_size)
dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size
for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups):
if kv_head_group not in cp_group_meta:
cp_group_meta[kv_head_group] = {}
cp_group_meta[kv_head_group]['cp_groups'] = []
cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0
kv_head_group_offset = tp_size // len(
kv_head_groups) * kv_head_group_idx
for dcp_repeat_idx in range(dcp_repeat_num):
# len(cp_group) == pcp_size * dcp_size
cp_group = []
dcp_repeat_offset = dcp_size * dcp_repeat_idx
for pcp_rank in range(pcp_size):
pcp_rank_offset = tp_size * pcp_rank
for dcp_rank in range(dcp_size):
cp_group.append(dcp_rank + port_base +
pcp_rank_offset +
dcp_repeat_offset +
kv_head_group_offset)
cp_group_meta[kv_head_group]['cp_groups'].append(cp_group)
return cp_group_meta
def get_local_remote_block_port_mappings():
context_parallel_parameters_check()
p_node_cp_group_meta = get_cp_group_meta(self._prefill_tp_size,
meta.remote_pcp_size,
meta.remote_dcp_size,
meta.remote_port)
d_node_cp_group_meta = get_cp_group_meta(self.tp_size,
self.pcp_size,
self.dcp_size,
self.side_channel_port)
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
for d_node_head_key in d_node_cp_group_meta.keys():
for p_node_head_key in p_node_cp_group_meta.keys():
if not set(p_node_head_key).issubset(set(d_node_head_key)):
continue
d_node_head_group = d_node_cp_group_meta[d_node_head_key]
p_node_head_group = p_node_cp_group_meta[p_node_head_key]
for d_cp_group in d_node_head_group['cp_groups']:
select_cp_groups_id = p_node_head_group[
'select_cp_groups_id']
p_cp_groups = p_node_head_group['cp_groups']
p_cp_group = p_cp_groups[select_cp_groups_id]
p_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \
if select_cp_groups_id + 1 < len(p_cp_groups) else 0
for d_idx, d_port in enumerate(d_cp_group):
if d_port not in local_remote_block_port_mappings:
local_remote_block_port_mappings[d_port] = []
p_port_remote_list = []
for p_idx, p_port in enumerate(p_cp_group):
if p_idx % len(d_cp_group) == d_idx:
p_port_remote_list.append(p_port)
local_remote_block_port_mappings[d_port].append(
p_port_remote_list)
logger.info(
"p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
"local_remote_block_port_mappings is:: %s. ",
p_node_cp_group_meta, d_node_cp_group_meta,
local_remote_block_port_mappings)
return local_remote_block_port_mappings
def get_remote_port_send_num(local_remote_block_port_mappings):
remote_port_send_num: dict[int, int] = {}
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
remote_port_send_num[meta.remote_port + port] = 0
for local_port in local_remote_block_port_mappings.keys():
remote_port_head_list = local_remote_block_port_mappings[
local_port]
for remote_port_list in remote_port_head_list:
for remote_port in remote_port_list:
remote_port_send_num[remote_port] += 1
return remote_port_send_num
if self.local_remote_block_port_mapping is None:
local_remote_block_port_mappings = get_local_remote_block_port_mappings(
)
self.local_remote_block_port_mapping = local_remote_block_port_mappings[
self.handshake_port]
self.remote_port_send_num = get_remote_port_send_num(
local_remote_block_port_mappings)
local_remote_block_port_mapping = copy.deepcopy(
self.local_remote_block_port_mapping)
num_external_blocks = math.ceil(meta.num_external_tokens /
self.block_size)
assert math.ceil(num_external_blocks / (self.pcp_size * self.dcp_size)) == len(meta.local_block_ids), \
f"num_external_blocks({num_external_blocks}), cp_size({self.pcp_size * self.dcp_size}), " \
f"local_block_ids_len ({len(meta.local_block_ids)})"
assert meta.num_prompt_blocks >= num_external_blocks, \
f"meta.num_prompt_blocks({meta.num_prompt_blocks}), num_external_blocks({num_external_blocks})"
remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size
remote_block_nums_all = [meta.num_prompt_blocks // remote_cp_size
] * remote_cp_size
num_remain_blocks = meta.num_prompt_blocks % remote_cp_size
for i in range(num_remain_blocks):
remote_block_nums_all[i] += 1
last_block_location = (num_remain_blocks + remote_cp_size -
1) % remote_cp_size
# Considering prefix cache, the remote_block_nums_all should be revised
num_prefix_cached_blocks = meta.num_prompt_blocks - num_external_blocks
remote_block_nums_all = [
num - num_prefix_cached_blocks // remote_cp_size
for num in remote_block_nums_all
]
num_remain_blocks = num_prefix_cached_blocks % remote_cp_size
for i in range(num_remain_blocks):
remote_block_nums_all[i] -= 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_block_nums: list[int] = []
final_block_idx: int | None = None
local_cp_rank = self.dcp_rank + self.pcp_rank * self.dcp_size
local_cp_size = self.dcp_size * self.pcp_size
for cp_rank, block_num in enumerate(remote_block_nums_all):
if cp_rank % local_cp_size == local_cp_rank:
if last_block_location == cp_rank:
final_block_idx = len(remote_block_nums)
remote_block_nums.append(block_num)
assert local_remote_block_port_mapping is not None
if final_block_idx is not None:
final_block_num = remote_block_nums.pop(final_block_idx)
remote_block_nums.append(final_block_num)
for mapping in local_remote_block_port_mapping:
final_block_port = mapping.pop(final_block_idx)
mapping.append(final_block_port)
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], []
for idx in range(len(local_remote_block_port_mapping[0])):
mapping_list = []
for mapping in local_remote_block_port_mapping:
mapping_list.append(mapping[idx])
remote_handshake_port_list.append(mapping_list)
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
# D rank will get remote block 1 in port 30004 and save it in local block 5
local_block_ids_list = []
remote_block_ids_list = []
local_block_offset = 0
for remote_kv_id in range(len(remote_handshake_port_list)):
num_blocks_to_pull = remote_block_nums[remote_kv_id]
@@ -1289,8 +1480,9 @@ class MooncakeConnectorWorker:
meta.local_block_ids[local_block_offset:local_block_offset +
num_blocks_to_pull])
local_block_offset += num_blocks_to_pull
assert local_block_offset == len(meta.local_block_ids), \
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), \
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
@@ -1302,15 +1494,11 @@ class MooncakeConnectorWorker:
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
if prefill_context_parallel_enable():
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
if len(local_block_ids_list[pcp_dcp_rank]) + len(
remote_block_ids_list[pcp_dcp_rank]) == 0:
continue
for i in range(self.tp_num_need_pulls):
assert self.kv_recv_thread is not None
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
@@ -1329,6 +1517,7 @@ class MooncakeConnectorWorker:
pcp_dcp_rank][i],
offset=i,
tp_num_need_pulls=self.tp_num_need_pulls,
remote_port_send_num=self.remote_port_send_num,
all_task_done=(
pcp_dcp_rank
== len(remote_handshake_port_list) - 1
@@ -1355,7 +1544,7 @@ class MooncakeConnectorWorker:
all_task_done=(i == self.tp_num_need_pulls *
self._prefill_pp_size - 1))
if self.kv_send_thread is not None:
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._prefill_get_remote_rank(req_id):
self.kv_send_thread.add_delayed_request(
@@ -1363,6 +1552,11 @@ class MooncakeConnectorWorker:
else:
self.kv_send_thread.add_not_transfer_request(req_id)
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size > 1:
for req_id, delay_start_time in metadata.requests_to_send.items():
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
def _get_remote_host_info_by_port(self, base_port: int,
remote_handshake_port: int,
remote_host: str, remote_engine_id: str,