[bugfix] bugfix for PD disaggregate (#4319)
This PR is used to fix mooncake_connector in pcp/dcp case. When
executing function update_done_task_count, it is necessary to ensure
that both pcp/dcp and TP ranks have finished transferring KV cache.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
@@ -89,7 +89,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
|||||||
kv_caches: Dict[str, Any] = {}
|
kv_caches: Dict[str, Any] = {}
|
||||||
self.common_args = {
|
self.common_args = {
|
||||||
'tp_rank': 1,
|
'tp_rank': 1,
|
||||||
'decode_tp_size': 4,
|
'prefill_tp_size': 4,
|
||||||
'local_engine_id': 'engine_1',
|
'local_engine_id': 'engine_1',
|
||||||
'side_channel_host': 'localhost',
|
'side_channel_host': 'localhost',
|
||||||
'side_channel_port': 5555,
|
'side_channel_port': 5555,
|
||||||
@@ -133,7 +133,7 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
|||||||
kv_caches: Dict[str, Any] = {}
|
kv_caches: Dict[str, Any] = {}
|
||||||
self.common_args = {
|
self.common_args = {
|
||||||
'tp_rank': 1,
|
'tp_rank': 1,
|
||||||
'decode_tp_size': 4,
|
'prefill_tp_size': 4,
|
||||||
'local_engine_id': 'engine_1',
|
'local_engine_id': 'engine_1',
|
||||||
'side_channel_host': 'localhost',
|
'side_channel_host': 'localhost',
|
||||||
'side_channel_port': 5555,
|
'side_channel_port': 5555,
|
||||||
@@ -171,7 +171,7 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
|||||||
free_port = s.getsockname()[1]
|
free_port = s.getsockname()[1]
|
||||||
|
|
||||||
thread = KVCacheSendingThread(tp_rank=0,
|
thread = KVCacheSendingThread(tp_rank=0,
|
||||||
decode_tp_size=1,
|
prefill_tp_size=1,
|
||||||
local_engine_id="engine1",
|
local_engine_id="engine1",
|
||||||
side_channel_host=host,
|
side_channel_host=host,
|
||||||
side_channel_port=free_port,
|
side_channel_port=free_port,
|
||||||
@@ -237,7 +237,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
|||||||
"remote_host": "localhost",
|
"remote_host": "localhost",
|
||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
"offset": 0,
|
"offset": 0,
|
||||||
"num_need_pulls": 2
|
"num_need_pulls": 2,
|
||||||
|
"all_task_done": False
|
||||||
}
|
}
|
||||||
self.thread.add_request(
|
self.thread.add_request(
|
||||||
request_id=test_req["request_id"],
|
request_id=test_req["request_id"],
|
||||||
@@ -247,7 +248,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
|||||||
remote_host=test_req["remote_host"],
|
remote_host=test_req["remote_host"],
|
||||||
remote_handshake_port=test_req["remote_handshake_port"],
|
remote_handshake_port=test_req["remote_handshake_port"],
|
||||||
offset=test_req["offset"],
|
offset=test_req["offset"],
|
||||||
num_need_pulls=test_req["num_need_pulls"])
|
num_need_pulls=test_req["num_need_pulls"],
|
||||||
|
all_task_done=test_req["all_task_done"])
|
||||||
queued = self.thread.request_queue.get_nowait()
|
queued = self.thread.request_queue.get_nowait()
|
||||||
self.assertEqual(queued["request_id"], "req1")
|
self.assertEqual(queued["request_id"], "req1")
|
||||||
self.assertEqual(queued["remote_host"], "localhost")
|
self.assertEqual(queued["remote_host"], "localhost")
|
||||||
@@ -341,7 +343,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
|||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
"remote_transfer_port": 7777,
|
"remote_transfer_port": 7777,
|
||||||
"offset": 0,
|
"offset": 0,
|
||||||
"num_need_pulls": 2
|
"num_need_pulls": 2,
|
||||||
|
"all_task_done": False
|
||||||
}
|
}
|
||||||
self.thread.task_tracker = MagicMock()
|
self.thread.task_tracker = MagicMock()
|
||||||
self.engine.batch_transfer_sync_read.return_value = 0
|
self.engine.batch_transfer_sync_read.return_value = 0
|
||||||
@@ -485,7 +488,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
|||||||
"remote_handshake_port": 6666,
|
"remote_handshake_port": 6666,
|
||||||
"remote_transfer_port": 7777,
|
"remote_transfer_port": 7777,
|
||||||
"offset": 0,
|
"offset": 0,
|
||||||
"num_need_pulls": 2
|
"num_need_pulls": 2,
|
||||||
|
"all_task_done": False
|
||||||
}
|
}
|
||||||
|
|
||||||
self.thread.request_queue.put(test_request)
|
self.thread.request_queue.put(test_request)
|
||||||
|
|||||||
@@ -150,13 +150,14 @@ class KVCacheTaskTracker:
|
|||||||
|
|
||||||
class KVCacheSendingThread(threading.Thread):
|
class KVCacheSendingThread(threading.Thread):
|
||||||
|
|
||||||
def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str,
|
def __init__(self, tp_rank: int, prefill_tp_size: int,
|
||||||
side_channel_host: str, side_channel_port: int,
|
local_engine_id: str, side_channel_host: str,
|
||||||
metadata: MooncakeAgentMetadata, ready_event: threading.Event,
|
side_channel_port: int, metadata: MooncakeAgentMetadata,
|
||||||
kv_caches: dict[str, Any], pcp_rank: int):
|
ready_event: threading.Event, kv_caches: dict[str, Any],
|
||||||
|
pcp_rank: int):
|
||||||
super().__init__(daemon=True, name="KVCacheSendingThread")
|
super().__init__(daemon=True, name="KVCacheSendingThread")
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.decode_tp_size = decode_tp_size
|
self.prefill_tp_size = prefill_tp_size
|
||||||
self.local_engine_id = local_engine_id
|
self.local_engine_id = local_engine_id
|
||||||
self.side_channel_host = side_channel_host
|
self.side_channel_host = side_channel_host
|
||||||
self.side_channel_port = side_channel_port
|
self.side_channel_port = side_channel_port
|
||||||
@@ -195,7 +196,7 @@ class KVCacheSendingThread(threading.Thread):
|
|||||||
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
|
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
|
||||||
# us moving. We will switch when moving to etcd or where we have a
|
# us moving. We will switch when moving to etcd or where we have a
|
||||||
# single ZMQ socket in the scheduler.
|
# single ZMQ socket in the scheduler.
|
||||||
handshake_port = self.side_channel_port + self.pcp_rank * self.decode_tp_size \
|
handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \
|
||||||
+ self.tp_rank
|
+ self.tp_rank
|
||||||
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
|
||||||
logger.info("Starting listening on path: %s", path)
|
logger.info("Starting listening on path: %s", path)
|
||||||
@@ -295,7 +296,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
def add_request(self, request_id: str, local_block_ids: list[int],
|
def add_request(self, request_id: str, local_block_ids: list[int],
|
||||||
remote_block_ids: list[int], remote_engine_id: str,
|
remote_block_ids: list[int], remote_engine_id: str,
|
||||||
remote_host: str, remote_handshake_port: int, offset: int,
|
remote_host: str, remote_handshake_port: int, offset: int,
|
||||||
num_need_pulls: int):
|
num_need_pulls: int, all_task_done: bool):
|
||||||
"""Add a new request to the queue for processing."""
|
"""Add a new request to the queue for processing."""
|
||||||
logger.debug(f"Adding request {request_id} to the queue.")
|
logger.debug(f"Adding request {request_id} to the queue.")
|
||||||
self.request_queue.put({
|
self.request_queue.put({
|
||||||
@@ -306,7 +307,8 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
"remote_host": remote_host,
|
"remote_host": remote_host,
|
||||||
"remote_handshake_port": remote_handshake_port,
|
"remote_handshake_port": remote_handshake_port,
|
||||||
"offset": offset,
|
"offset": offset,
|
||||||
"num_need_pulls": num_need_pulls
|
"num_need_pulls": num_need_pulls,
|
||||||
|
"all_task_done": all_task_done
|
||||||
})
|
})
|
||||||
|
|
||||||
def get_and_clear_finished_requests(self) -> set[str]:
|
def get_and_clear_finished_requests(self) -> set[str]:
|
||||||
@@ -335,8 +337,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
request_id = req_meta["request_id"]
|
request_id = req_meta["request_id"]
|
||||||
remote_host = req_meta["remote_host"]
|
remote_host = req_meta["remote_host"]
|
||||||
remote_handshake_port = req_meta["remote_handshake_port"]
|
remote_handshake_port = req_meta["remote_handshake_port"]
|
||||||
offset = req_meta["offset"]
|
all_task_done = req_meta["all_task_done"]
|
||||||
num_need_pulls = req_meta["num_need_pulls"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -353,7 +354,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
# remote host.
|
# remote host.
|
||||||
self._send_done_recv_signal(request_id, remote_host,
|
self._send_done_recv_signal(request_id, remote_host,
|
||||||
remote_handshake_port)
|
remote_handshake_port)
|
||||||
if offset == num_need_pulls - 1:
|
if all_task_done:
|
||||||
self.task_tracker.update_done_task_count(request_id)
|
self.task_tracker.update_done_task_count(request_id)
|
||||||
self.request_queue.task_done()
|
self.request_queue.task_done()
|
||||||
|
|
||||||
@@ -1091,7 +1092,7 @@ class MooncakeConnectorWorker:
|
|||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
if self.kv_role == 'kv_producer':
|
if self.kv_role == 'kv_producer':
|
||||||
self.kv_send_thread = KVCacheSendingThread(
|
self.kv_send_thread = KVCacheSendingThread(
|
||||||
self.tp_rank, self._decode_tp_size, self.engine_id,
|
self.tp_rank, self._prefill_tp_size, self.engine_id,
|
||||||
self.side_channel_host, self.side_channel_port, metadata,
|
self.side_channel_host, self.side_channel_port, metadata,
|
||||||
ready_event, self.kv_caches, self.pcp_rank)
|
ready_event, self.kv_caches, self.pcp_rank)
|
||||||
self.kv_send_thread.start()
|
self.kv_send_thread.start()
|
||||||
@@ -1239,7 +1240,10 @@ class MooncakeConnectorWorker:
|
|||||||
remote_handshake_port=remote_handshake_port_list[
|
remote_handshake_port=remote_handshake_port_list[
|
||||||
pcp_dcp_rank][i],
|
pcp_dcp_rank][i],
|
||||||
offset=i,
|
offset=i,
|
||||||
num_need_pulls=self.num_need_pulls)
|
num_need_pulls=self.num_need_pulls,
|
||||||
|
all_task_done=(pcp_dcp_rank
|
||||||
|
== len(remote_handshake_port_list) - 1
|
||||||
|
and i == self.num_need_pulls - 1))
|
||||||
|
|
||||||
if self.kv_send_thread is not None:
|
if self.kv_send_thread is not None:
|
||||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||||
|
|||||||
Reference in New Issue
Block a user