[bugfix] bugfix for PD disaggregate (#4319)
This PR is used to fix mooncake_connector in pcp/dcp case. When
executing function update_done_task_count, it is necessary to ensure
that both pcp/dcp and TP ranks have finished transferring KV cache.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: wangxiaochao <w00642655@china.huawei.com>
Co-authored-by: wangxiaochao <w00642655@china.huawei.com>
This commit is contained in:
@@ -89,7 +89,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'prefill_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
@@ -133,7 +133,7 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'prefill_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
@@ -171,7 +171,7 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
||||
free_port = s.getsockname()[1]
|
||||
|
||||
thread = KVCacheSendingThread(tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
prefill_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
@@ -237,7 +237,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
"num_need_pulls": 2,
|
||||
"all_task_done": False
|
||||
}
|
||||
self.thread.add_request(
|
||||
request_id=test_req["request_id"],
|
||||
@@ -247,7 +248,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
remote_host=test_req["remote_host"],
|
||||
remote_handshake_port=test_req["remote_handshake_port"],
|
||||
offset=test_req["offset"],
|
||||
num_need_pulls=test_req["num_need_pulls"])
|
||||
num_need_pulls=test_req["num_need_pulls"],
|
||||
all_task_done=test_req["all_task_done"])
|
||||
queued = self.thread.request_queue.get_nowait()
|
||||
self.assertEqual(queued["request_id"], "req1")
|
||||
self.assertEqual(queued["remote_host"], "localhost")
|
||||
@@ -341,7 +343,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
"num_need_pulls": 2,
|
||||
"all_task_done": False
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
self.engine.batch_transfer_sync_read.return_value = 0
|
||||
@@ -485,7 +488,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
"num_need_pulls": 2,
|
||||
"all_task_done": False
|
||||
}
|
||||
|
||||
self.thread.request_queue.put(test_request)
|
||||
|
||||
Reference in New Issue
Block a user