[0.11.0][Bugfix] fix delay free prefill req & D node support prefix cache (#3609)
### What this PR does / why we need it? Fix mooncake connector. In scenarios where TP is not equal, when the prefill TP size is less than the number of key-value heads, _get_remote_tp_ranks_for_req will return a list of np.arrays. Performing an operation like int in list of np.arrays will cause an error. Converting the list of np.arrays into a single np.array resolves this issue. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? qwen235B P tp16, D tp1 P tp8, D tp1 P tp4, D tp1 P tp8, D tp2 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -628,10 +628,12 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
config = MockVllmConfig()
|
config = MockVllmConfig()
|
||||||
with patch(
|
self.p1 = patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
|
||||||
):
|
new=MagicMock(return_value=None))
|
||||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
self.p1.start()
|
||||||
|
self.addCleanup(self.p1.stop)
|
||||||
|
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||||
|
|
||||||
def test_get_num_new_matched_tokens(self):
|
def test_get_num_new_matched_tokens(self):
|
||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
@@ -643,7 +645,7 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
request.kv_transfer_params = {"do_remote_prefill": True}
|
request.kv_transfer_params = {"do_remote_prefill": True}
|
||||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||||
request, 0)
|
request, 0)
|
||||||
self.assertEqual(tokens, 3)
|
self.assertEqual(tokens, 4)
|
||||||
self.assertTrue(async_flag)
|
self.assertTrue(async_flag)
|
||||||
|
|
||||||
def test_build_connector_meta(self):
|
def test_build_connector_meta(self):
|
||||||
@@ -820,7 +822,7 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
|
|||||||
kv_transfer_params={"do_remote_prefill": True})
|
kv_transfer_params={"do_remote_prefill": True})
|
||||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||||
request, 0)
|
request, 0)
|
||||||
self.assertEqual(tokens, 3)
|
self.assertEqual(tokens, 4)
|
||||||
self.assertTrue(async_flag)
|
self.assertTrue(async_flag)
|
||||||
|
|
||||||
def test_update_state_after_alloc_no_remote_prefill(self):
|
def test_update_state_after_alloc_no_remote_prefill(self):
|
||||||
@@ -1036,8 +1038,9 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
||||||
mock_string_to_int64_hash),
|
mock_string_to_int64_hash),
|
||||||
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
|
patch(
|
||||||
return_value=self.mock_transfer_engine),
|
'vllm_ascend.distributed.mooncake.transfer_engine.TransferEngine',
|
||||||
|
return_value=self.mock_transfer_engine),
|
||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
||||||
MagicMock()),
|
MagicMock()),
|
||||||
@@ -1063,7 +1066,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
p.stop() # type: ignore
|
p.stop() # type: ignore
|
||||||
|
|
||||||
@unittest.skip("skip")
|
|
||||||
def test_worker_use_ascend_direct(self):
|
def test_worker_use_ascend_direct(self):
|
||||||
test_case = [True, False]
|
test_case = [True, False]
|
||||||
|
|
||||||
@@ -1104,7 +1106,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
config, self.engine_id)
|
config, self.engine_id)
|
||||||
self.assertIsNotNone(worker)
|
self.assertIsNotNone(worker)
|
||||||
|
|
||||||
@unittest.skip("skip")
|
|
||||||
def test_register_kv_caches_producer(self):
|
def test_register_kv_caches_producer(self):
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
worker.register_kv_caches(self.kv_caches)
|
worker.register_kv_caches(self.kv_caches)
|
||||||
@@ -1112,7 +1113,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
self.assertIsNotNone(worker.kv_send_thread)
|
self.assertIsNotNone(worker.kv_send_thread)
|
||||||
self.assertIsNone(worker.kv_recv_thread)
|
self.assertIsNone(worker.kv_recv_thread)
|
||||||
|
|
||||||
@unittest.skip("skip")
|
|
||||||
def test_register_kv_caches_consumer(self):
|
def test_register_kv_caches_consumer(self):
|
||||||
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
|
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
@@ -1120,7 +1120,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
self.assertIsNone(worker.kv_send_thread)
|
self.assertIsNone(worker.kv_send_thread)
|
||||||
self.assertIsNotNone(worker.kv_recv_thread)
|
self.assertIsNotNone(worker.kv_recv_thread)
|
||||||
|
|
||||||
@unittest.skip("skip")
|
|
||||||
def test_register_kv_caches_mla_case(self):
|
def test_register_kv_caches_mla_case(self):
|
||||||
mla_cache1 = MagicMock()
|
mla_cache1 = MagicMock()
|
||||||
mla_cache1.size.return_value = (10, 16, 1, 16)
|
mla_cache1.size.return_value = (10, 16, 1, 16)
|
||||||
@@ -1133,7 +1132,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
|||||||
self.assertTrue(worker.use_mla)
|
self.assertTrue(worker.use_mla)
|
||||||
self.assertEqual(len(worker.block_len), 2)
|
self.assertEqual(len(worker.block_len), 2)
|
||||||
|
|
||||||
@unittest.skip("skip")
|
|
||||||
def test_device_id_selection_with_physical_devices(self):
|
def test_device_id_selection_with_physical_devices(self):
|
||||||
# Test with physical devices set
|
# Test with physical devices set
|
||||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||||
|
|||||||
@@ -342,9 +342,15 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
|
|
||||||
# Full prefix cache hit: do not need to read remote blocks, just notify
|
# Full prefix cache hit: do not need to read remote blocks, just notify
|
||||||
# P worker that we have the blocks we need.
|
# P worker that we have the blocks we need.
|
||||||
if len(local_block_ids) == 0:
|
num_local_blocks = len(local_block_ids)
|
||||||
|
if num_local_blocks == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
num_remote_blocks = len(remote_block_ids)
|
||||||
|
assert num_local_blocks <= num_remote_blocks
|
||||||
|
if num_local_blocks < num_remote_blocks:
|
||||||
|
remote_block_ids = remote_block_ids[-num_local_blocks:]
|
||||||
|
|
||||||
# Check if we have the remote metadata cached.
|
# Check if we have the remote metadata cached.
|
||||||
if remote_engine_id not in self.kv_caches_base_addr or \
|
if remote_engine_id not in self.kv_caches_base_addr or \
|
||||||
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
|
remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]:
|
||||||
@@ -736,13 +742,11 @@ class MooncakeConnectorScheduler:
|
|||||||
num_computed_tokens, params)
|
num_computed_tokens, params)
|
||||||
|
|
||||||
if params is not None and params.get("do_remote_prefill"):
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
assert num_computed_tokens == 0, "Currently only support " \
|
# Remote prefill: get all prompt blocks from remote.
|
||||||
"prefill with num_computed_tokens == 0."
|
assert num_computed_tokens % self.block_size == 0
|
||||||
# Assume that the request's KV cache is already fully prefilled and
|
# Note: We use the full token count as transmit data here.
|
||||||
# can be fetched entirely from the prefill node.
|
count = max(len(request.prompt_token_ids) - num_computed_tokens, 0)
|
||||||
count = max(len(request.prompt_token_ids) - 1, 0)
|
return count, count > 0
|
||||||
if count > 0:
|
|
||||||
return count, True
|
|
||||||
|
|
||||||
# No remote prefill for this request.
|
# No remote prefill for this request.
|
||||||
return 0, False
|
return 0, False
|
||||||
@@ -1137,10 +1141,13 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
if self.kv_send_thread is not None:
|
if self.kv_send_thread is not None:
|
||||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||||
if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id):
|
if self.tp_rank in self._prefill_get_remote_tp_rank(req_id):
|
||||||
self.kv_send_thread.add_delayed_request(
|
self.kv_send_thread.add_delayed_request(
|
||||||
req_id, delay_start_time)
|
req_id, delay_start_time)
|
||||||
|
|
||||||
|
def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]:
|
||||||
|
return sum(self._get_remote_tp_ranks_for_req(req_id), [])
|
||||||
|
|
||||||
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
|
def _get_remote_tp_rank(self, req_id: str) -> List[int]:
|
||||||
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
|
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
|
||||||
|
|
||||||
@@ -1176,8 +1183,8 @@ class MooncakeConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
group_size = self._prefill_tp_size // self._decode_tp_size
|
group_size = self._prefill_tp_size // self._decode_tp_size
|
||||||
for i in range(self._decode_tp_size):
|
for i in range(self._decode_tp_size):
|
||||||
slice = ori_data[i * group_size:(i + 1) * group_size]
|
ori_data_slice = ori_data[i * group_size:(i + 1) * group_size]
|
||||||
sampled_nums.append(slice)
|
sampled_nums.append(ori_data_slice.tolist())
|
||||||
return sampled_nums
|
return sampled_nums
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user