[feature] Mooncake_connector support pcp/dcp (#4183)

add feature for Mooncake_connector supporting pcp/dcp

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

---------

Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
wangxiaochao
2025-11-18 10:17:48 +08:00
committed by GitHub
parent 10a046ddce
commit 0d04ad8c8f
2 changed files with 180 additions and 37 deletions

View File

@@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch
import msgspec import msgspec
import zmq import zmq
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import vllm_version_is
@@ -94,7 +95,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
'side_channel_port': 5555, 'side_channel_port': 5555,
'metadata': MagicMock(), 'metadata': MagicMock(),
'ready_event': threading.Event(), 'ready_event': threading.Event(),
'kv_caches': kv_caches 'kv_caches': kv_caches,
'pcp_rank': 0
} }
self.threads = [] self.threads = []
@@ -139,7 +141,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
"test": "metadata" "test": "metadata"
}, },
'ready_event': threading.Event(), 'ready_event': threading.Event(),
'kv_caches': kv_caches 'kv_caches': kv_caches,
'pcp_rank': 0
} }
self.thread = KVCacheSendingThread(**self.common_args) self.thread = KVCacheSendingThread(**self.common_args)
@@ -174,7 +177,8 @@ class TestKVCacheSendingThread(unittest.TestCase):
side_channel_port=free_port, side_channel_port=free_port,
metadata=metadata, metadata=metadata,
ready_event=ready_event, ready_event=ready_event,
kv_caches={}) kv_caches={},
pcp_rank=0)
thread.start() thread.start()
self.assertTrue(ready_event.wait(timeout=3), self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout") "Server thread startup timeout")
@@ -617,7 +621,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
"remote_block_ids": [4, 5, 6], "remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine", "remote_engine_id": "remote_engine",
"remote_host": "localhost", "remote_host": "localhost",
"remote_port": 5000 "remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
}) })
self.assertEqual(len(meta.requests), 1) self.assertEqual(len(meta.requests), 1)
@@ -663,7 +669,9 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
"remote_block_ids": [1, 2, 3], "remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote", "remote_engine_id": "remote",
"remote_host": "localhost", "remote_host": "localhost",
"remote_port": 5000 "remote_port": 5000,
"remote_pcp_size": 1,
"remote_dcp_size": 1
} }
meta = self.scheduler.build_connector_meta(MagicMock()) meta = self.scheduler.build_connector_meta(MagicMock())
@@ -1018,6 +1026,12 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0 self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0
self.mock_dcp_group = MagicMock(spec=GroupCoordinator)
self.mock_dcp_group.rank_in_group = 0
self.mock_dcp_group.world_size = 1
self.mock_dcp_group.device_group = MagicMock()
self.mock_dcp = MagicMock()
self.mock_dcp.world_size = 1
self.patches = [ self.patches = [
patch( patch(
@@ -1051,6 +1065,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
MagicMock()), MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event', patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()), MagicMock()),
patch('vllm.distributed.parallel_state.get_dcp_group',
return_value=self.mock_dcp_group),
patch('vllm.distributed.parallel_state._DCP',
return_value=self.mock_dcp),
patch(
'vllm.distributed.get_decode_context_model_parallel_world_size',
return_value=1)
] ]
for p in self.patches: for p in self.patches:

View File

@@ -25,8 +25,10 @@ from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (
get_tp_group) get_decode_context_model_parallel_rank,
get_decode_context_model_parallel_world_size,
get_tensor_model_parallel_rank, get_tp_group)
from vllm.utils import logger from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
@@ -35,7 +37,14 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import vllm_version_is from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is
# isort: off
if prefill_context_parallel_enable():
from vllm.distributed import (get_prefill_context_model_parallel_rank,
get_prefill_context_model_parallel_world_size
)
# isort: on
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
@@ -66,6 +75,8 @@ class ReqMeta:
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_engine_id: str remote_engine_id: str
remote_pcp_size: int
remote_dcp_size: int
class KVCacheTaskTracker: class KVCacheTaskTracker:
@@ -142,7 +153,7 @@ class KVCacheSendingThread(threading.Thread):
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
side_channel_host: str, side_channel_port: int, side_channel_host: str, side_channel_port: int,
metadata: MooncakeAgentMetadata, ready_event: threading.Event, metadata: MooncakeAgentMetadata, ready_event: threading.Event,
kv_caches: dict[str, Any]): kv_caches: dict[str, Any], pcp_rank: int):
super().__init__(daemon=True, name="KVCacheSendingThread") super().__init__(daemon=True, name="KVCacheSendingThread")
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.decode_tp_size = decode_tp_size self.decode_tp_size = decode_tp_size
@@ -152,6 +163,7 @@ class KVCacheSendingThread(threading.Thread):
self.metadata = metadata self.metadata = metadata
self.ready_event = ready_event self.ready_event = ready_event
self.kv_caches = kv_caches self.kv_caches = kv_caches
self.pcp_rank = pcp_rank
self.task_tracker = KVCacheTaskTracker() self.task_tracker = KVCacheTaskTracker()
@@ -183,7 +195,8 @@ class KVCacheSendingThread(threading.Thread):
# NOTE(rob): we need each rank to have a unique port. This hack to keeps # NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a # us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler. # single ZMQ socket in the scheduler.
handshake_port = self.side_channel_port + self.tp_rank handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \
+ self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port) path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path) logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
@@ -616,6 +629,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
remote_engine_id=kv_transfer_params["remote_engine_id"], remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"], remote_port=kv_transfer_params["remote_port"],
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
) )
@@ -713,14 +728,18 @@ class MooncakeConnectorScheduler:
logger.info("Initializing Mooncake Scheduler %s", engine_id) logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \
if prefill_context_parallel_enable() else 1
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size vllm_config.parallel_config.data_parallel_size * \
self.pcp_size
# Handshake base port # Handshake base port
self.side_channel_port = ( self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port + vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
# Requests that need to start recv. # Requests that need to start recv.
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
@@ -848,6 +867,8 @@ class MooncakeConnectorScheduler:
remote_engine_id=self.engine_id, remote_engine_id=self.engine_id,
remote_host=self.side_channel_host, remote_host=self.side_channel_host,
remote_port=self.side_channel_port, remote_port=self.side_channel_port,
remote_pcp_size=self.pcp_size,
remote_dcp_size=self.dcp_size,
last_token_id=request.output_token_ids[-1], last_token_id=request.output_token_ids[-1],
) )
@@ -875,7 +896,15 @@ class MooncakeConnectorWorker:
self.dp_size = vllm_config.parallel_config.data_parallel_size_local self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {} self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip() self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size self.pcp_size = get_prefill_context_model_parallel_world_size(
) if prefill_context_parallel_enable() else 1
self.pcp_rank = get_prefill_context_model_parallel_rank(
) if self.pcp_size > 1 else 0
self.dcp_size = get_decode_context_model_parallel_world_size()
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
@@ -883,8 +912,8 @@ class MooncakeConnectorWorker:
self.side_channel_port = ( self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port + vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.data_parallel_rank *
vllm_config.parallel_config.tensor_parallel_size) vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
self.handshake_port = self.side_channel_port + self.tp_rank self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
self.sockets: dict = {} self.sockets: dict = {}
# get tp device id # get tp device id
@@ -893,20 +922,23 @@ class MooncakeConnectorWorker:
device_ids_str = envs_ascend.PHYSICAL_DEVICES device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None: if device_ids_str is None:
device_ids = list( device_ids = list(
range(self.dp_rank * self.tp_size, range(self.dp_rank * self.tp_size * self.pcp_size,
(self.dp_rank + 1) * self.tp_size)) (self.dp_rank + 1) * self.tp_size * self.pcp_size))
else: else:
device_ids = list(map(int, device_ids_str.split(','))) device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size start_index = self.dp_rank * self.tp_size * self.pcp_size
end_index = start_index + self.tp_size end_index = start_index + self.tp_size * self.pcp_size
if len(device_ids) < end_index: if len(device_ids) < end_index:
raise ValueError( raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. " f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} " f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.") "in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index] device_ids = device_ids[start_index:end_index]
assert len(device_ids) > self.tp_rank # type: ignore assert len(
self.device_id = device_ids[self.tp_rank] # type: ignore device_ids
) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore
self.device_id = device_ids[self.pcp_rank * self.tp_size +
self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config( if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', True): 'use_ascend_direct', True):
@@ -1061,7 +1093,7 @@ class MooncakeConnectorWorker:
self.kv_send_thread = KVCacheSendingThread( self.kv_send_thread = KVCacheSendingThread(
self.tp_rank, self._decode_tp_size, self.engine_id, self.tp_rank, self._decode_tp_size, self.engine_id,
self.side_channel_host, self.side_channel_port, metadata, self.side_channel_host, self.side_channel_port, metadata,
ready_event, self.kv_caches) ready_event, self.kv_caches, self.pcp_rank)
self.kv_send_thread.start() self.kv_send_thread.start()
else: else:
self.kv_recv_thread = KVCacheRecvingThread( self.kv_recv_thread = KVCacheRecvingThread(
@@ -1094,6 +1126,92 @@ class MooncakeConnectorWorker:
"requests: %d", len(done_sending), len(done_recving)) "requests: %d", len(done_sending), len(done_recving))
return done_sending, done_recving return done_sending, done_recving
def _get_kv_split_metadata(
self,
req_id: str,
meta: ReqMeta,
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
"""
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
"""
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
choosen_rank_list = self._get_remote_tp_rank(req_id)
remote_handshake_port_list = [[
x + meta.remote_port for x in choosen_rank_list
]]
local_block_ids_list, remote_block_ids_list = [
meta.local_block_ids
], [meta.remote_block_ids]
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
# remote & local cp/dcp are equal, do kv transfer point-to-point
remote_kv_num = 1
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
remote_block_nums = [len(meta.remote_block_ids)]
else:
assert self.pcp_size == 1
if self.use_mla:
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
else:
assert self.tp_size == self._prefill_tp_size and (
self.dcp_size == 1
or self.dcp_size == meta.remote_dcp_size)
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
cp_dcp_offsets = []
for cp_idx in range(meta.remote_pcp_size):
cp_offset = cp_idx * self._prefill_tp_size
cp_dcp_offsets += list(
range(cp_offset, cp_offset + remote_dcp_size))
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
for cp_dcp_offset in cp_dcp_offsets]
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
local_block_num = len(meta.local_block_ids)
remote_block_nums = [
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
] * meta.remote_pcp_size * remote_dcp_size
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
remote_dcp_size)
for i in range(num_remain_blocks):
remote_block_nums[i] += 1
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
remote_ports = remote_ports[
num_remain_blocks:] + remote_ports[:num_remain_blocks]
remote_block_nums = remote_block_nums[
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
remote_handshake_port_list = []
for remote_kv_id in range(remote_kv_num):
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
# D rank will get remote block 1 in port 30004 and save it in local block 5
local_block_ids_list = []
remote_block_ids_list = []
local_block_offset = 0
for remote_kv_id in range(len(remote_handshake_port_list)):
num_blocks_to_pull = remote_block_nums[remote_kv_id]
remote_block_ids_list.append(
meta.remote_block_ids[:num_blocks_to_pull])
local_block_ids_list.append(
meta.local_block_ids[local_block_offset:local_block_offset +
num_blocks_to_pull])
local_block_offset += num_blocks_to_pull
assert local_block_offset == len(meta.local_block_ids), \
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
def start_load_kv(self, metadata: MooncakeConnectorMetadata): def start_load_kv(self, metadata: MooncakeConnectorMetadata):
"""Start loading KV blocks from remote engine.""" """Start loading KV blocks from remote engine."""
for req_id, meta in metadata.requests.items(): for req_id, meta in metadata.requests.items():
@@ -1103,19 +1221,23 @@ class MooncakeConnectorWorker:
meta.remote_engine_id, len(meta.local_block_ids), meta.remote_engine_id, len(meta.local_block_ids),
len(meta.remote_block_ids)) len(meta.remote_block_ids))
choosen_rank_list = self._get_remote_tp_rank(req_id) remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
remote_handshake_port_list = [ req_id, meta)
x + meta.remote_port for x in choosen_rank_list
] for pcp_dcp_rank in range(len(remote_handshake_port_list)):
if len(local_block_ids_list[pcp_dcp_rank]) + len(
remote_block_ids_list[pcp_dcp_rank]) == 0:
continue
for i in range(self.num_need_pulls): for i in range(self.num_need_pulls):
assert self.kv_recv_thread is not None assert self.kv_recv_thread is not None
self.kv_recv_thread.add_request( self.kv_recv_thread.add_request(
request_id=req_id, request_id=req_id,
local_block_ids=meta.local_block_ids, local_block_ids=local_block_ids_list[pcp_dcp_rank],
remote_block_ids=meta.remote_block_ids, remote_block_ids=remote_block_ids_list[pcp_dcp_rank],
remote_engine_id=meta.remote_engine_id, remote_engine_id=meta.remote_engine_id,
remote_host=meta.remote_host, remote_host=meta.remote_host,
remote_handshake_port=remote_handshake_port_list[i], remote_handshake_port=remote_handshake_port_list[
pcp_dcp_rank][i],
offset=i, offset=i,
num_need_pulls=self.num_need_pulls) num_need_pulls=self.num_need_pulls)