[feature] Mooncake_connector support pcp/dcp (#4183)

add feature for Mooncake_connector supporting pcp/dcp

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
wangxiaochao
2025-11-18 10:17:48 +08:00
committed by GitHub
parent 10a046ddce
commit 0d04ad8c8f
2 changed files with 180 additions and 37 deletions

View File

@@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch
import msgspec
import zmq
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.utils import vllm_version_is
@@ -94,7 +95,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
'side_channel_port': 5555,
'metadata': MagicMock(),
'ready_event': threading.Event(),
'kv_caches': kv_caches
'kv_caches': kv_caches,
'pcp_rank': 0
}
self.threads = []
@@ -139,7 +141,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
"test": "metadata"
},
'ready_event': threading.Event(),
'kv_caches': kv_caches
'kv_caches': kv_caches,
'pcp_rank': 0
}
self.thread = KVCacheSendingThread(**self.common_args)
@@ -174,7 +177,8 @@ class TestKVCacheSendingThread(unittest.TestCase):
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
kv_caches={})
kv_caches={},
pcp_rank=0)
thread.start()
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
@@ -617,7 +621,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
"remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_port": 5000
"remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
})
self.assertEqual(len(meta.requests), 1)
@@ -663,7 +669,9 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
"remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
}
meta = self.scheduler.build_connector_meta(MagicMock())
@@ -1018,6 +1026,12 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.mock_dcp_group = MagicMock(spec=GroupCoordinator)
self.mock_dcp_group.rank_in_group = 0
self.mock_dcp_group.world_size = 1
self.mock_dcp_group.device_group = MagicMock()
self.mock_dcp = MagicMock()
self.mock_dcp.world_size = 1
self.patches = [
patch(
@@ -1051,6 +1065,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
patch('vllm.distributed.parallel_state.get_dcp_group',
return_value=self.mock_dcp_group),
patch('vllm.distributed.parallel_state._DCP',
return_value=self.mock_dcp),
patch(
'vllm.distributed.get_decode_context_model_parallel_world_size',
return_value=1)
]
for p in self.patches:

View File

@@ -25,8 +25,10 @@ from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
get_tp_group)
from vllm.distributed.parallel_state import (
get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank, get_tp_group)
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
@@ -35,7 +37,14 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is
# isort: off
if prefill_context_parallel_enable():
from vllm.distributed import (get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
if vllm_version_is("0.11.0"):
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
@@ -66,6 +75,8 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_engine_id: str
remote_pcp_size: int
remote_dcp_size: int
class KVCacheTaskTracker:
@@ -142,7 +153,7 @@ class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int,
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
kv_caches: dict[str, Any]):
kv_caches: dict[str, Any], pcp_rank: int):
super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank
self.decode_tp_size = decode_tp_size
@@ -152,6 +163,7 @@ class KVCacheSendingThread(threading.Thread):
self.metadata = metadata
self.ready_event = ready_event
self.kv_caches = kv_caches
self.pcp_rank = pcp_rank
self.task_tracker = KVCacheTaskTracker()
@@ -183,7 +195,8 @@ class KVCacheSendingThread(threading.Thread):
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler.
handshake_port = self.side_channel_port + self.tp_rank
handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \
+ self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
@@ -616,6 +629,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
)
@@ -713,14 +728,18 @@ class MooncakeConnectorScheduler:
logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \
if prefill_context_parallel_enable() else 1
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size
vllm_config.parallel_config.data_parallel_size * \
self.pcp_size
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
@@ -848,6 +867,8 @@ class MooncakeConnectorScheduler:
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
last_token_id=request.output_token_ids[-1],
)
@@ -875,7 +896,15 @@ class MooncakeConnectorWorker:
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
@@ -883,8 +912,8 @@ class MooncakeConnectorWorker:
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size)
self.handshake_port = self.side_channel_port + self.tp_rank
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
self.sockets: dict = {}
# get tp device id
@@ -893,20 +922,23 @@ class MooncakeConnectorWorker:
device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None:
device_ids = list(
range(self.dp_rank * self.tp_size,
(self.dp_rank + 1) * self.tp_size))
range(self.dp_rank * self.tp_size * self.pcp_size,
(self.dp_rank + 1) * self.tp_size * self.pcp_size))
else:
device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size
end_index = start_index + self.tp_size
start_index = self.dp_rank * self.tp_size * self.pcp_size
end_index = start_index + self.tp_size * self.pcp_size
if len(device_ids) < end_index:
raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index]
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
assert len(
device_ids
) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore
self.device_id = device_ids[self.pcp_rank * self.tp_size +
self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', True):
@@ -1061,7 +1093,7 @@ class MooncakeConnectorWorker:
self.kv_send_thread = KVCacheSendingThread(
self.tp_rank, self._decode_tp_size, self.engine_id,
self.side_channel_host, self.side_channel_port, metadata,
ready_event, self.kv_caches)
ready_event, self.kv_caches, self.pcp_rank)
self.kv_send_thread.start()
else:
self.kv_recv_thread = KVCacheRecvingThread(
@@ -1094,6 +1126,92 @@ class MooncakeConnectorWorker:
"requests: %d", len(done_sending), len(done_recving))
return done_sending, done_recving
def _get_kv_split_metadata(
self,
req_id: str,
meta: ReqMeta,
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [[
x + meta.remote_port for x in choosen_rank_list
]]
local_block_ids_list, remote_block_ids_list = [
meta.local_block_ids
], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
remote_block_nums = [len(meta.remote_block_ids)]
else:
assert self.pcp_size == 1
if self.use_mla:
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
else:
assert self.tp_size == self._prefill_tp_size and (
self.dcp_size == 1
or self.dcp_size == meta.remote_dcp_size)
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(meta.remote_pcp_size):
cp_offset = cp_idx * self._prefill_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(meta.local_block_ids)
remote_block_nums = [
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
] * meta.remote_pcp_size * remote_dcp_size
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_ports = remote_ports[
num_remain_blocks:] + remote_ports[:num_remain_blocks]
remote_block_nums = remote_block_nums[
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
remote_handshake_port_list = []
for remote_kv_id in range(remote_kv_num):
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
# D rank will get remote block 1 in port 30004 and save it in local block 5
local_block_ids_list = []
remote_block_ids_list = []
local_block_offset = 0
for remote_kv_id in range(len(remote_handshake_port_list)):
num_blocks_to_pull = remote_block_nums[remote_kv_id]
remote_block_ids_list.append(
meta.remote_block_ids[:num_blocks_to_pull])
local_block_ids_list.append(
meta.local_block_ids[local_block_offset:local_block_offset +
num_blocks_to_pull])
local_block_offset += num_blocks_to_pull
assert local_block_offset == len(meta.local_block_ids), \
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
"""Start loading KV blocks from remote engine."""
for req_id, meta in metadata.requests.items():
@@ -1103,21 +1221,25 @@ class MooncakeConnectorWorker:
meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids))
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [
x + meta.remote_port for x in choosen_rank_list
]
for i in range(self.num_need_pulls):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=meta.local_block_ids,
remote_block_ids=meta.remote_block_ids,
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[i],
offset=i,
num_need_pulls=self.num_need_pulls)
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
req_id, meta)
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
if len(local_block_ids_list[pcp_dcp_rank]) + len(
remote_block_ids_list[pcp_dcp_rank]) == 0:
continue
for i in range(self.num_need_pulls):
assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request(
request_id=req_id,
local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=remote_block_ids_list[pcp_dcp_rank],
remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
offset=i,
num_need_pulls=self.num_need_pulls)
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():