[feature] Mooncake_connector support pcp/dcp (#4183)
add feature for Mooncake_connector supporting pcp/dcp
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
@@ -94,7 +95,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
'side_channel_port': 5555,
|
||||
'metadata': MagicMock(),
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
'kv_caches': kv_caches,
|
||||
'pcp_rank': 0
|
||||
}
|
||||
self.threads = []
|
||||
|
||||
@@ -139,7 +141,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
"test": "metadata"
|
||||
},
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
'kv_caches': kv_caches,
|
||||
'pcp_rank': 0
|
||||
}
|
||||
self.thread = KVCacheSendingThread(**self.common_args)
|
||||
|
||||
@@ -174,7 +177,8 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
kv_caches={})
|
||||
kv_caches={},
|
||||
pcp_rank=0)
|
||||
thread.start()
|
||||
self.assertTrue(ready_event.wait(timeout=3),
|
||||
"Server thread startup timeout")
|
||||
@@ -617,7 +621,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
"remote_port": 5000,
|
||||
"remote_pcp_size": 1,
|
||||
"remote_dcp_size": 1
|
||||
})
|
||||
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
@@ -663,7 +669,9 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
"remote_port": 5000,
|
||||
"remote_pcp_size": 1,
|
||||
"remote_dcp_size": 1
|
||||
}
|
||||
|
||||
meta = self.scheduler.build_connector_meta(MagicMock())
|
||||
@@ -1018,6 +1026,12 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.mock_transfer_engine.get_rpc_port.return_value = 9090
|
||||
self.mock_transfer_engine.initialize.return_value = 0
|
||||
self.mock_transfer_engine.register_memory.return_value = 0
|
||||
self.mock_dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
self.mock_dcp_group.rank_in_group = 0
|
||||
self.mock_dcp_group.world_size = 1
|
||||
self.mock_dcp_group.device_group = MagicMock()
|
||||
self.mock_dcp = MagicMock()
|
||||
self.mock_dcp.world_size = 1
|
||||
|
||||
self.patches = [
|
||||
patch(
|
||||
@@ -1051,6 +1065,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
|
||||
MagicMock()),
|
||||
patch('vllm.distributed.parallel_state.get_dcp_group',
|
||||
return_value=self.mock_dcp_group),
|
||||
patch('vllm.distributed.parallel_state._DCP',
|
||||
return_value=self.mock_dcp),
|
||||
patch(
|
||||
'vllm.distributed.get_decode_context_model_parallel_world_size',
|
||||
return_value=1)
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
|
||||
@@ -25,8 +25,10 @@ from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
|
||||
get_tp_group)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_decode_context_model_parallel_world_size,
|
||||
get_tensor_model_parallel_rank, get_tp_group)
|
||||
from vllm.utils import logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
@@ -35,7 +37,14 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
from vllm_ascend.utils import prefill_context_parallel_enable, vllm_version_is
|
||||
|
||||
# isort: off
|
||||
if prefill_context_parallel_enable():
|
||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
||||
get_prefill_context_model_parallel_world_size
|
||||
)
|
||||
# isort: on
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
@@ -66,6 +75,8 @@ class ReqMeta:
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_engine_id: str
|
||||
remote_pcp_size: int
|
||||
remote_dcp_size: int
|
||||
|
||||
|
||||
class KVCacheTaskTracker:
|
||||
@@ -142,7 +153,7 @@ class KVCacheSendingThread(threading.Thread):
|
||||
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
|
||||
side_channel_host: str, side_channel_port: int,
|
||||
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
|
||||
kv_caches: dict[str, Any]):
|
||||
kv_caches: dict[str, Any], pcp_rank: int):
|
||||
super().__init__(daemon=True, name="KVCacheSendingThread")
|
||||
self.tp_rank = tp_rank
|
||||
self.decode_tp_size = decode_tp_size
|
||||
@@ -152,6 +163,7 @@ class KVCacheSendingThread(threading.Thread):
|
||||
self.metadata = metadata
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches = kv_caches
|
||||
self.pcp_rank = pcp_rank
|
||||
|
||||
self.task_tracker = KVCacheTaskTracker()
|
||||
|
||||
@@ -183,7 +195,8 @@ class KVCacheSendingThread(threading.Thread):
|
||||
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
|
||||
# us moving. We will switch when moving to etcd or where we have a
|
||||
# single ZMQ socket in the scheduler.
|
||||
handshake_port = self.side_channel_port + self.tp_rank
|
||||
handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \
|
||||
+ self.tp_rank
|
||||
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
||||
logger.info("Starting listening on path: %s", path)
|
||||
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
|
||||
@@ -616,6 +629,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
remote_pcp_size=kv_transfer_params["remote_pcp_size"],
|
||||
remote_dcp_size=kv_transfer_params["remote_dcp_size"],
|
||||
)
|
||||
|
||||
|
||||
@@ -713,14 +728,18 @@ class MooncakeConnectorScheduler:
|
||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||
|
||||
self.side_channel_host = get_ip()
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size \
|
||||
if prefill_context_parallel_enable() else 1
|
||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
|
||||
vllm_config.parallel_config.data_parallel_size
|
||||
vllm_config.parallel_config.data_parallel_size * \
|
||||
self.pcp_size
|
||||
|
||||
# Handshake base port
|
||||
self.side_channel_port = (
|
||||
vllm_config.kv_transfer_config.kv_port +
|
||||
vllm_config.parallel_config.data_parallel_rank *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
|
||||
|
||||
# Requests that need to start recv.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
@@ -848,6 +867,8 @@ class MooncakeConnectorScheduler:
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
remote_pcp_size=self.pcp_size,
|
||||
remote_dcp_size=self.dcp_size,
|
||||
last_token_id=request.output_token_ids[-1],
|
||||
)
|
||||
|
||||
@@ -875,7 +896,15 @@ class MooncakeConnectorWorker:
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.side_channel_host = get_ip()
|
||||
self.max_device_id = self.tp_size * self.dp_size
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
self.max_device_id = self.tp_size * self.dp_size * self.pcp_size
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
|
||||
@@ -883,8 +912,8 @@ class MooncakeConnectorWorker:
|
||||
self.side_channel_port = (
|
||||
vllm_config.kv_transfer_config.kv_port +
|
||||
vllm_config.parallel_config.data_parallel_rank *
|
||||
vllm_config.parallel_config.tensor_parallel_size)
|
||||
self.handshake_port = self.side_channel_port + self.tp_rank
|
||||
vllm_config.parallel_config.tensor_parallel_size * self.pcp_size)
|
||||
self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank
|
||||
self.sockets: dict = {}
|
||||
|
||||
# get tp device id
|
||||
@@ -893,20 +922,23 @@ class MooncakeConnectorWorker:
|
||||
device_ids_str = envs_ascend.PHYSICAL_DEVICES
|
||||
if device_ids_str is None:
|
||||
device_ids = list(
|
||||
range(self.dp_rank * self.tp_size,
|
||||
(self.dp_rank + 1) * self.tp_size))
|
||||
range(self.dp_rank * self.tp_size * self.pcp_size,
|
||||
(self.dp_rank + 1) * self.tp_size * self.pcp_size))
|
||||
else:
|
||||
device_ids = list(map(int, device_ids_str.split(',')))
|
||||
start_index = self.dp_rank * self.tp_size
|
||||
end_index = start_index + self.tp_size
|
||||
start_index = self.dp_rank * self.tp_size * self.pcp_size
|
||||
end_index = start_index + self.tp_size * self.pcp_size
|
||||
if len(device_ids) < end_index:
|
||||
raise ValueError(
|
||||
f"Not enough physical devices available for DP rank {self.dp_rank}. "
|
||||
f"Expected at least {end_index} devices, but found {len(device_ids)} "
|
||||
"in PHYSICAL_DEVICES.")
|
||||
device_ids = device_ids[start_index:end_index]
|
||||
assert len(device_ids) > self.tp_rank # type: ignore
|
||||
self.device_id = device_ids[self.tp_rank] # type: ignore
|
||||
assert len(
|
||||
device_ids
|
||||
) > self.pcp_rank * self.tp_size + self.tp_rank # type: ignore
|
||||
self.device_id = device_ids[self.pcp_rank * self.tp_size +
|
||||
self.tp_rank] # type: ignore
|
||||
|
||||
if vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
'use_ascend_direct', True):
|
||||
@@ -1061,7 +1093,7 @@ class MooncakeConnectorWorker:
|
||||
self.kv_send_thread = KVCacheSendingThread(
|
||||
self.tp_rank, self._decode_tp_size, self.engine_id,
|
||||
self.side_channel_host, self.side_channel_port, metadata,
|
||||
ready_event, self.kv_caches)
|
||||
ready_event, self.kv_caches, self.pcp_rank)
|
||||
self.kv_send_thread.start()
|
||||
else:
|
||||
self.kv_recv_thread = KVCacheRecvingThread(
|
||||
@@ -1094,6 +1126,92 @@ class MooncakeConnectorWorker:
|
||||
"requests: %d", len(done_sending), len(done_recving))
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_kv_split_metadata(
|
||||
self,
|
||||
req_id: str,
|
||||
meta: ReqMeta,
|
||||
) -> tuple[list[list[int]], list[list[int]], list[list[int]]]:
|
||||
"""
|
||||
In cp/dcp scenario, kv_cache may be split, so we need to pull multiple blocks from multiple remote P node.
|
||||
Use this function to calculate remote port and remote block number of each remote P node that we need to pull.
|
||||
"""
|
||||
if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1:
|
||||
choosen_rank_list = self._get_remote_tp_rank(req_id)
|
||||
remote_handshake_port_list = [[
|
||||
x + meta.remote_port for x in choosen_rank_list
|
||||
]]
|
||||
local_block_ids_list, remote_block_ids_list = [
|
||||
meta.local_block_ids
|
||||
], [meta.remote_block_ids]
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
|
||||
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
|
||||
# remote & local cp/dcp are equal, do kv transfer point-to-point
|
||||
remote_kv_num = 1
|
||||
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
|
||||
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
|
||||
remote_block_nums = [len(meta.remote_block_ids)]
|
||||
else:
|
||||
assert self.pcp_size == 1
|
||||
if self.use_mla:
|
||||
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
|
||||
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
|
||||
else:
|
||||
assert self.tp_size == self._prefill_tp_size and (
|
||||
self.dcp_size == 1
|
||||
or self.dcp_size == meta.remote_dcp_size)
|
||||
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
|
||||
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
|
||||
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
|
||||
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
|
||||
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
|
||||
cp_dcp_offsets = []
|
||||
for cp_idx in range(meta.remote_pcp_size):
|
||||
cp_offset = cp_idx * self._prefill_tp_size
|
||||
cp_dcp_offsets += list(
|
||||
range(cp_offset, cp_offset + remote_dcp_size))
|
||||
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
|
||||
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
|
||||
for cp_dcp_offset in cp_dcp_offsets]
|
||||
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
|
||||
local_block_num = len(meta.local_block_ids)
|
||||
remote_block_nums = [
|
||||
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
|
||||
] * meta.remote_pcp_size * remote_dcp_size
|
||||
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
|
||||
remote_dcp_size)
|
||||
for i in range(num_remain_blocks):
|
||||
remote_block_nums[i] += 1
|
||||
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
|
||||
remote_ports = remote_ports[
|
||||
num_remain_blocks:] + remote_ports[:num_remain_blocks]
|
||||
remote_block_nums = remote_block_nums[
|
||||
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
|
||||
|
||||
remote_handshake_port_list = []
|
||||
for remote_kv_id in range(remote_kv_num):
|
||||
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
|
||||
|
||||
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
|
||||
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
|
||||
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
|
||||
# D rank will get remote block 1 in port 30004 and save it in local block 5
|
||||
local_block_ids_list = []
|
||||
remote_block_ids_list = []
|
||||
local_block_offset = 0
|
||||
for remote_kv_id in range(len(remote_handshake_port_list)):
|
||||
num_blocks_to_pull = remote_block_nums[remote_kv_id]
|
||||
remote_block_ids_list.append(
|
||||
meta.remote_block_ids[:num_blocks_to_pull])
|
||||
local_block_ids_list.append(
|
||||
meta.local_block_ids[local_block_offset:local_block_offset +
|
||||
num_blocks_to_pull])
|
||||
local_block_offset += num_blocks_to_pull
|
||||
assert local_block_offset == len(meta.local_block_ids), \
|
||||
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
|
||||
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
"""Start loading KV blocks from remote engine."""
|
||||
for req_id, meta in metadata.requests.items():
|
||||
@@ -1103,21 +1221,25 @@ class MooncakeConnectorWorker:
|
||||
meta.remote_engine_id, len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids))
|
||||
|
||||
choosen_rank_list = self._get_remote_tp_rank(req_id)
|
||||
remote_handshake_port_list = [
|
||||
x + meta.remote_port for x in choosen_rank_list
|
||||
]
|
||||
for i in range(self.num_need_pulls):
|
||||
assert self.kv_recv_thread is not None
|
||||
self.kv_recv_thread.add_request(
|
||||
request_id=req_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
remote_engine_id=meta.remote_engine_id,
|
||||
remote_host=meta.remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[i],
|
||||
offset=i,
|
||||
num_need_pulls=self.num_need_pulls)
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
||||
req_id, meta)
|
||||
|
||||
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
||||
if len(local_block_ids_list[pcp_dcp_rank]) + len(
|
||||
remote_block_ids_list[pcp_dcp_rank]) == 0:
|
||||
continue
|
||||
for i in range(self.num_need_pulls):
|
||||
assert self.kv_recv_thread is not None
|
||||
self.kv_recv_thread.add_request(
|
||||
request_id=req_id,
|
||||
local_block_ids=local_block_ids_list[pcp_dcp_rank],
|
||||
remote_block_ids=remote_block_ids_list[pcp_dcp_rank],
|
||||
remote_engine_id=meta.remote_engine_id,
|
||||
remote_host=meta.remote_host,
|
||||
remote_handshake_port=remote_handshake_port_list[
|
||||
pcp_dcp_rank][i],
|
||||
offset=i,
|
||||
num_need_pulls=self.num_need_pulls)
|
||||
|
||||
if self.kv_send_thread is not None:
|
||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||
|
||||
Reference in New Issue
Block a user