[feature] mooncake support pcp/dcp in common conditions (#5224)
### What this PR does / why we need it?
1. This PR is proposed to support complicated pcp/dcp parallelisms in
Prefill and Decode nodes in Mooncake, such as Prefill: TP8/PCP2DCP8 and
Decode: TP8/DCP4/DP2, which is not supported now. We establish the link
mappings to transfer KVCache between prefill and decode nodes. The main
function is realized in Function of `_get_kv_split_metadata` in
Mooncake_connector.py
2. After a prefill rank is pulled KVCache by a decode rank, the decode
rank will send `DONE_RECVING_MSG` to the prefill rank and the prefill
rank will free its KVCache blocks. If a prefill rank is pulled KVCache
more than one time by several decode ranks and it surely could happen in
complicated pcp/dcp parallelisms, it will cause the prefill rank free
its KVCache blocks for several times, which could cause memory issue.
This PR solve this issue by counting the times of prefill rank would be
pulled KVCache and in the last time, it will free the prefill rank
KVCache blocks. The related code is in Function of `run_busy_loop` in
Mooncake_connector.py
3. If a prefill rank is not pulled KVCache by any decode ranks, the
first rank in decode node will send "DONE_RECVING_MSG" to free its
blocks. The related code is in Function of
`_send_done_signal_to_free_remote_port` in Mooncake_connector.py
### How was this patch tested?
This PR is tested in many pcp/dcp parallelisms, and the accuracy are all
correct.
MLA model:
Prefill node: TP8/DP2, Decode node: TP8/DP2
Prefill node: TP8/PCP2/DCP8, Decode node: TP8/DP2
Prefill node: TP8/PCP2/DCP8, Decode node: TP8/DCP4/DP2
Prefill node: TP8/PCP2/DCP4, Decode node: TP4/DCP2/DP4
Prefill node: TP8/PCP2/DCP2, Decode node: TP4/DCP4/DP4
Prefill node: TP8/PCP2, Decode node: TP4/DCP2
GQA model:
Prefill node: TP8/DP2, Decode node: TP8/DP2
Prefill node: TP8/PCP2/DCP2, Decode node: TP8/DP2
Prefill node: TP8/PCP2/DCP2, Decode node: TP8/DCP2/DP2
Prefill node: TP8/PCP2/DCP2, Decode node: TP4/DP4
Prefill node: TP16/DCP2/PCP1, Decode node: TP8/DCP2/DP2
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
- Co-author by: Daishixun dsxtsteven@sina.com
---------
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -241,6 +241,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
side_channel_port=30000,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
@@ -294,6 +295,7 @@ class TestSocketManagement(unittest.TestCase):
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
side_channel_port=30000,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
@@ -351,6 +353,7 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
side_channel_port=30000,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
@@ -367,6 +370,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"tp_num_need_pulls": 2,
|
||||
"remote_port_send_num": {
|
||||
6666: 1
|
||||
},
|
||||
"all_task_done": False
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
@@ -382,7 +388,7 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
self.thread._handle_request(self.test_req)
|
||||
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666, {6666: 1})
|
||||
if not self.thread.task_tracker.update_done_task_count.called:
|
||||
self.thread.task_tracker.update_done_task_count("req1")
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
@@ -433,6 +439,7 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
side_channel_port=30000,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
@@ -497,6 +504,7 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
side_channel_port=30000,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event,
|
||||
@@ -654,6 +662,7 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
|
||||
meta.add_new_req(request_id="req1",
|
||||
local_block_ids=[1, 2, 3],
|
||||
num_external_tokens=48,
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": "remote_engine",
|
||||
@@ -706,7 +715,7 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
request = MockRequest("req1")
|
||||
blocks_mock = MagicMock()
|
||||
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
|
||||
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
|
||||
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6], 48)
|
||||
request.kv_transfer_params = {
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
@@ -1278,6 +1287,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
worker.pcp_rank = pcp_rank
|
||||
worker._prefill_tp_size = _prefill_tp_size
|
||||
worker.local_remote_block_port_mapping = None
|
||||
worker.block_size = 16
|
||||
worker.num_key_value_heads = 1
|
||||
|
||||
meta = types.SimpleNamespace()
|
||||
|
||||
@@ -1286,6 +1297,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
meta.remote_port = remote_port
|
||||
meta.remote_block_ids = remote_block_ids
|
||||
meta.local_block_ids = local_block_ids
|
||||
meta.num_external_tokens = pcp_size * dcp_size * len(
|
||||
local_block_ids) * worker.block_size
|
||||
meta.num_prompt_blocks = pcp_size * dcp_size * len(local_block_ids)
|
||||
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = worker._get_kv_split_metadata(
|
||||
'0', meta)
|
||||
@@ -1324,17 +1338,17 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30009], [30001]], [[], [1]], [[], [1]]))
|
||||
([[30000], [30008]], [[1], []], [[1], []]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(False, 1, 2, 8, 1, 0, 8, 2, 2, 30000, [1],
|
||||
[1]),
|
||||
([[30009], [30001]], [[], [1]], [[], [1]]))
|
||||
([[30000], [30008]], [[1], []], [[1], []]))
|
||||
|
||||
self.assertEqual(
|
||||
get_kv_split_metadata(True, 1, 2, 8, 0, 0, 8, 2, 2, 30000,
|
||||
[1, 2, 3], [1, 2, 3, 4, 5]),
|
||||
([[30008], [30000]], [[1, 2], [3, 4, 5]], [[1, 2], [1, 2, 3]]))
|
||||
([[30000], [30008]], [[1, 2, 3], [4, 5]], [[1, 2, 3], [1, 2]]))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
@@ -42,7 +43,7 @@ from vllm.v1.request import RequestStatus
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||
from vllm_ascend.utils import is_vl_model, prefill_context_parallel_enable
|
||||
from vllm_ascend.utils import is_vl_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -65,6 +66,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
num_external_tokens: int
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
@@ -72,6 +74,7 @@ class ReqMeta:
|
||||
remote_pcp_size: int
|
||||
remote_dcp_size: int
|
||||
remote_multi_nodes_meta_mapping: dict[str, dict[str, Any]]
|
||||
num_prompt_blocks: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -184,6 +187,7 @@ class KVCacheSendingThread(threading.Thread):
|
||||
self.ready_event = ready_event
|
||||
self.kv_caches = kv_caches
|
||||
self.pcp_rank = pcp_rank
|
||||
self.port_send_num: dict[str, int] = {}
|
||||
|
||||
self.task_tracker = KVCacheTaskTracker()
|
||||
|
||||
@@ -248,7 +252,22 @@ class KVCacheSendingThread(threading.Thread):
|
||||
elif msg[0] == DONE_RECVING_MSG:
|
||||
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
||||
request_id = msg[1]
|
||||
self.task_tracker.update_done_task_count(request_id)
|
||||
remote_port_send_num = msg[2]
|
||||
if remote_port_send_num:
|
||||
if request_id not in self.port_send_num:
|
||||
self.port_send_num[request_id] = 0
|
||||
self.port_send_num[request_id] += 1
|
||||
device_index = self.pp_rank * self.tp_size + \
|
||||
self.tp_rank + self.pcp_rank * \
|
||||
self.prefill_tp_size
|
||||
handshake_port = self.side_channel_port + device_index
|
||||
if self.port_send_num[request_id] >= \
|
||||
remote_port_send_num[handshake_port]:
|
||||
self.task_tracker.update_done_task_count(
|
||||
request_id)
|
||||
del self.port_send_num[request_id]
|
||||
else:
|
||||
self.task_tracker.update_done_task_count(request_id)
|
||||
# Acknowledge the request completion.
|
||||
while True:
|
||||
try:
|
||||
@@ -275,7 +294,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
|
||||
def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int,
|
||||
engine: TransferEngine, local_engine_id: str,
|
||||
local_handshake_port: int,
|
||||
local_handshake_port: int, side_channel_port: int,
|
||||
local_kv_caches_base_addr: list[int], block_len: list[int],
|
||||
ready_event: threading.Event, vllm_config: VllmConfig,
|
||||
kv_caches: dict[str, Any]):
|
||||
@@ -285,6 +304,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self._prefill_pp_size = _prefill_pp_size
|
||||
self.local_engine_id = local_engine_id
|
||||
self.local_handshake_port = local_handshake_port
|
||||
self.side_channel_port = side_channel_port
|
||||
self.engine = engine
|
||||
self.ready_event = ready_event
|
||||
|
||||
@@ -328,11 +348,19 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.num_kv_heads = max(
|
||||
self.model_config.hf_text_config.num_key_value_heads //
|
||||
self.tp_size, 1)
|
||||
self.proc_not_transfer_request: dict[str, bool] = {}
|
||||
|
||||
def add_request(self, request_id: str, local_block_ids: list[int],
|
||||
remote_block_ids: list[int], remote_engine_id: str,
|
||||
remote_host: str, remote_handshake_port: int, offset: int,
|
||||
tp_num_need_pulls: int, all_task_done: bool):
|
||||
def add_request(self,
|
||||
request_id: str,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
remote_engine_id: str,
|
||||
remote_host: str,
|
||||
remote_handshake_port: int,
|
||||
offset: int,
|
||||
tp_num_need_pulls: int,
|
||||
remote_port_send_num: dict[int, int] = {},
|
||||
all_task_done: bool = False):
|
||||
"""Add a new request to the queue for processing."""
|
||||
logger.debug(f"Adding request {request_id} to the queue.")
|
||||
self.request_queue.put({
|
||||
@@ -344,6 +372,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
"remote_handshake_port": remote_handshake_port,
|
||||
"offset": offset,
|
||||
"tp_num_need_pulls": tp_num_need_pulls,
|
||||
"remote_port_send_num": remote_port_send_num,
|
||||
"all_task_done": all_task_done
|
||||
})
|
||||
|
||||
@@ -373,6 +402,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
request_id = req_meta["request_id"]
|
||||
remote_host = req_meta["remote_host"]
|
||||
remote_handshake_port = req_meta["remote_handshake_port"]
|
||||
remote_port_send_num = req_meta["remote_port_send_num"]
|
||||
all_task_done = req_meta["all_task_done"]
|
||||
|
||||
try:
|
||||
@@ -389,11 +419,31 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
# resource cleanup. Failing to do so may cause a memory leak on the
|
||||
# remote host.
|
||||
self._send_done_recv_signal(request_id, remote_host,
|
||||
remote_handshake_port)
|
||||
remote_handshake_port,
|
||||
remote_port_send_num)
|
||||
self._send_done_signal_to_free_remote_port(request_id, remote_host,
|
||||
remote_port_send_num)
|
||||
if all_task_done:
|
||||
self.task_tracker.update_done_task_count(request_id)
|
||||
if request_id in self.proc_not_transfer_request:
|
||||
del self.proc_not_transfer_request[request_id]
|
||||
self.request_queue.task_done()
|
||||
|
||||
def _send_done_signal_to_free_remote_port(self, request_id, remote_host,
|
||||
remote_port_send_num):
|
||||
if self.side_channel_port != self.local_handshake_port \
|
||||
or not remote_port_send_num:
|
||||
return
|
||||
if request_id not in self.proc_not_transfer_request:
|
||||
self.proc_not_transfer_request[request_id] = True
|
||||
if self.proc_not_transfer_request[request_id]:
|
||||
for remote_port in remote_port_send_num.keys():
|
||||
if remote_port_send_num[remote_port] == 0:
|
||||
self._send_done_recv_signal(request_id, remote_host,
|
||||
remote_port,
|
||||
remote_port_send_num)
|
||||
self.proc_not_transfer_request[request_id] = False
|
||||
|
||||
def _transfer_kv_cache(self, req_meta: dict[str, Any]):
|
||||
"""Handle a KV cache transfer request."""
|
||||
request_id = req_meta["request_id"]
|
||||
@@ -606,13 +656,15 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
remote_handshake_port)
|
||||
|
||||
def _send_done_recv_signal(self, request_id: str, remote_host: str,
|
||||
remote_handshake_port: int):
|
||||
remote_handshake_port: int,
|
||||
remote_port_send_num: dict[int, int]):
|
||||
logger.debug("Sending done recving signal for request %s to %s:%d",
|
||||
request_id, remote_host, remote_handshake_port)
|
||||
sock: Optional[zmq.Socket] = None # type: ignore
|
||||
try:
|
||||
sock = self._get_remote_socket(remote_host, remote_handshake_port)
|
||||
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
|
||||
data_bytes = self.encoder.encode(
|
||||
(DONE_RECVING_MSG, request_id, remote_port_send_num))
|
||||
ensure_zmq_send(sock, data_bytes)
|
||||
resp = ensure_zmq_recv(sock,
|
||||
self.remote_poller,
|
||||
@@ -674,10 +726,12 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
self,
|
||||
request_id: str,
|
||||
local_block_ids: list[int],
|
||||
num_external_tokens: int,
|
||||
kv_transfer_params: dict[str, Any],
|
||||
):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
num_external_tokens=num_external_tokens,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
@@ -686,6 +740,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1),
|
||||
remote_multi_nodes_meta_mapping=kv_transfer_params.get(
|
||||
"remote_multi_nodes_meta_mapping", {}),
|
||||
num_prompt_blocks=kv_transfer_params.get("num_prompt_blocks", 0),
|
||||
)
|
||||
|
||||
|
||||
@@ -827,7 +882,7 @@ class MooncakeConnectorScheduler:
|
||||
# Requests that need to start recv.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int], int]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
# master-slave meta information for cross-nodes
|
||||
@@ -885,7 +940,7 @@ class MooncakeConnectorScheduler:
|
||||
if num_external_tokens > 0 else [])
|
||||
# Get unhashed blocks to pull from remote.
|
||||
self._reqs_need_recv[request.request_id] = (
|
||||
request, local_block_ids)
|
||||
request, local_block_ids, num_external_tokens)
|
||||
else:
|
||||
logger.warning(
|
||||
"Got invalid KVTransferParams: %s. This "
|
||||
@@ -902,7 +957,8 @@ class MooncakeConnectorScheduler:
|
||||
meta = MooncakeConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, block_ids) in self._reqs_need_recv.items():
|
||||
for req_id, (req, block_ids,
|
||||
num_external_tokens) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
# For the case where there are no remote blocks to pull
|
||||
# (block_ids is empty), we don't need to schedule
|
||||
@@ -910,6 +966,7 @@ class MooncakeConnectorScheduler:
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
num_external_tokens=num_external_tokens,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
)
|
||||
|
||||
@@ -946,6 +1003,9 @@ class MooncakeConnectorScheduler:
|
||||
len(computed_block_ids), request.request_id)
|
||||
self._reqs_need_send[request.request_id] = time.time()
|
||||
|
||||
num_prompt_blocks = math.ceil(
|
||||
len(request.prompt_token_ids) / self.block_size)
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
@@ -957,6 +1017,7 @@ class MooncakeConnectorScheduler:
|
||||
remote_dcp_size=self.dcp_size,
|
||||
last_token_id=request.output_token_ids[-1],
|
||||
remote_multi_nodes_meta_mapping=self.multi_nodes_meta_mapping,
|
||||
num_prompt_blocks=num_prompt_blocks,
|
||||
)
|
||||
|
||||
def set_xfer_handshake_metadata(
|
||||
@@ -1045,6 +1106,8 @@ class MooncakeConnectorWorker:
|
||||
num_p_block_heads = max(
|
||||
1, self.num_key_value_heads // self._prefill_tp_size)
|
||||
self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads
|
||||
self.local_remote_block_port_mapping = None
|
||||
self.remote_port_send_num: dict[int, int] = {}
|
||||
|
||||
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
||||
# get prefill tp and dp size from extra config
|
||||
@@ -1177,8 +1240,9 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
self.kv_recv_thread = KVCacheRecvingThread(
|
||||
self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine,
|
||||
self.engine_id, self.handshake_port, kv_caches_base_addr,
|
||||
self.block_len, ready_event, self.vllm_config, self.kv_caches)
|
||||
self.engine_id, self.handshake_port, self.side_channel_port,
|
||||
kv_caches_base_addr, self.block_len, ready_event,
|
||||
self.vllm_config, self.kv_caches)
|
||||
self.kv_recv_thread.start()
|
||||
|
||||
start_wait_time = time.time()
|
||||
@@ -1227,59 +1291,186 @@ class MooncakeConnectorWorker:
|
||||
], [meta.remote_block_ids]
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
|
||||
if self.pcp_size == meta.remote_pcp_size and self.dcp_size == meta.remote_dcp_size:
|
||||
# remote & local cp/dcp are equal, do kv transfer point-to-point
|
||||
remote_kv_num = 1
|
||||
remote_ports = [meta.remote_port + self.pcp_rank * self.tp_size + tp_offset \
|
||||
for tp_offset in range(self.tp_rank, int(self._prefill_tp_size), self.tp_size)]
|
||||
remote_block_nums = [len(meta.remote_block_ids)]
|
||||
else:
|
||||
assert self.pcp_size == 1
|
||||
if self.use_mla:
|
||||
assert (self.dcp_size == 1 and (self.tp_size == 1 or self.tp_size == self._prefill_tp_size)) or \
|
||||
(self.dcp_size == meta.remote_dcp_size and self.tp_size == self._prefill_tp_size)
|
||||
else:
|
||||
assert self.tp_size == self._prefill_tp_size and (
|
||||
self.dcp_size == 1
|
||||
or self.dcp_size == meta.remote_dcp_size)
|
||||
# remote & local cp/dcp are not equal, each D node needs to pull from pcp(*dcp) P nodes
|
||||
# 1. for mla, support D pcp_size = 1, D dcp_size = (1 or P dcp_size)
|
||||
# 2. for gqa, support D tp_size = P tp_size, D dcp_size = P dcp_size
|
||||
remote_dcp_size = meta.remote_dcp_size // self.dcp_size
|
||||
remote_kv_num = meta.remote_pcp_size * remote_dcp_size
|
||||
cp_dcp_offsets = []
|
||||
for cp_idx in range(meta.remote_pcp_size):
|
||||
cp_offset = cp_idx * self._prefill_tp_size
|
||||
cp_dcp_offsets += list(
|
||||
range(cp_offset, cp_offset + remote_dcp_size))
|
||||
tp_offset = self.tp_rank // remote_dcp_size * remote_dcp_size
|
||||
remote_ports = [meta.remote_port + cp_dcp_offset + tp_offset \
|
||||
for cp_dcp_offset in cp_dcp_offsets]
|
||||
# recompute cp/dcp block assign here, maybe we can also pass it from P node meta
|
||||
local_block_num = len(meta.local_block_ids)
|
||||
remote_block_nums = [
|
||||
local_block_num // (meta.remote_pcp_size * remote_dcp_size)
|
||||
] * meta.remote_pcp_size * remote_dcp_size
|
||||
num_remain_blocks = local_block_num % (meta.remote_pcp_size *
|
||||
remote_dcp_size)
|
||||
for i in range(num_remain_blocks):
|
||||
remote_block_nums[i] += 1
|
||||
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
|
||||
remote_ports = remote_ports[
|
||||
num_remain_blocks:] + remote_ports[:num_remain_blocks]
|
||||
remote_block_nums = remote_block_nums[
|
||||
num_remain_blocks:] + remote_block_nums[:num_remain_blocks]
|
||||
def context_parallel_parameters_check():
|
||||
assert (meta.remote_pcp_size * meta.remote_dcp_size) % (
|
||||
self.pcp_size * self.dcp_size) == 0
|
||||
if not self.use_mla:
|
||||
p_node_heads_per_rank = math.ceil(self.num_key_value_heads /
|
||||
self._prefill_tp_size)
|
||||
d_node_heads_per_rank = math.ceil(self.num_key_value_heads /
|
||||
self.tp_size)
|
||||
assert d_node_heads_per_rank % p_node_heads_per_rank == 0
|
||||
|
||||
remote_handshake_port_list = []
|
||||
for remote_kv_id in range(remote_kv_num):
|
||||
remote_handshake_port_list.append([remote_ports[remote_kv_id]])
|
||||
def get_kv_head_groups(tp_size):
|
||||
if self.use_mla:
|
||||
kv_head_groups = []
|
||||
kv_head_ids = [0]
|
||||
kv_head_groups.append(tuple(kv_head_ids))
|
||||
return kv_head_groups
|
||||
if self.num_key_value_heads // tp_size >= 1:
|
||||
kv_head_groups = []
|
||||
for tp_rank in range(tp_size):
|
||||
kv_head_ids = [head_idx + tp_rank * (self.num_key_value_heads // tp_size) \
|
||||
for head_idx in range(self.num_key_value_heads // tp_size)]
|
||||
kv_head_groups.append(tuple(kv_head_ids))
|
||||
return kv_head_groups
|
||||
if tp_size // self.num_key_value_heads > 1:
|
||||
kv_head_groups = []
|
||||
for kv_head_ids_ in range(self.num_key_value_heads):
|
||||
kv_head_groups.append(tuple([kv_head_ids_]))
|
||||
return kv_head_groups
|
||||
|
||||
def get_cp_group_meta(tp_size, pcp_size, dcp_size, port_base):
|
||||
# key is kv_head_group, value is cp_groups and which cp_groups to select
|
||||
cp_group_meta: dict = {}
|
||||
kv_head_groups = get_kv_head_groups(tp_size)
|
||||
dcp_repeat_num = tp_size // len(kv_head_groups) // dcp_size
|
||||
|
||||
for kv_head_group_idx, kv_head_group in enumerate(kv_head_groups):
|
||||
if kv_head_group not in cp_group_meta:
|
||||
cp_group_meta[kv_head_group] = {}
|
||||
cp_group_meta[kv_head_group]['cp_groups'] = []
|
||||
cp_group_meta[kv_head_group]['select_cp_groups_id'] = 0
|
||||
kv_head_group_offset = tp_size // len(
|
||||
kv_head_groups) * kv_head_group_idx
|
||||
for dcp_repeat_idx in range(dcp_repeat_num):
|
||||
# len(cp_group) == pcp_size * dcp_size
|
||||
cp_group = []
|
||||
dcp_repeat_offset = dcp_size * dcp_repeat_idx
|
||||
for pcp_rank in range(pcp_size):
|
||||
pcp_rank_offset = tp_size * pcp_rank
|
||||
for dcp_rank in range(dcp_size):
|
||||
cp_group.append(dcp_rank + port_base +
|
||||
pcp_rank_offset +
|
||||
dcp_repeat_offset +
|
||||
kv_head_group_offset)
|
||||
cp_group_meta[kv_head_group]['cp_groups'].append(cp_group)
|
||||
|
||||
return cp_group_meta
|
||||
|
||||
def get_local_remote_block_port_mappings():
|
||||
context_parallel_parameters_check()
|
||||
p_node_cp_group_meta = get_cp_group_meta(self._prefill_tp_size,
|
||||
meta.remote_pcp_size,
|
||||
meta.remote_dcp_size,
|
||||
meta.remote_port)
|
||||
d_node_cp_group_meta = get_cp_group_meta(self.tp_size,
|
||||
self.pcp_size,
|
||||
self.dcp_size,
|
||||
self.side_channel_port)
|
||||
local_remote_block_port_mappings: dict[int, list[list[int]]] = {}
|
||||
for d_node_head_key in d_node_cp_group_meta.keys():
|
||||
for p_node_head_key in p_node_cp_group_meta.keys():
|
||||
if not set(p_node_head_key).issubset(set(d_node_head_key)):
|
||||
continue
|
||||
d_node_head_group = d_node_cp_group_meta[d_node_head_key]
|
||||
p_node_head_group = p_node_cp_group_meta[p_node_head_key]
|
||||
for d_cp_group in d_node_head_group['cp_groups']:
|
||||
select_cp_groups_id = p_node_head_group[
|
||||
'select_cp_groups_id']
|
||||
p_cp_groups = p_node_head_group['cp_groups']
|
||||
p_cp_group = p_cp_groups[select_cp_groups_id]
|
||||
p_node_head_group['select_cp_groups_id'] = select_cp_groups_id + 1 \
|
||||
if select_cp_groups_id + 1 < len(p_cp_groups) else 0
|
||||
for d_idx, d_port in enumerate(d_cp_group):
|
||||
if d_port not in local_remote_block_port_mappings:
|
||||
local_remote_block_port_mappings[d_port] = []
|
||||
p_port_remote_list = []
|
||||
for p_idx, p_port in enumerate(p_cp_group):
|
||||
if p_idx % len(d_cp_group) == d_idx:
|
||||
p_port_remote_list.append(p_port)
|
||||
local_remote_block_port_mappings[d_port].append(
|
||||
p_port_remote_list)
|
||||
|
||||
logger.info(
|
||||
"p_node_cp_group_meta is:: %s. d_node_cp_group_meta is:: %s. "
|
||||
"local_remote_block_port_mappings is:: %s. ",
|
||||
p_node_cp_group_meta, d_node_cp_group_meta,
|
||||
local_remote_block_port_mappings)
|
||||
|
||||
return local_remote_block_port_mappings
|
||||
|
||||
def get_remote_port_send_num(local_remote_block_port_mappings):
|
||||
remote_port_send_num: dict[int, int] = {}
|
||||
for port in range(self._prefill_tp_size * meta.remote_pcp_size):
|
||||
remote_port_send_num[meta.remote_port + port] = 0
|
||||
for local_port in local_remote_block_port_mappings.keys():
|
||||
remote_port_head_list = local_remote_block_port_mappings[
|
||||
local_port]
|
||||
for remote_port_list in remote_port_head_list:
|
||||
for remote_port in remote_port_list:
|
||||
remote_port_send_num[remote_port] += 1
|
||||
return remote_port_send_num
|
||||
|
||||
if self.local_remote_block_port_mapping is None:
|
||||
local_remote_block_port_mappings = get_local_remote_block_port_mappings(
|
||||
)
|
||||
self.local_remote_block_port_mapping = local_remote_block_port_mappings[
|
||||
self.handshake_port]
|
||||
self.remote_port_send_num = get_remote_port_send_num(
|
||||
local_remote_block_port_mappings)
|
||||
|
||||
local_remote_block_port_mapping = copy.deepcopy(
|
||||
self.local_remote_block_port_mapping)
|
||||
|
||||
num_external_blocks = math.ceil(meta.num_external_tokens /
|
||||
self.block_size)
|
||||
|
||||
assert math.ceil(num_external_blocks / (self.pcp_size * self.dcp_size)) == len(meta.local_block_ids), \
|
||||
f"num_external_blocks({num_external_blocks}), cp_size({self.pcp_size * self.dcp_size}), " \
|
||||
f"local_block_ids_len ({len(meta.local_block_ids)})"
|
||||
assert meta.num_prompt_blocks >= num_external_blocks, \
|
||||
f"meta.num_prompt_blocks({meta.num_prompt_blocks}), num_external_blocks({num_external_blocks})"
|
||||
|
||||
remote_cp_size = meta.remote_pcp_size * meta.remote_dcp_size
|
||||
remote_block_nums_all = [meta.num_prompt_blocks // remote_cp_size
|
||||
] * remote_cp_size
|
||||
num_remain_blocks = meta.num_prompt_blocks % remote_cp_size
|
||||
for i in range(num_remain_blocks):
|
||||
remote_block_nums_all[i] += 1
|
||||
last_block_location = (num_remain_blocks + remote_cp_size -
|
||||
1) % remote_cp_size
|
||||
|
||||
# Considering prefix cache, the remote_block_nums_all should be revised
|
||||
num_prefix_cached_blocks = meta.num_prompt_blocks - num_external_blocks
|
||||
remote_block_nums_all = [
|
||||
num - num_prefix_cached_blocks // remote_cp_size
|
||||
for num in remote_block_nums_all
|
||||
]
|
||||
num_remain_blocks = num_prefix_cached_blocks % remote_cp_size
|
||||
for i in range(num_remain_blocks):
|
||||
remote_block_nums_all[i] -= 1
|
||||
|
||||
# make sure the last block (which may be unfull) of P nodes is put to the last block of D node
|
||||
remote_block_nums: list[int] = []
|
||||
final_block_idx: int | None = None
|
||||
local_cp_rank = self.dcp_rank + self.pcp_rank * self.dcp_size
|
||||
local_cp_size = self.dcp_size * self.pcp_size
|
||||
for cp_rank, block_num in enumerate(remote_block_nums_all):
|
||||
if cp_rank % local_cp_size == local_cp_rank:
|
||||
if last_block_location == cp_rank:
|
||||
final_block_idx = len(remote_block_nums)
|
||||
remote_block_nums.append(block_num)
|
||||
|
||||
assert local_remote_block_port_mapping is not None
|
||||
if final_block_idx is not None:
|
||||
final_block_num = remote_block_nums.pop(final_block_idx)
|
||||
remote_block_nums.append(final_block_num)
|
||||
for mapping in local_remote_block_port_mapping:
|
||||
final_block_port = mapping.pop(final_block_idx)
|
||||
mapping.append(final_block_port)
|
||||
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = [], [], []
|
||||
for idx in range(len(local_remote_block_port_mapping[0])):
|
||||
mapping_list = []
|
||||
for mapping in local_remote_block_port_mapping:
|
||||
mapping_list.append(mapping[idx])
|
||||
remote_handshake_port_list.append(mapping_list)
|
||||
|
||||
# the local_block_ids_list and remote_block_ids_list are related with remote_handshake_port_list
|
||||
# such as: local_block_ids_list[[1],[2],[5],[6]], remote_block_ids_list[[1],[1],[1],[1]],
|
||||
# remote_handshake_port_list[[30000],[30001],[30004],[30005]]
|
||||
# D rank will get remote block 1 in port 30004 and save it in local block 5
|
||||
local_block_ids_list = []
|
||||
remote_block_ids_list = []
|
||||
local_block_offset = 0
|
||||
for remote_kv_id in range(len(remote_handshake_port_list)):
|
||||
num_blocks_to_pull = remote_block_nums[remote_kv_id]
|
||||
@@ -1289,8 +1480,9 @@ class MooncakeConnectorWorker:
|
||||
meta.local_block_ids[local_block_offset:local_block_offset +
|
||||
num_blocks_to_pull])
|
||||
local_block_offset += num_blocks_to_pull
|
||||
assert local_block_offset == len(meta.local_block_ids), \
|
||||
f"local_block_offset ({local_block_offset}) should equal with local_block_ids len ({len(meta.local_block_ids)})"
|
||||
|
||||
assert self.tp_num_need_pulls == len(remote_handshake_port_list[0]), \
|
||||
f"tp_num_need_pulls: {self.tp_num_need_pulls}, remote_handshake_port_list: {remote_handshake_port_list}"
|
||||
|
||||
return remote_handshake_port_list, local_block_ids_list, remote_block_ids_list
|
||||
|
||||
@@ -1302,15 +1494,11 @@ class MooncakeConnectorWorker:
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id,
|
||||
meta.remote_engine_id, len(meta.local_block_ids),
|
||||
len(meta.remote_block_ids))
|
||||
|
||||
if prefill_context_parallel_enable():
|
||||
if meta.remote_pcp_size * meta.remote_dcp_size > 1:
|
||||
remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata(
|
||||
req_id, meta)
|
||||
|
||||
for pcp_dcp_rank in range(len(remote_handshake_port_list)):
|
||||
if len(local_block_ids_list[pcp_dcp_rank]) + len(
|
||||
remote_block_ids_list[pcp_dcp_rank]) == 0:
|
||||
continue
|
||||
for i in range(self.tp_num_need_pulls):
|
||||
assert self.kv_recv_thread is not None
|
||||
remote_host, remote_engine_id = self._get_remote_host_info_by_port(
|
||||
@@ -1329,6 +1517,7 @@ class MooncakeConnectorWorker:
|
||||
pcp_dcp_rank][i],
|
||||
offset=i,
|
||||
tp_num_need_pulls=self.tp_num_need_pulls,
|
||||
remote_port_send_num=self.remote_port_send_num,
|
||||
all_task_done=(
|
||||
pcp_dcp_rank
|
||||
== len(remote_handshake_port_list) - 1
|
||||
@@ -1355,7 +1544,7 @@ class MooncakeConnectorWorker:
|
||||
all_task_done=(i == self.tp_num_need_pulls *
|
||||
self._prefill_pp_size - 1))
|
||||
|
||||
if self.kv_send_thread is not None:
|
||||
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size == 1:
|
||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||
if self.tp_rank in self._prefill_get_remote_rank(req_id):
|
||||
self.kv_send_thread.add_delayed_request(
|
||||
@@ -1363,6 +1552,11 @@ class MooncakeConnectorWorker:
|
||||
else:
|
||||
self.kv_send_thread.add_not_transfer_request(req_id)
|
||||
|
||||
if self.kv_send_thread is not None and self.pcp_size * self.dcp_size > 1:
|
||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||
self.kv_send_thread.add_delayed_request(
|
||||
req_id, delay_start_time)
|
||||
|
||||
def _get_remote_host_info_by_port(self, base_port: int,
|
||||
remote_handshake_port: int,
|
||||
remote_host: str, remote_engine_id: str,
|
||||
|
||||
Reference in New Issue
Block a user