[P/D] Improve the performance of Layerwise Connector (#5303)
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -1112,6 +1112,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
MagicMock(), MagicMock()
|
||||
]
|
||||
self.impl.num_kv_heads = self.impl.num_heads
|
||||
self.impl.is_kv_producer = False
|
||||
|
||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||
"mock_layer",
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
|
||||
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
|
||||
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
|
||||
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
|
||||
MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv,
|
||||
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
|
||||
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
|
||||
zmq_ctx)
|
||||
|
||||
@@ -71,7 +71,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
remote_port=7777,
|
||||
remote_te_rpc_port=6000,
|
||||
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
|
||||
metaserver="http://dummy")
|
||||
metaserver="http://dummy",
|
||||
chunk_finish=False)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
|
||||
@@ -113,11 +114,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
key = torch.zeros((cap, dim), dtype=torch.float32)
|
||||
value = torch.zeros((cap, dim), dtype=torch.float32)
|
||||
|
||||
thread._transfer_kv_cache(req_id="req1",
|
||||
req_meta=req_meta,
|
||||
layer_index=0,
|
||||
key=key,
|
||||
value=value)
|
||||
thread._transfer_kv_cache( # type: ignore
|
||||
req_id="req1",
|
||||
req_meta=req_meta,
|
||||
layer_index=0,
|
||||
key=key,
|
||||
value=value,
|
||||
reshape_cache_event=MagicMock())
|
||||
|
||||
self.engine.batch_transfer_sync_write.assert_called_once()
|
||||
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
|
||||
@@ -142,9 +145,37 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
def test_transfer_skips_when_no_local_blocks(self):
|
||||
req_meta = self.req_meta_base
|
||||
req_meta.local_block_ids = []
|
||||
self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros(
|
||||
(1, 8)), torch.zeros((1, 8)))
|
||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||
self.thread.pd_head_ratio = 1
|
||||
self.thread.block_len = [64, 128]
|
||||
|
||||
key = torch.zeros((1, 8), dtype=torch.float32)
|
||||
value = torch.zeros((1, 8), dtype=torch.float32)
|
||||
|
||||
reshape_cache_event = MagicMock()
|
||||
with patch.object(self.engine,
|
||||
'batch_transfer_sync_write') as mock_batch_transfer:
|
||||
mock_batch_transfer.return_value = 1
|
||||
|
||||
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
|
||||
value,
|
||||
reshape_cache_event): # type: ignore
|
||||
if not req_meta.local_block_ids:
|
||||
return
|
||||
self._transfer_kv_cache( # type: ignore
|
||||
req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event)
|
||||
|
||||
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
|
||||
self.thread._transfer_kv_cache( # type: ignore
|
||||
req_id="req2",
|
||||
req_meta=req_meta,
|
||||
layer_index=0,
|
||||
key=key,
|
||||
value=value,
|
||||
reshape_cache_event=reshape_cache_event)
|
||||
|
||||
mock_batch_transfer.assert_not_called()
|
||||
self.assertEqual(mock_batch_transfer.call_count, 0)
|
||||
|
||||
def test_transfer_skips_when_tp_not_sender(self):
|
||||
|
||||
@@ -161,8 +192,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
first_kv_cache=self.first_kv_cache,
|
||||
callback_func=MagicMock())
|
||||
req_meta = self.req_meta_base
|
||||
thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)),
|
||||
torch.zeros((1, 8)))
|
||||
thread._transfer_kv_cache( # type: ignore
|
||||
"req3",
|
||||
req_meta,
|
||||
0,
|
||||
torch.zeros((1, 8)),
|
||||
torch.zeros((1, 8)),
|
||||
reshape_cache_event=MagicMock())
|
||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||
|
||||
@patch(
|
||||
@@ -172,25 +208,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
||||
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
|
||||
)
|
||||
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
|
||||
|
||||
req_meta = self.req_meta_base
|
||||
req_meta.local_block_ids = [5, 6]
|
||||
req_meta.remote_block_ids = [10, 11]
|
||||
|
||||
req_meta.remote_kv_caches_base_addr = [
|
||||
7000, 8000, 9000, 10000, 11000, 12000
|
||||
]
|
||||
|
||||
req_meta.chunk_finish = True
|
||||
key = torch.zeros((1, 8), dtype=torch.float32)
|
||||
value = torch.zeros((1, 8), dtype=torch.float32)
|
||||
|
||||
self.thread._transfer_kv_cache("req5",
|
||||
req_meta,
|
||||
layer_index=2,
|
||||
key=key,
|
||||
value=value)
|
||||
send_task = MagicMock()
|
||||
send_task.layer_index = self.thread.total_layers - 1
|
||||
send_task.send_request = {"req5": req_meta}
|
||||
|
||||
self.thread.callback_func.assert_called_once()
|
||||
with patch.object(self.thread, 'callback_func') as mock_callback_func:
|
||||
self.thread._transfer_kv_cache( # type: ignore
|
||||
req_id="req5",
|
||||
req_meta=req_meta,
|
||||
layer_index=send_task.layer_index,
|
||||
key=key,
|
||||
value=value,
|
||||
reshape_cache_event=MagicMock())
|
||||
print(f"Callback called: {mock_callback_func.call_count} times")
|
||||
mock_callback_func.assert_called_once()
|
||||
|
||||
|
||||
class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||
@@ -468,6 +509,7 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
request = MockRequest("req1")
|
||||
|
||||
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
|
||||
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
|
||||
request.kv_transfer_params = {
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
@@ -505,7 +547,8 @@ class _MockSchedulerOutput:
|
||||
cached_new_block_ids=None,
|
||||
cached_num_computed=None,
|
||||
new_reqs=None,
|
||||
num_sched=None):
|
||||
num_sched=None,
|
||||
scheduled_spec_decode_tokens=None):
|
||||
self.scheduled_cached_reqs = SimpleNamespace(
|
||||
req_ids=cached_req_ids or [],
|
||||
new_block_ids=cached_new_block_ids or [],
|
||||
@@ -513,6 +556,7 @@ class _MockSchedulerOutput:
|
||||
)
|
||||
self.scheduled_new_reqs = new_reqs or []
|
||||
self.num_scheduled_tokens = num_sched or {}
|
||||
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
|
||||
|
||||
|
||||
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
@@ -549,43 +593,39 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
|
||||
|
||||
def test_update_state_after_alloc_decode_records_send_layerwise(self):
|
||||
req = MockRequest("req_u2",
|
||||
prompt_token_ids=list(range(10)),
|
||||
kv_transfer_params={"do_remote_decode": True})
|
||||
req = MockRequest(
|
||||
"req_u2",
|
||||
prompt_token_ids=list(range(10)),
|
||||
kv_transfer_params={
|
||||
"do_remote_decode": True,
|
||||
"remote_block_ids": [] # 修改为空列表 []
|
||||
})
|
||||
|
||||
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
|
||||
self.scheduler.update_state_after_alloc(req,
|
||||
blocks,
|
||||
num_external_tokens=0)
|
||||
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
|
||||
total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[
|
||||
"req_u2"]
|
||||
self.assertEqual(total_tokens, 10)
|
||||
self.assertEqual(local_block_ids, [7, 8, 9])
|
||||
self.assertIs(req_ref, req)
|
||||
|
||||
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
|
||||
req = MockRequest("req_b1",
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [1, 2],
|
||||
"remote_engine_id": "E",
|
||||
"remote_host": "H",
|
||||
"remote_port": 5555,
|
||||
"remote_te_rpc_port": 6000,
|
||||
"remote_kv_caches_base_addr": [10, 11],
|
||||
})
|
||||
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
|
||||
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
|
||||
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
|
||||
self.assertIn("req_b1", meta.requests)
|
||||
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
|
||||
self.assertEqual(info.local_block_ids, [7, 8, 9])
|
||||
self.assertIs(info.request, req)
|
||||
self.assertEqual(info.remote_block_ids, [])
|
||||
self.assertIsInstance(info.remote_block_ids, list)
|
||||
|
||||
def test_build_connector_meta_accumulates_cached_blocks(self):
|
||||
req = MockRequest("req_b2",
|
||||
prompt_token_ids=list(range(8)),
|
||||
kv_transfer_params={"do_remote_decode": True})
|
||||
req_meta = MagicMock(spec=ReqMeta)
|
||||
req_meta.local_block_ids = [1, 2, 3]
|
||||
req_meta.remote_block_ids = [4, 5]
|
||||
req_meta.remote_engine_id = "remote"
|
||||
req_meta.remote_host = "localhost"
|
||||
req_meta.remote_port = 5000
|
||||
req_meta.remote_te_rpc_port = 6000
|
||||
req_meta.remote_kv_caches_base_addr = [10, 20]
|
||||
req_meta.metaserver = "http://dummy"
|
||||
req_meta.chunk_finish = False
|
||||
|
||||
self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req)
|
||||
req_meta.extend_local_block_ids = MagicMock()
|
||||
self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta
|
||||
|
||||
out = _MockSchedulerOutput(
|
||||
cached_req_ids=["req_b2"],
|
||||
@@ -596,47 +636,53 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||
)
|
||||
meta = self.scheduler.build_connector_meta(out)
|
||||
self.assertEqual(len(meta.requests), 0)
|
||||
total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[
|
||||
"req_b2"]
|
||||
self.assertEqual(total, 8)
|
||||
self.assertEqual(block_ids, [1, 2, 3, 4])
|
||||
|
||||
def test_build_connector_meta_emits_when_tokens_reach_total(self):
|
||||
req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
|
||||
|
||||
req = MockRequest("req_b3",
|
||||
prompt_token_ids=list(range(12)),
|
||||
kv_transfer_params={
|
||||
"do_remote_decode": True,
|
||||
"remote_block_ids": [9],
|
||||
"remote_engine_id": "E",
|
||||
"remote_host": "H",
|
||||
"remote_port": 5555,
|
||||
"remote_te_rpc_port": 6000,
|
||||
"remote_kv_caches_base_addr": [10, 11],
|
||||
})
|
||||
self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100,
|
||||
101], req)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
|
||||
)
|
||||
def test_build_connector_meta_emits_when_tokens_reach_total(
|
||||
self, mock_group_concurrent_contiguous):
|
||||
req_meta = MagicMock(spec=ReqMeta)
|
||||
req_meta.local_block_ids = [1, 2, 3]
|
||||
req_meta.remote_block_ids = [4, 5]
|
||||
req_meta.remote_engine_id = "remote"
|
||||
req_meta.remote_host = "localhost"
|
||||
req_meta.remote_port = 5000
|
||||
req_meta.remote_te_rpc_port = 6000
|
||||
req_meta.remote_kv_caches_base_addr = [10, 20]
|
||||
req_meta.metaserver = "http://dummy"
|
||||
req_meta.chunk_finish = False
|
||||
send_req_info = MagicMock(spec=SendReqInfo)
|
||||
send_req_info.local_block_ids = [1, 2, 3]
|
||||
send_req_info.remote_block_ids = [4, 5]
|
||||
send_req_info.remote_cache_tokens = 100
|
||||
send_req_info.local_transferred_tokens = 50
|
||||
send_req_info.local_computed_tokens = 75
|
||||
send_req_info.request = MagicMock()
|
||||
send_req_info.extend_local_block_ids = MagicMock()
|
||||
send_req_info.update_computed_tokens = MagicMock()
|
||||
send_req_info.update_transferred_tokens = MagicMock()
|
||||
send_req_info.unpack = MagicMock(
|
||||
return_value=(send_req_info.local_block_ids,
|
||||
send_req_info.remote_block_ids,
|
||||
send_req_info.remote_cache_tokens,
|
||||
send_req_info.local_transferred_tokens,
|
||||
send_req_info.local_computed_tokens,
|
||||
send_req_info.request))
|
||||
|
||||
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
|
||||
out = _MockSchedulerOutput(
|
||||
cached_req_ids=["req_b3"],
|
||||
cached_new_block_ids=[([50], )],
|
||||
cached_num_computed=[8],
|
||||
new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)],
|
||||
new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)],
|
||||
num_sched={"req_b3": 4},
|
||||
)
|
||||
meta = self.scheduler.build_connector_meta(out)
|
||||
send_req_info.extend_local_block_ids.assert_called_once_with([50])
|
||||
self.assertIn("req_b3", meta.requests)
|
||||
rmeta = meta.requests["req_b3"]
|
||||
|
||||
self.assertEqual(rmeta.local_block_ids, [100, 101, 50])
|
||||
|
||||
self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise)
|
||||
|
||||
def test_request_finished_returns_false_none(self):
|
||||
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
|
||||
[1, 2])
|
||||
self.assertFalse(ok)
|
||||
self.assertIsNone(params)
|
||||
|
||||
|
||||
class TestHelperFunctions(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user