[P/D] Improve the performance of Layerwise Connector (#5303)
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -1112,6 +1112,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
MagicMock(), MagicMock()
|
MagicMock(), MagicMock()
|
||||||
]
|
]
|
||||||
self.impl.num_kv_heads = self.impl.num_heads
|
self.impl.num_kv_heads = self.impl.num_heads
|
||||||
|
self.impl.is_kv_producer = False
|
||||||
|
|
||||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||||
"mock_layer",
|
"mock_layer",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
|
|||||||
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
|
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
|
||||||
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
|
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
|
||||||
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
|
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
|
||||||
MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv,
|
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
|
||||||
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
|
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
|
||||||
zmq_ctx)
|
zmq_ctx)
|
||||||
|
|
||||||
@@ -71,7 +71,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
remote_port=7777,
|
remote_port=7777,
|
||||||
remote_te_rpc_port=6000,
|
remote_te_rpc_port=6000,
|
||||||
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
|
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
|
||||||
metaserver="http://dummy")
|
metaserver="http://dummy",
|
||||||
|
chunk_finish=False)
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
|
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
|
||||||
@@ -113,11 +114,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
key = torch.zeros((cap, dim), dtype=torch.float32)
|
key = torch.zeros((cap, dim), dtype=torch.float32)
|
||||||
value = torch.zeros((cap, dim), dtype=torch.float32)
|
value = torch.zeros((cap, dim), dtype=torch.float32)
|
||||||
|
|
||||||
thread._transfer_kv_cache(req_id="req1",
|
thread._transfer_kv_cache( # type: ignore
|
||||||
req_meta=req_meta,
|
req_id="req1",
|
||||||
layer_index=0,
|
req_meta=req_meta,
|
||||||
key=key,
|
layer_index=0,
|
||||||
value=value)
|
key=key,
|
||||||
|
value=value,
|
||||||
|
reshape_cache_event=MagicMock())
|
||||||
|
|
||||||
self.engine.batch_transfer_sync_write.assert_called_once()
|
self.engine.batch_transfer_sync_write.assert_called_once()
|
||||||
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
|
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
|
||||||
@@ -142,9 +145,37 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
def test_transfer_skips_when_no_local_blocks(self):
|
def test_transfer_skips_when_no_local_blocks(self):
|
||||||
req_meta = self.req_meta_base
|
req_meta = self.req_meta_base
|
||||||
req_meta.local_block_ids = []
|
req_meta.local_block_ids = []
|
||||||
self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros(
|
self.thread.pd_head_ratio = 1
|
||||||
(1, 8)), torch.zeros((1, 8)))
|
self.thread.block_len = [64, 128]
|
||||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
|
||||||
|
key = torch.zeros((1, 8), dtype=torch.float32)
|
||||||
|
value = torch.zeros((1, 8), dtype=torch.float32)
|
||||||
|
|
||||||
|
reshape_cache_event = MagicMock()
|
||||||
|
with patch.object(self.engine,
|
||||||
|
'batch_transfer_sync_write') as mock_batch_transfer:
|
||||||
|
mock_batch_transfer.return_value = 1
|
||||||
|
|
||||||
|
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
|
||||||
|
value,
|
||||||
|
reshape_cache_event): # type: ignore
|
||||||
|
if not req_meta.local_block_ids:
|
||||||
|
return
|
||||||
|
self._transfer_kv_cache( # type: ignore
|
||||||
|
req_id, req_meta, layer_index, key, value,
|
||||||
|
reshape_cache_event)
|
||||||
|
|
||||||
|
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
|
||||||
|
self.thread._transfer_kv_cache( # type: ignore
|
||||||
|
req_id="req2",
|
||||||
|
req_meta=req_meta,
|
||||||
|
layer_index=0,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
reshape_cache_event=reshape_cache_event)
|
||||||
|
|
||||||
|
mock_batch_transfer.assert_not_called()
|
||||||
|
self.assertEqual(mock_batch_transfer.call_count, 0)
|
||||||
|
|
||||||
def test_transfer_skips_when_tp_not_sender(self):
|
def test_transfer_skips_when_tp_not_sender(self):
|
||||||
|
|
||||||
@@ -161,8 +192,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
first_kv_cache=self.first_kv_cache,
|
first_kv_cache=self.first_kv_cache,
|
||||||
callback_func=MagicMock())
|
callback_func=MagicMock())
|
||||||
req_meta = self.req_meta_base
|
req_meta = self.req_meta_base
|
||||||
thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)),
|
thread._transfer_kv_cache( # type: ignore
|
||||||
torch.zeros((1, 8)))
|
"req3",
|
||||||
|
req_meta,
|
||||||
|
0,
|
||||||
|
torch.zeros((1, 8)),
|
||||||
|
torch.zeros((1, 8)),
|
||||||
|
reshape_cache_event=MagicMock())
|
||||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
@@ -172,25 +208,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
|
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
|
||||||
)
|
)
|
||||||
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
|
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
|
||||||
|
|
||||||
req_meta = self.req_meta_base
|
req_meta = self.req_meta_base
|
||||||
req_meta.local_block_ids = [5, 6]
|
req_meta.local_block_ids = [5, 6]
|
||||||
req_meta.remote_block_ids = [10, 11]
|
req_meta.remote_block_ids = [10, 11]
|
||||||
|
|
||||||
req_meta.remote_kv_caches_base_addr = [
|
req_meta.remote_kv_caches_base_addr = [
|
||||||
7000, 8000, 9000, 10000, 11000, 12000
|
7000, 8000, 9000, 10000, 11000, 12000
|
||||||
]
|
]
|
||||||
|
req_meta.chunk_finish = True
|
||||||
key = torch.zeros((1, 8), dtype=torch.float32)
|
key = torch.zeros((1, 8), dtype=torch.float32)
|
||||||
value = torch.zeros((1, 8), dtype=torch.float32)
|
value = torch.zeros((1, 8), dtype=torch.float32)
|
||||||
|
|
||||||
self.thread._transfer_kv_cache("req5",
|
send_task = MagicMock()
|
||||||
req_meta,
|
send_task.layer_index = self.thread.total_layers - 1
|
||||||
layer_index=2,
|
send_task.send_request = {"req5": req_meta}
|
||||||
key=key,
|
|
||||||
value=value)
|
|
||||||
|
|
||||||
self.thread.callback_func.assert_called_once()
|
with patch.object(self.thread, 'callback_func') as mock_callback_func:
|
||||||
|
self.thread._transfer_kv_cache( # type: ignore
|
||||||
|
req_id="req5",
|
||||||
|
req_meta=req_meta,
|
||||||
|
layer_index=send_task.layer_index,
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
reshape_cache_event=MagicMock())
|
||||||
|
print(f"Callback called: {mock_callback_func.call_count} times")
|
||||||
|
mock_callback_func.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
||||||
@@ -468,6 +509,7 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
|
|||||||
request = MockRequest("req1")
|
request = MockRequest("req1")
|
||||||
|
|
||||||
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
|
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
|
||||||
|
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
|
||||||
request.kv_transfer_params = {
|
request.kv_transfer_params = {
|
||||||
"remote_block_ids": [1, 2, 3],
|
"remote_block_ids": [1, 2, 3],
|
||||||
"remote_engine_id": "remote",
|
"remote_engine_id": "remote",
|
||||||
@@ -505,7 +547,8 @@ class _MockSchedulerOutput:
|
|||||||
cached_new_block_ids=None,
|
cached_new_block_ids=None,
|
||||||
cached_num_computed=None,
|
cached_num_computed=None,
|
||||||
new_reqs=None,
|
new_reqs=None,
|
||||||
num_sched=None):
|
num_sched=None,
|
||||||
|
scheduled_spec_decode_tokens=None):
|
||||||
self.scheduled_cached_reqs = SimpleNamespace(
|
self.scheduled_cached_reqs = SimpleNamespace(
|
||||||
req_ids=cached_req_ids or [],
|
req_ids=cached_req_ids or [],
|
||||||
new_block_ids=cached_new_block_ids or [],
|
new_block_ids=cached_new_block_ids or [],
|
||||||
@@ -513,6 +556,7 @@ class _MockSchedulerOutput:
|
|||||||
)
|
)
|
||||||
self.scheduled_new_reqs = new_reqs or []
|
self.scheduled_new_reqs = new_reqs or []
|
||||||
self.num_scheduled_tokens = num_sched or {}
|
self.num_scheduled_tokens = num_sched or {}
|
||||||
|
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
|
||||||
|
|
||||||
|
|
||||||
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
||||||
@@ -549,43 +593,39 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
|||||||
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
|
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
|
||||||
|
|
||||||
def test_update_state_after_alloc_decode_records_send_layerwise(self):
|
def test_update_state_after_alloc_decode_records_send_layerwise(self):
|
||||||
req = MockRequest("req_u2",
|
req = MockRequest(
|
||||||
prompt_token_ids=list(range(10)),
|
"req_u2",
|
||||||
kv_transfer_params={"do_remote_decode": True})
|
prompt_token_ids=list(range(10)),
|
||||||
|
kv_transfer_params={
|
||||||
|
"do_remote_decode": True,
|
||||||
|
"remote_block_ids": [] # 修改为空列表 []
|
||||||
|
})
|
||||||
|
|
||||||
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
|
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
|
||||||
self.scheduler.update_state_after_alloc(req,
|
self.scheduler.update_state_after_alloc(req,
|
||||||
blocks,
|
blocks,
|
||||||
num_external_tokens=0)
|
num_external_tokens=0)
|
||||||
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
|
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
|
||||||
total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[
|
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
|
||||||
"req_u2"]
|
self.assertEqual(info.local_block_ids, [7, 8, 9])
|
||||||
self.assertEqual(total_tokens, 10)
|
self.assertIs(info.request, req)
|
||||||
self.assertEqual(local_block_ids, [7, 8, 9])
|
self.assertEqual(info.remote_block_ids, [])
|
||||||
self.assertIs(req_ref, req)
|
self.assertIsInstance(info.remote_block_ids, list)
|
||||||
|
|
||||||
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
|
|
||||||
req = MockRequest("req_b1",
|
|
||||||
kv_transfer_params={
|
|
||||||
"remote_block_ids": [1, 2],
|
|
||||||
"remote_engine_id": "E",
|
|
||||||
"remote_host": "H",
|
|
||||||
"remote_port": 5555,
|
|
||||||
"remote_te_rpc_port": 6000,
|
|
||||||
"remote_kv_caches_base_addr": [10, 11],
|
|
||||||
})
|
|
||||||
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
|
|
||||||
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
|
|
||||||
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
|
|
||||||
self.assertIn("req_b1", meta.requests)
|
|
||||||
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
|
|
||||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
|
||||||
|
|
||||||
def test_build_connector_meta_accumulates_cached_blocks(self):
|
def test_build_connector_meta_accumulates_cached_blocks(self):
|
||||||
req = MockRequest("req_b2",
|
req_meta = MagicMock(spec=ReqMeta)
|
||||||
prompt_token_ids=list(range(8)),
|
req_meta.local_block_ids = [1, 2, 3]
|
||||||
kv_transfer_params={"do_remote_decode": True})
|
req_meta.remote_block_ids = [4, 5]
|
||||||
|
req_meta.remote_engine_id = "remote"
|
||||||
|
req_meta.remote_host = "localhost"
|
||||||
|
req_meta.remote_port = 5000
|
||||||
|
req_meta.remote_te_rpc_port = 6000
|
||||||
|
req_meta.remote_kv_caches_base_addr = [10, 20]
|
||||||
|
req_meta.metaserver = "http://dummy"
|
||||||
|
req_meta.chunk_finish = False
|
||||||
|
|
||||||
self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req)
|
req_meta.extend_local_block_ids = MagicMock()
|
||||||
|
self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta
|
||||||
|
|
||||||
out = _MockSchedulerOutput(
|
out = _MockSchedulerOutput(
|
||||||
cached_req_ids=["req_b2"],
|
cached_req_ids=["req_b2"],
|
||||||
@@ -596,47 +636,53 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
meta = self.scheduler.build_connector_meta(out)
|
meta = self.scheduler.build_connector_meta(out)
|
||||||
self.assertEqual(len(meta.requests), 0)
|
self.assertEqual(len(meta.requests), 0)
|
||||||
total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[
|
|
||||||
"req_b2"]
|
|
||||||
self.assertEqual(total, 8)
|
|
||||||
self.assertEqual(block_ids, [1, 2, 3, 4])
|
|
||||||
|
|
||||||
def test_build_connector_meta_emits_when_tokens_reach_total(self):
|
req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
|
||||||
|
|
||||||
req = MockRequest("req_b3",
|
@patch(
|
||||||
prompt_token_ids=list(range(12)),
|
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
|
||||||
kv_transfer_params={
|
)
|
||||||
"do_remote_decode": True,
|
def test_build_connector_meta_emits_when_tokens_reach_total(
|
||||||
"remote_block_ids": [9],
|
self, mock_group_concurrent_contiguous):
|
||||||
"remote_engine_id": "E",
|
req_meta = MagicMock(spec=ReqMeta)
|
||||||
"remote_host": "H",
|
req_meta.local_block_ids = [1, 2, 3]
|
||||||
"remote_port": 5555,
|
req_meta.remote_block_ids = [4, 5]
|
||||||
"remote_te_rpc_port": 6000,
|
req_meta.remote_engine_id = "remote"
|
||||||
"remote_kv_caches_base_addr": [10, 11],
|
req_meta.remote_host = "localhost"
|
||||||
})
|
req_meta.remote_port = 5000
|
||||||
self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100,
|
req_meta.remote_te_rpc_port = 6000
|
||||||
101], req)
|
req_meta.remote_kv_caches_base_addr = [10, 20]
|
||||||
|
req_meta.metaserver = "http://dummy"
|
||||||
|
req_meta.chunk_finish = False
|
||||||
|
send_req_info = MagicMock(spec=SendReqInfo)
|
||||||
|
send_req_info.local_block_ids = [1, 2, 3]
|
||||||
|
send_req_info.remote_block_ids = [4, 5]
|
||||||
|
send_req_info.remote_cache_tokens = 100
|
||||||
|
send_req_info.local_transferred_tokens = 50
|
||||||
|
send_req_info.local_computed_tokens = 75
|
||||||
|
send_req_info.request = MagicMock()
|
||||||
|
send_req_info.extend_local_block_ids = MagicMock()
|
||||||
|
send_req_info.update_computed_tokens = MagicMock()
|
||||||
|
send_req_info.update_transferred_tokens = MagicMock()
|
||||||
|
send_req_info.unpack = MagicMock(
|
||||||
|
return_value=(send_req_info.local_block_ids,
|
||||||
|
send_req_info.remote_block_ids,
|
||||||
|
send_req_info.remote_cache_tokens,
|
||||||
|
send_req_info.local_transferred_tokens,
|
||||||
|
send_req_info.local_computed_tokens,
|
||||||
|
send_req_info.request))
|
||||||
|
|
||||||
|
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
|
||||||
out = _MockSchedulerOutput(
|
out = _MockSchedulerOutput(
|
||||||
cached_req_ids=["req_b3"],
|
cached_req_ids=["req_b3"],
|
||||||
cached_new_block_ids=[([50], )],
|
cached_new_block_ids=[([50], )],
|
||||||
cached_num_computed=[8],
|
cached_num_computed=[8],
|
||||||
new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)],
|
new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)],
|
||||||
num_sched={"req_b3": 4},
|
num_sched={"req_b3": 4},
|
||||||
)
|
)
|
||||||
meta = self.scheduler.build_connector_meta(out)
|
meta = self.scheduler.build_connector_meta(out)
|
||||||
|
send_req_info.extend_local_block_ids.assert_called_once_with([50])
|
||||||
self.assertIn("req_b3", meta.requests)
|
self.assertIn("req_b3", meta.requests)
|
||||||
rmeta = meta.requests["req_b3"]
|
|
||||||
|
|
||||||
self.assertEqual(rmeta.local_block_ids, [100, 101, 50])
|
|
||||||
|
|
||||||
self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise)
|
|
||||||
|
|
||||||
def test_request_finished_returns_false_none(self):
|
|
||||||
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
|
|
||||||
[1, 2])
|
|
||||||
self.assertFalse(ok)
|
|
||||||
self.assertIsNone(params)
|
|
||||||
|
|
||||||
|
|
||||||
class TestHelperFunctions(unittest.TestCase):
|
class TestHelperFunctions(unittest.TestCase):
|
||||||
|
|||||||
@@ -176,6 +176,8 @@ class AscendMetadata:
|
|||||||
causal: bool = True
|
causal: bool = True
|
||||||
# runner_type in model_config.
|
# runner_type in model_config.
|
||||||
model_runner_type: str = ""
|
model_runner_type: str = ""
|
||||||
|
# prefill reshape_and_cache event
|
||||||
|
reshape_cache_event: torch.npu.Event = None
|
||||||
|
|
||||||
# sliding window attention mask
|
# sliding window attention mask
|
||||||
swa_mask: Optional[torch.Tensor] = None
|
swa_mask: Optional[torch.Tensor] = None
|
||||||
@@ -333,6 +335,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
|
||||||
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
|
||||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||||
@@ -654,6 +657,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if len(kv_cache) > 1:
|
if len(kv_cache) > 1:
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||||
if self.key_cache is None:
|
if self.key_cache is None:
|
||||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||||
slots = attn_metadata.slot_mapping
|
slots = attn_metadata.slot_mapping
|
||||||
@@ -674,6 +679,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
value_cache=self.value_cache,
|
value_cache=self.value_cache,
|
||||||
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
slot_indices=slots[:attn_metadata.num_actual_tokens])
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event.record()
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
def forward_impl(
|
def forward_impl(
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ class AscendMLAMetadata:
|
|||||||
|
|
||||||
decode: Optional[AscendMLADecodeMetadata] = None
|
decode: Optional[AscendMLADecodeMetadata] = None
|
||||||
prefill: Optional[AscendMLAPrefillMetadata] = None
|
prefill: Optional[AscendMLAPrefillMetadata] = None
|
||||||
|
reshape_cache_event: torch.npu.Event = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -705,6 +706,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_sharing_target_layer_name: Optional[str],
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
self.vllm_config = get_current_vllm_config()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@@ -752,6 +754,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.speculative_config = self.vllm_config.speculative_config
|
self.speculative_config = self.vllm_config.speculative_config
|
||||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
|
|
||||||
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
|
||||||
x = x.view(self.num_heads, -1, self.kv_lora_rank)
|
x = x.view(self.num_heads, -1, self.kv_lora_rank)
|
||||||
@@ -1351,8 +1355,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
prefill_slots = attn_metadata.slot_mapping[
|
prefill_slots = attn_metadata.slot_mapping[
|
||||||
num_decode_tokens:num_actual_tokens]
|
num_decode_tokens:num_actual_tokens]
|
||||||
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||||
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
|
||||||
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event.record()
|
||||||
prefill_k_nope, prefill_value = self.kv_b_proj(
|
prefill_k_nope, prefill_value = self.kv_b_proj(
|
||||||
prefill_k_c_normed)[0].view(
|
prefill_k_c_normed)[0].view(
|
||||||
-1, self.num_heads,
|
-1, self.num_heads,
|
||||||
|
|||||||
@@ -65,6 +65,32 @@ class ReqMeta:
|
|||||||
remote_te_rpc_port: Optional[int]
|
remote_te_rpc_port: Optional[int]
|
||||||
remote_kv_caches_base_addr: Optional[list[int]]
|
remote_kv_caches_base_addr: Optional[list[int]]
|
||||||
metaserver: Optional[str]
|
metaserver: Optional[str]
|
||||||
|
chunk_finish: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SendReqInfo:
|
||||||
|
local_block_ids: list[int]
|
||||||
|
remote_block_ids: List[int]
|
||||||
|
remote_cache_tokens: int
|
||||||
|
local_transferred_tokens: int
|
||||||
|
local_computed_tokens: int
|
||||||
|
request: "Request"
|
||||||
|
|
||||||
|
def extend_local_block_ids(self, new_block_ids: List[int]) -> None:
|
||||||
|
"""extend local block ids for this step"""
|
||||||
|
self.local_block_ids.extend(new_block_ids)
|
||||||
|
|
||||||
|
def update_computed_tokens(self, computed_tokens: int) -> None:
|
||||||
|
"""update local computen tokens for this step"""
|
||||||
|
self.local_computed_tokens = computed_tokens
|
||||||
|
|
||||||
|
def update_transferred_tokens(self, transferred_tokens: int) -> None:
|
||||||
|
"""update transferred tokens for this step"""
|
||||||
|
self.local_transferred_tokens = transferred_tokens
|
||||||
|
|
||||||
|
def unpack(self):
|
||||||
|
return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
raise RuntimeError("Mooncake memory registration failed. ")
|
raise RuntimeError("Mooncake memory registration failed. ")
|
||||||
|
|
||||||
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
|
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
|
||||||
torch.Tensor]]()
|
torch.Tensor, torch.npu.Event]]()
|
||||||
|
|
||||||
self.ready_event = ready_event
|
self.ready_event = ready_event
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
@@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
torch.npu.set_device(device)
|
torch.npu.set_device(device)
|
||||||
self.ready_event.set()
|
self.ready_event.set()
|
||||||
while True:
|
while True:
|
||||||
req_id, req_meta, layer_index, key, value = self.send_queue.get()
|
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
|
||||||
self._handle_request(req_id, req_meta, layer_index, key, value)
|
)
|
||||||
|
self._handle_request(req_id, req_meta, layer_index, key, value,
|
||||||
|
reshape_cache_event)
|
||||||
|
|
||||||
def _handle_request(self, req_id, req_meta, layer_index, key, value):
|
def _handle_request(self, req_id, req_meta, layer_index, key, value,
|
||||||
|
reshape_cache_event):
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||||
)
|
)
|
||||||
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value)
|
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
|
||||||
|
reshape_cache_event)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||||
)
|
)
|
||||||
@@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
logger.error("Failed to transfer KV cache for request "
|
logger.error("Failed to transfer KV cache for request "
|
||||||
f"{req_id}: {e}")
|
f"{req_id}: {e}")
|
||||||
|
|
||||||
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
|
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
|
||||||
# send kv layer to remote
|
reshape_cache_event):
|
||||||
if len(req_meta.local_block_ids) == 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
# not need to send kv cache
|
# not need to send kv cache
|
||||||
if self.tp_rank % self.num_head_replica != 0:
|
if self.tp_rank % self.num_head_replica != 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
length_list.append(length)
|
length_list.append(length)
|
||||||
if self.current_layer != layer_index:
|
if self.current_layer != layer_index:
|
||||||
self.current_layer = layer_index
|
self.current_layer = layer_index
|
||||||
self.model_stream.synchronize()
|
"""
|
||||||
|
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||||
|
This issue will be fixed in CANN version 8.5.rc1.
|
||||||
|
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||||
|
to resolve this issue before the 8.5.RC1 release.
|
||||||
|
"""
|
||||||
|
reshape_cache_event.synchronize()
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
session_id, src_list, dst_list, length_list)
|
session_id, src_list, dst_list, length_list)
|
||||||
if ret < 0:
|
if ret < 0:
|
||||||
@@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
logger.error("Mooncake transfer failed for request %s", req_id)
|
logger.error("Mooncake transfer failed for request %s", req_id)
|
||||||
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
||||||
|
|
||||||
if layer_index == (self.total_layers - 1):
|
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
|
||||||
self.callback_func(req_id, req_meta)
|
self.callback_func(req_id, req_meta)
|
||||||
|
|
||||||
|
|
||||||
@@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
local_block_ids: list[int],
|
local_block_ids: list[int],
|
||||||
kv_transfer_params: dict[str, Any],
|
kv_transfer_params: dict[str, Any],
|
||||||
token_ids: Optional[list[int]] = None):
|
token_ids: Optional[list[int]] = None,
|
||||||
|
chunk_finish: bool = False):
|
||||||
self.requests[request_id] = ReqMeta(
|
self.requests[request_id] = ReqMeta(
|
||||||
token_ids=token_ids or [],
|
token_ids=token_ids or [],
|
||||||
local_block_ids=local_block_ids,
|
local_block_ids=local_block_ids,
|
||||||
@@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
|||||||
remote_kv_caches_base_addr=kv_transfer_params.get(
|
remote_kv_caches_base_addr=kv_transfer_params.get(
|
||||||
"remote_kv_caches_base_addr", None),
|
"remote_kv_caches_base_addr", None),
|
||||||
metaserver=kv_transfer_params.get("metaserver", None),
|
metaserver=kv_transfer_params.get("metaserver", None),
|
||||||
)
|
chunk_finish=chunk_finish)
|
||||||
|
|
||||||
|
|
||||||
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
||||||
@@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
role: KVConnectorRole,
|
role: KVConnectorRole,
|
||||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||||
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
|
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
|
||||||
@@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
# the scheduler. Used to make metadata passed to Worker.
|
# the scheduler. Used to make metadata passed to Worker.
|
||||||
self._reqs_need_recv: dict[str, tuple[Request, list[int],
|
self._reqs_need_recv: dict[str, tuple[Request, list[int],
|
||||||
list[int]]] = {}
|
list[int]]] = {}
|
||||||
self._reqs_need_send_layerwise: dict[str, tuple[
|
self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {}
|
||||||
int, list[int],
|
|
||||||
Request]] = {} # req_id, (len(prompt), local_block_ids, request)
|
self.executor = ThreadPoolExecutor(32)
|
||||||
|
self.metaserver_client = httpx.Client(
|
||||||
|
limits=httpx.Limits(max_connections=100000), timeout=None)
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request",
|
self, request: "Request",
|
||||||
@@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
|
|
||||||
params["do_remote_prefill"] = False
|
params["do_remote_prefill"] = False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}"
|
||||||
|
)
|
||||||
|
# All parameters here should appear in the returned dict of
|
||||||
|
# request_finished in the scheduler side except "request_id".
|
||||||
|
kv_transfer_params = dict(
|
||||||
|
token_ids=[],
|
||||||
|
request_id=request.request_id,
|
||||||
|
do_remote_prefill=False,
|
||||||
|
do_remote_decode=True,
|
||||||
|
remote_block_ids=local_block_ids,
|
||||||
|
remote_engine_id=self.engine_id,
|
||||||
|
remote_host=self.side_channel_host,
|
||||||
|
remote_port=self.side_channel_port,
|
||||||
|
)
|
||||||
|
future = self.executor.submit(
|
||||||
|
self._access_metaserver,
|
||||||
|
url=params.get("metaserver", None),
|
||||||
|
message=kv_transfer_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_exception(future):
|
||||||
|
if future.exception():
|
||||||
|
logger.error(
|
||||||
|
f"Access metaserver fail: {future.exception()}")
|
||||||
|
|
||||||
|
future.add_done_callback(handle_exception)
|
||||||
|
|
||||||
# Layerwise prefiller add request need send
|
# Layerwise prefiller add request need send
|
||||||
if params is not None and params.get("do_remote_decode"):
|
if params is not None and params.get("do_remote_decode"):
|
||||||
local_block_ids = (blocks.get_block_ids()[0])
|
local_block_ids = (blocks.get_block_ids()[0])
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
|
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
|
||||||
)
|
)
|
||||||
self._reqs_need_send_layerwise[request.request_id] = (len(
|
remote_block_ids = copy.deepcopy(params["remote_block_ids"])
|
||||||
request.all_token_ids), local_block_ids, request)
|
remote_cache_tokens = (
|
||||||
|
(len(request.all_token_ids) + self.block_size - 1) //
|
||||||
|
self.block_size - len(remote_block_ids)) * self.block_size
|
||||||
|
local_transferred_tokens = remote_cache_tokens
|
||||||
|
local_computed_tokens = 0
|
||||||
|
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
|
||||||
|
local_block_ids=local_block_ids,
|
||||||
|
remote_block_ids=remote_block_ids,
|
||||||
|
remote_cache_tokens=remote_cache_tokens,
|
||||||
|
local_transferred_tokens=local_transferred_tokens,
|
||||||
|
local_computed_tokens=local_computed_tokens,
|
||||||
|
request=request)
|
||||||
|
|
||||||
def build_connector_meta(
|
def build_connector_meta(
|
||||||
self,
|
self,
|
||||||
@@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
) -> KVConnectorMetadata:
|
) -> KVConnectorMetadata:
|
||||||
meta = MooncakeLayerwiseConnectorMetadata()
|
meta = MooncakeLayerwiseConnectorMetadata()
|
||||||
|
|
||||||
# Loop through scheduled reqs and convert to ReqMeta.
|
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||||
for req_id, (req, token_ids,
|
# Loop through scheduled reqs and convert to ReqMeta.
|
||||||
block_ids) in self._reqs_need_recv.items():
|
for req_id, (req, token_ids,
|
||||||
assert req.kv_transfer_params is not None
|
block_ids) in self._reqs_need_recv.items():
|
||||||
# For the case where there are no remote blocks to pull
|
assert req.kv_transfer_params is not None
|
||||||
# (block_ids is empty), we don't need to schedule
|
# For the case where there are no remote blocks to pull
|
||||||
# an async read on the worker side.
|
# (block_ids is empty), we don't need to schedule
|
||||||
meta.add_new_req(request_id=req_id,
|
# an async read on the worker side.
|
||||||
local_block_ids=block_ids,
|
meta.add_new_req(request_id=req_id,
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
local_block_ids=block_ids,
|
||||||
token_ids=token_ids)
|
kv_transfer_params=req.kv_transfer_params,
|
||||||
|
token_ids=token_ids)
|
||||||
|
|
||||||
# Clear the list once workers start the transfers
|
# Clear the list once workers start the transfers
|
||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
|
else:
|
||||||
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
new_reqs = scheduler_output.scheduled_new_reqs
|
||||||
|
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||||
|
# update local block ids
|
||||||
|
for req_id, new_blocks in zip(cached_reqs.req_ids,
|
||||||
|
cached_reqs.new_block_ids):
|
||||||
|
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
|
||||||
|
self._reqs_need_send_layerwise[
|
||||||
|
req_id].extend_local_block_ids(new_blocks[0])
|
||||||
|
|
||||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
computed_tokens = dict(
|
||||||
new_reqs = scheduler_output.scheduled_new_reqs
|
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
|
||||||
for req_id, new_blocks in zip(cached_reqs.req_ids,
|
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs])
|
||||||
cached_reqs.new_block_ids):
|
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
|
||||||
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
|
):
|
||||||
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
|
if req_id in self._reqs_need_send_layerwise:
|
||||||
req_id]
|
send_req_info = self._reqs_need_send_layerwise[req_id]
|
||||||
block_ids.extend(new_blocks[0])
|
# update local computed tokens, not transfer spec decode tokens
|
||||||
|
spec_decode_tokens = len(
|
||||||
|
scheduled_spec_decode_tokens[req_id]) if (
|
||||||
|
req_id in scheduled_spec_decode_tokens) else 0
|
||||||
|
send_req_info.update_computed_tokens(
|
||||||
|
computed_tokens.get(req_id, 0) + scheduled_tokens -
|
||||||
|
spec_decode_tokens)
|
||||||
|
|
||||||
computed_tokens = dict(
|
def add_tranfer_task(req_id,
|
||||||
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) +
|
send_req_info: SendReqInfo,
|
||||||
[(x.req_id, x.num_computed_tokens) for x in new_reqs])
|
chunk_finish=False):
|
||||||
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
|
local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack(
|
||||||
):
|
)
|
||||||
if req_id in self._reqs_need_send_layerwise:
|
local_trans_block_ids = local_block_ids[(
|
||||||
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
|
local_transferred_tokens //
|
||||||
req_id]
|
self.block_size):(local_computed_tokens //
|
||||||
current_tokens = computed_tokens.get(req_id,
|
self.block_size)]
|
||||||
0) + scheduled_tokens
|
remote_trans_block_ids = remote_block_ids[(
|
||||||
if current_tokens >= total_tokens:
|
(local_transferred_tokens - remote_cache_tokens) //
|
||||||
logger.debug(
|
self.block_size):((local_computed_tokens -
|
||||||
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
remote_cache_tokens) //
|
||||||
)
|
self.block_size)]
|
||||||
meta.add_new_req(request_id=req_id,
|
request.kv_transfer_params[
|
||||||
local_block_ids=block_ids,
|
"remote_block_ids"] = remote_trans_block_ids
|
||||||
kv_transfer_params=req.kv_transfer_params,
|
assert len(local_trans_block_ids) == len(
|
||||||
token_ids=[])
|
remote_trans_block_ids
|
||||||
self._reqs_need_send_layerwise.pop(req_id)
|
), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}"
|
||||||
else:
|
adjusted_tokens = local_computed_tokens - (
|
||||||
logger.debug(
|
self.block_size -
|
||||||
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
1) if chunk_finish else local_computed_tokens
|
||||||
)
|
logger.info(
|
||||||
|
f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}"
|
||||||
|
)
|
||||||
|
meta.add_new_req(
|
||||||
|
request_id=req_id,
|
||||||
|
local_block_ids=local_trans_block_ids,
|
||||||
|
kv_transfer_params=request.kv_transfer_params,
|
||||||
|
token_ids=[],
|
||||||
|
chunk_finish=chunk_finish)
|
||||||
|
# update local_transferred_tokens
|
||||||
|
local_transferred_tokens = (
|
||||||
|
local_computed_tokens //
|
||||||
|
self.block_size) * self.block_size
|
||||||
|
send_req_info.update_transferred_tokens(
|
||||||
|
local_transferred_tokens)
|
||||||
|
|
||||||
|
# no chunk or last chunk
|
||||||
|
if send_req_info.local_computed_tokens >= len(
|
||||||
|
send_req_info.request.all_token_ids):
|
||||||
|
send_req_info.update_computed_tokens(
|
||||||
|
send_req_info.local_computed_tokens +
|
||||||
|
self.block_size - 1)
|
||||||
|
add_tranfer_task(req_id,
|
||||||
|
send_req_info,
|
||||||
|
chunk_finish=True)
|
||||||
|
self._reqs_need_send_layerwise.pop(req_id)
|
||||||
|
# chunk
|
||||||
|
elif (send_req_info.local_computed_tokens //
|
||||||
|
self.block_size) - (
|
||||||
|
send_req_info.local_transferred_tokens //
|
||||||
|
self.block_size) > 0:
|
||||||
|
add_tranfer_task(req_id, send_req_info)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
|
def _access_metaserver(self, url, message):
|
||||||
|
success = False
|
||||||
|
retry = 0
|
||||||
|
while retry < 3 and success is False:
|
||||||
|
retry += 1
|
||||||
|
try:
|
||||||
|
self.metaserver_client.post(url, json=message)
|
||||||
|
success = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to connect to metaserver: {url}, retry {retry} time."
|
||||||
|
)
|
||||||
|
if retry == 3:
|
||||||
|
raise e
|
||||||
|
|
||||||
def request_finished(
|
def request_finished(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
request: "Request",
|
||||||
@@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.total_layers = vllm_config.model_config.get_num_layers(
|
self.total_layers = vllm_config.model_config.get_num_layers(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
|
||||||
self.executor = ThreadPoolExecutor(32)
|
|
||||||
self.metaserver_client = httpx.Client(
|
|
||||||
limits=httpx.Limits(max_connections=100000),
|
|
||||||
timeout=None) if self.tp_rank == 0 else None
|
|
||||||
|
|
||||||
# Handshake base port
|
# Handshake base port
|
||||||
self.side_channel_port = (
|
self.side_channel_port = (
|
||||||
vllm_config.kv_transfer_config.kv_port +
|
vllm_config.kv_transfer_config.kv_port +
|
||||||
@@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.kv_recv_layer_thread.start()
|
self.kv_recv_layer_thread.start()
|
||||||
ready_event.wait()
|
ready_event.wait()
|
||||||
|
|
||||||
def _access_metaserver(self, url, message):
|
|
||||||
success = False
|
|
||||||
retry = 0
|
|
||||||
while retry < 3 and success is False:
|
|
||||||
retry += 1
|
|
||||||
try:
|
|
||||||
self.metaserver_client.post(url, json=message)
|
|
||||||
success = True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to connect to metaserver: {url}, retry {retry} time."
|
|
||||||
)
|
|
||||||
if retry == 3:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||||
done_recving = (
|
done_recving = (
|
||||||
self.kv_recv_layer_thread.
|
self.kv_recv_layer_thread.
|
||||||
@@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.current_layer = 0
|
self.current_layer = 0
|
||||||
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||||
for req_id, meta in metadata.requests.items():
|
for req_id, meta in metadata.requests.items():
|
||||||
if self.tp_rank % self.tp_size == 0:
|
|
||||||
logger.info(
|
|
||||||
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
|
|
||||||
)
|
|
||||||
# All parameters here should appear in the returned dict of
|
|
||||||
# request_finished in the scheduler side except "request_id".
|
|
||||||
kv_transfer_params = dict(
|
|
||||||
token_ids=meta.token_ids,
|
|
||||||
request_id=req_id,
|
|
||||||
do_remote_prefill=False,
|
|
||||||
do_remote_decode=True,
|
|
||||||
remote_block_ids=meta.local_block_ids,
|
|
||||||
remote_engine_id=self.engine_id,
|
|
||||||
remote_host=self.side_channel_host,
|
|
||||||
remote_port=self.side_channel_port,
|
|
||||||
)
|
|
||||||
future = self.executor.submit(
|
|
||||||
self._access_metaserver,
|
|
||||||
url=meta.metaserver,
|
|
||||||
message=kv_transfer_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle_exception(future):
|
|
||||||
if future.exception():
|
|
||||||
logger.error(
|
|
||||||
f"Access metaserver fail: {future.exception()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
future.add_done_callback(handle_exception)
|
|
||||||
assert self.kv_recv_layer_thread is not None
|
assert self.kv_recv_layer_thread is not None
|
||||||
with self.kv_recv_layer_thread.lock:
|
with self.kv_recv_layer_thread.lock:
|
||||||
self.kv_recv_layer_thread.task_tracker[req_id] = 0
|
self.kv_recv_layer_thread.task_tracker[req_id] = 0
|
||||||
@@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
||||||
):
|
):
|
||||||
# enable decode prefix cache
|
# enable decode prefix cache
|
||||||
for request in connector_metadata.requests.values():
|
if self.use_mla:
|
||||||
assert len(request.local_block_ids) >= len(
|
reshape_cache_event = attn_metadata[
|
||||||
request.remote_block_ids
|
layer_name].reshape_cache_event
|
||||||
), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num."
|
else:
|
||||||
request.local_block_ids = request.local_block_ids[
|
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||||
-len(request.remote_block_ids):]
|
|
||||||
if self.pd_head_ratio != 1:
|
if self.pd_head_ratio != 1:
|
||||||
|
|
||||||
def sort_kv_cache(input_kv: list[list[int]]):
|
def sort_kv_cache(input_kv: list[list[int]]):
|
||||||
@@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||||
)
|
)
|
||||||
assert self.kv_send_layer_thread is not None
|
assert self.kv_send_layer_thread is not None
|
||||||
|
assert reshape_cache_event is not None
|
||||||
self.kv_send_layer_thread.send_queue.put(
|
self.kv_send_layer_thread.send_queue.put(
|
||||||
(req_id, req_meta_update, self.current_layer, key, value))
|
(req_id, req_meta_update, self.current_layer, key, value,
|
||||||
|
reshape_cache_event))
|
||||||
self.current_layer += 1
|
self.current_layer += 1
|
||||||
|
|
||||||
def _get_remote_socket(
|
def _get_remote_socket(
|
||||||
|
|||||||
Reference in New Issue
Block a user