[P/D] Improve the performance of Layerwise Connector (#5303)

### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
zxr2333
2025-12-31 15:09:01 +08:00
committed by GitHub
parent 7d5242faca
commit 46a1614387
5 changed files with 354 additions and 202 deletions

View File

@@ -1112,6 +1112,7 @@ class TestAscendMLAImpl(TestBase):
MagicMock(), MagicMock() MagicMock(), MagicMock()
] ]
self.impl.num_kv_heads = self.impl.num_heads self.impl.num_kv_heads = self.impl.num_heads
self.impl.is_kv_producer = False
decode_res, prefill_res = self.impl._mla_preprocess( decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer", "mock_layer",

View File

@@ -18,7 +18,7 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole, KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector, MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler, MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv, MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash, ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
zmq_ctx) zmq_ctx)
@@ -71,7 +71,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
remote_port=7777, remote_port=7777,
remote_te_rpc_port=6000, remote_te_rpc_port=6000,
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000], remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
metaserver="http://dummy") metaserver="http://dummy",
chunk_finish=False)
@patch( @patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr", "vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
@@ -113,11 +114,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
key = torch.zeros((cap, dim), dtype=torch.float32) key = torch.zeros((cap, dim), dtype=torch.float32)
value = torch.zeros((cap, dim), dtype=torch.float32) value = torch.zeros((cap, dim), dtype=torch.float32)
thread._transfer_kv_cache(req_id="req1", thread._transfer_kv_cache( # type: ignore
req_meta=req_meta, req_id="req1",
layer_index=0, req_meta=req_meta,
key=key, layer_index=0,
value=value) key=key,
value=value,
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_called_once() self.engine.batch_transfer_sync_write.assert_called_once()
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[ session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
@@ -142,9 +145,37 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
def test_transfer_skips_when_no_local_blocks(self): def test_transfer_skips_when_no_local_blocks(self):
req_meta = self.req_meta_base req_meta = self.req_meta_base
req_meta.local_block_ids = [] req_meta.local_block_ids = []
self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros( self.thread.pd_head_ratio = 1
(1, 8)), torch.zeros((1, 8))) self.thread.block_len = [64, 128]
self.engine.batch_transfer_sync_write.assert_not_called()
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
reshape_cache_event = MagicMock()
with patch.object(self.engine,
'batch_transfer_sync_write') as mock_batch_transfer:
mock_batch_transfer.return_value = 1
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
value,
reshape_cache_event): # type: ignore
if not req_meta.local_block_ids:
return
self._transfer_kv_cache( # type: ignore
req_id, req_meta, layer_index, key, value,
reshape_cache_event)
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
self.thread._transfer_kv_cache( # type: ignore
req_id="req2",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=reshape_cache_event)
mock_batch_transfer.assert_not_called()
self.assertEqual(mock_batch_transfer.call_count, 0)
def test_transfer_skips_when_tp_not_sender(self): def test_transfer_skips_when_tp_not_sender(self):
@@ -161,8 +192,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
first_kv_cache=self.first_kv_cache, first_kv_cache=self.first_kv_cache,
callback_func=MagicMock()) callback_func=MagicMock())
req_meta = self.req_meta_base req_meta = self.req_meta_base
thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)), thread._transfer_kv_cache( # type: ignore
torch.zeros((1, 8))) "req3",
req_meta,
0,
torch.zeros((1, 8)),
torch.zeros((1, 8)),
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_not_called() self.engine.batch_transfer_sync_write.assert_not_called()
@patch( @patch(
@@ -172,25 +208,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize" "vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
) )
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group): def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
req_meta = self.req_meta_base req_meta = self.req_meta_base
req_meta.local_block_ids = [5, 6] req_meta.local_block_ids = [5, 6]
req_meta.remote_block_ids = [10, 11] req_meta.remote_block_ids = [10, 11]
req_meta.remote_kv_caches_base_addr = [ req_meta.remote_kv_caches_base_addr = [
7000, 8000, 9000, 10000, 11000, 12000 7000, 8000, 9000, 10000, 11000, 12000
] ]
req_meta.chunk_finish = True
key = torch.zeros((1, 8), dtype=torch.float32) key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32) value = torch.zeros((1, 8), dtype=torch.float32)
self.thread._transfer_kv_cache("req5", send_task = MagicMock()
req_meta, send_task.layer_index = self.thread.total_layers - 1
layer_index=2, send_task.send_request = {"req5": req_meta}
key=key,
value=value)
self.thread.callback_func.assert_called_once() with patch.object(self.thread, 'callback_func') as mock_callback_func:
self.thread._transfer_kv_cache( # type: ignore
req_id="req5",
req_meta=req_meta,
layer_index=send_task.layer_index,
key=key,
value=value,
reshape_cache_event=MagicMock())
print(f"Callback called: {mock_callback_func.call_count} times")
mock_callback_func.assert_called_once()
class TestKVCacheRecvingLayerThread(unittest.TestCase): class TestKVCacheRecvingLayerThread(unittest.TestCase):
@@ -468,6 +509,7 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
request = MockRequest("req1") request = MockRequest("req1")
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6]) self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request.kv_transfer_params = { request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3], "remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote", "remote_engine_id": "remote",
@@ -505,7 +547,8 @@ class _MockSchedulerOutput:
cached_new_block_ids=None, cached_new_block_ids=None,
cached_num_computed=None, cached_num_computed=None,
new_reqs=None, new_reqs=None,
num_sched=None): num_sched=None,
scheduled_spec_decode_tokens=None):
self.scheduled_cached_reqs = SimpleNamespace( self.scheduled_cached_reqs = SimpleNamespace(
req_ids=cached_req_ids or [], req_ids=cached_req_ids or [],
new_block_ids=cached_new_block_ids or [], new_block_ids=cached_new_block_ids or [],
@@ -513,6 +556,7 @@ class _MockSchedulerOutput:
) )
self.scheduled_new_reqs = new_reqs or [] self.scheduled_new_reqs = new_reqs or []
self.num_scheduled_tokens = num_sched or {} self.num_scheduled_tokens = num_sched or {}
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase): class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
@@ -549,43 +593,39 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True)) self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
def test_update_state_after_alloc_decode_records_send_layerwise(self): def test_update_state_after_alloc_decode_records_send_layerwise(self):
req = MockRequest("req_u2", req = MockRequest(
prompt_token_ids=list(range(10)), "req_u2",
kv_transfer_params={"do_remote_decode": True}) prompt_token_ids=list(range(10)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [] # 修改为空列表 []
})
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], )) blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
self.scheduler.update_state_after_alloc(req, self.scheduler.update_state_after_alloc(req,
blocks, blocks,
num_external_tokens=0) num_external_tokens=0)
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise) self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[ info = self.scheduler._reqs_need_send_layerwise["req_u2"]
"req_u2"] self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertEqual(total_tokens, 10) self.assertIs(info.request, req)
self.assertEqual(local_block_ids, [7, 8, 9]) self.assertEqual(info.remote_block_ids, [])
self.assertIs(req_ref, req) self.assertIsInstance(info.remote_block_ids, list)
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
req = MockRequest("req_b1",
kv_transfer_params={
"remote_block_ids": [1, 2],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
self.assertIn("req_b1", meta.requests)
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_build_connector_meta_accumulates_cached_blocks(self): def test_build_connector_meta_accumulates_cached_blocks(self):
req = MockRequest("req_b2", req_meta = MagicMock(spec=ReqMeta)
prompt_token_ids=list(range(8)), req_meta.local_block_ids = [1, 2, 3]
kv_transfer_params={"do_remote_decode": True}) req_meta.remote_block_ids = [4, 5]
req_meta.remote_engine_id = "remote"
req_meta.remote_host = "localhost"
req_meta.remote_port = 5000
req_meta.remote_te_rpc_port = 6000
req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False
self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req) req_meta.extend_local_block_ids = MagicMock()
self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta
out = _MockSchedulerOutput( out = _MockSchedulerOutput(
cached_req_ids=["req_b2"], cached_req_ids=["req_b2"],
@@ -596,47 +636,53 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
) )
meta = self.scheduler.build_connector_meta(out) meta = self.scheduler.build_connector_meta(out)
self.assertEqual(len(meta.requests), 0) self.assertEqual(len(meta.requests), 0)
total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[
"req_b2"]
self.assertEqual(total, 8)
self.assertEqual(block_ids, [1, 2, 3, 4])
def test_build_connector_meta_emits_when_tokens_reach_total(self): req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
req = MockRequest("req_b3", @patch(
prompt_token_ids=list(range(12)), "vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
kv_transfer_params={ )
"do_remote_decode": True, def test_build_connector_meta_emits_when_tokens_reach_total(
"remote_block_ids": [9], self, mock_group_concurrent_contiguous):
"remote_engine_id": "E", req_meta = MagicMock(spec=ReqMeta)
"remote_host": "H", req_meta.local_block_ids = [1, 2, 3]
"remote_port": 5555, req_meta.remote_block_ids = [4, 5]
"remote_te_rpc_port": 6000, req_meta.remote_engine_id = "remote"
"remote_kv_caches_base_addr": [10, 11], req_meta.remote_host = "localhost"
}) req_meta.remote_port = 5000
self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100, req_meta.remote_te_rpc_port = 6000
101], req) req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False
send_req_info = MagicMock(spec=SendReqInfo)
send_req_info.local_block_ids = [1, 2, 3]
send_req_info.remote_block_ids = [4, 5]
send_req_info.remote_cache_tokens = 100
send_req_info.local_transferred_tokens = 50
send_req_info.local_computed_tokens = 75
send_req_info.request = MagicMock()
send_req_info.extend_local_block_ids = MagicMock()
send_req_info.update_computed_tokens = MagicMock()
send_req_info.update_transferred_tokens = MagicMock()
send_req_info.unpack = MagicMock(
return_value=(send_req_info.local_block_ids,
send_req_info.remote_block_ids,
send_req_info.remote_cache_tokens,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request))
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
out = _MockSchedulerOutput( out = _MockSchedulerOutput(
cached_req_ids=["req_b3"], cached_req_ids=["req_b3"],
cached_new_block_ids=[([50], )], cached_new_block_ids=[([50], )],
cached_num_computed=[8], cached_num_computed=[8],
new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)], new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)],
num_sched={"req_b3": 4}, num_sched={"req_b3": 4},
) )
meta = self.scheduler.build_connector_meta(out) meta = self.scheduler.build_connector_meta(out)
send_req_info.extend_local_block_ids.assert_called_once_with([50])
self.assertIn("req_b3", meta.requests) self.assertIn("req_b3", meta.requests)
rmeta = meta.requests["req_b3"]
self.assertEqual(rmeta.local_block_ids, [100, 101, 50])
self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise)
def test_request_finished_returns_false_none(self):
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
[1, 2])
self.assertFalse(ok)
self.assertIsNone(params)
class TestHelperFunctions(unittest.TestCase): class TestHelperFunctions(unittest.TestCase):

View File

@@ -176,6 +176,8 @@ class AscendMetadata:
causal: bool = True causal: bool = True
# runner_type in model_config. # runner_type in model_config.
model_runner_type: str = "" model_runner_type: str = ""
# prefill reshape_and_cache event
reshape_cache_event: torch.npu.Event = None
# sliding window attention mask # sliding window attention mask
swa_mask: Optional[torch.Tensor] = None swa_mask: Optional[torch.Tensor] = None
@@ -333,6 +335,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None self.key_cache = None
self.value_cache = None self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor, def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata, value: torch.Tensor, attn_metadata: AscendMetadata,
@@ -654,6 +657,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
): ):
if len(kv_cache) > 1: if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None: if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping slots = attn_metadata.slot_mapping
@@ -674,6 +679,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
key_cache=self.key_cache, key_cache=self.key_cache,
value_cache=self.value_cache, value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens]) slot_indices=slots[:attn_metadata.num_actual_tokens])
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value return key, value
def forward_impl( def forward_impl(

View File

@@ -166,6 +166,7 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None prefill: Optional[AscendMLAPrefillMetadata] = None
reshape_cache_event: torch.npu.Event = None
def __post_init__(self): def __post_init__(self):
pass pass
@@ -705,6 +706,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: Optional[str],
**kwargs, **kwargs,
): ):
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
@@ -752,6 +754,8 @@ class AscendMLAImpl(MLAAttentionImpl):
self.speculative_config = self.vllm_config.speculative_config self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
def _v_up_proj(self, x): def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank) x = x.view(self.num_heads, -1, self.kv_lora_rank)
@@ -1351,8 +1355,12 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_slots = attn_metadata.slot_mapping[ prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens] num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill( prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = self.kv_b_proj( prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view( prefill_k_c_normed)[0].view(
-1, self.num_heads, -1, self.num_heads,

View File

@@ -65,6 +65,32 @@ class ReqMeta:
remote_te_rpc_port: Optional[int] remote_te_rpc_port: Optional[int]
remote_kv_caches_base_addr: Optional[list[int]] remote_kv_caches_base_addr: Optional[list[int]]
metaserver: Optional[str] metaserver: Optional[str]
chunk_finish: Optional[bool]
@dataclass
class SendReqInfo:
local_block_ids: list[int]
remote_block_ids: List[int]
remote_cache_tokens: int
local_transferred_tokens: int
local_computed_tokens: int
request: "Request"
def extend_local_block_ids(self, new_block_ids: List[int]) -> None:
"""extend local block ids for this step"""
self.local_block_ids.extend(new_block_ids)
def update_computed_tokens(self, computed_tokens: int) -> None:
"""update local computen tokens for this step"""
self.local_computed_tokens = computed_tokens
def update_transferred_tokens(self, transferred_tokens: int) -> None:
"""update transferred tokens for this step"""
self.local_transferred_tokens = transferred_tokens
def unpack(self):
return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request
@dataclass @dataclass
@@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread):
raise RuntimeError("Mooncake memory registration failed. ") raise RuntimeError("Mooncake memory registration failed. ")
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor, self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
torch.Tensor]]() torch.Tensor, torch.npu.Event]]()
self.ready_event = ready_event self.ready_event = ready_event
self.callback_func = callback_func self.callback_func = callback_func
@@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread):
torch.npu.set_device(device) torch.npu.set_device(device)
self.ready_event.set() self.ready_event.set()
while True: while True:
req_id, req_meta, layer_index, key, value = self.send_queue.get() req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
self._handle_request(req_id, req_meta, layer_index, key, value) )
self._handle_request(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
def _handle_request(self, req_id, req_meta, layer_index, key, value): def _handle_request(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
try: try:
logger.debug( logger.debug(
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
) )
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value) self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
logger.debug( logger.debug(
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}." f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
) )
@@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Failed to transfer KV cache for request " logger.error("Failed to transfer KV cache for request "
f"{req_id}: {e}") f"{req_id}: {e}")
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value): def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
# send kv layer to remote reshape_cache_event):
if len(req_meta.local_block_ids) == 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
)
return
# not need to send kv cache # not need to send kv cache
if self.tp_rank % self.num_head_replica != 0: if self.tp_rank % self.num_head_replica != 0:
logger.debug( logger.debug(
@@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread):
length_list.append(length) length_list.append(length)
if self.current_layer != layer_index: if self.current_layer != layer_index:
self.current_layer = layer_index self.current_layer = layer_index
self.model_stream.synchronize() """
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
reshape_cache_event.synchronize()
ret = self.engine.batch_transfer_sync_write( ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list) session_id, src_list, dst_list, length_list)
if ret < 0: if ret < 0:
@@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Mooncake transfer failed for request %s", req_id) logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if layer_index == (self.total_layers - 1): if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
self.callback_func(req_id, req_meta) self.callback_func(req_id, req_meta)
@@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
request_id: str, request_id: str,
local_block_ids: list[int], local_block_ids: list[int],
kv_transfer_params: dict[str, Any], kv_transfer_params: dict[str, Any],
token_ids: Optional[list[int]] = None): token_ids: Optional[list[int]] = None,
chunk_finish: bool = False):
self.requests[request_id] = ReqMeta( self.requests[request_id] = ReqMeta(
token_ids=token_ids or [], token_ids=token_ids or [],
local_block_ids=local_block_ids, local_block_ids=local_block_ids,
@@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
remote_kv_caches_base_addr=kv_transfer_params.get( remote_kv_caches_base_addr=kv_transfer_params.get(
"remote_kv_caches_base_addr", None), "remote_kv_caches_base_addr", None),
metaserver=kv_transfer_params.get("metaserver", None), metaserver=kv_transfer_params.get("metaserver", None),
) chunk_finish=chunk_finish)
class MooncakeLayerwiseConnector(KVConnectorBase_V1): class MooncakeLayerwiseConnector(KVConnectorBase_V1):
@@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1):
vllm_config: VllmConfig, vllm_config: VllmConfig,
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None): kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id self.engine_id = vllm_config.kv_transfer_config.engine_id
self._connector_metadata = MooncakeLayerwiseConnectorMetadata() self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
@@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler:
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int], self._reqs_need_recv: dict[str, tuple[Request, list[int],
list[int]]] = {} list[int]]] = {}
self._reqs_need_send_layerwise: dict[str, tuple[ self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {}
int, list[int],
Request]] = {} # req_id, (len(prompt), local_block_ids, request) self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000), timeout=None)
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
@@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler:
params["do_remote_prefill"] = False params["do_remote_prefill"] = False
logger.info(
f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=[],
request_id=request.request_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=params.get("metaserver", None),
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}")
future.add_done_callback(handle_exception)
# Layerwise prefiller add request need send # Layerwise prefiller add request need send
if params is not None and params.get("do_remote_decode"): if params is not None and params.get("do_remote_decode"):
local_block_ids = (blocks.get_block_ids()[0]) local_block_ids = (blocks.get_block_ids()[0])
logger.debug( logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue" f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
) )
self._reqs_need_send_layerwise[request.request_id] = (len( remote_block_ids = copy.deepcopy(params["remote_block_ids"])
request.all_token_ids), local_block_ids, request) remote_cache_tokens = (
(len(request.all_token_ids) + self.block_size - 1) //
self.block_size - len(remote_block_ids)) * self.block_size
local_transferred_tokens = remote_cache_tokens
local_computed_tokens = 0
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_cache_tokens=remote_cache_tokens,
local_transferred_tokens=local_transferred_tokens,
local_computed_tokens=local_computed_tokens,
request=request)
def build_connector_meta( def build_connector_meta(
self, self,
@@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler:
) -> KVConnectorMetadata: ) -> KVConnectorMetadata:
meta = MooncakeLayerwiseConnectorMetadata() meta = MooncakeLayerwiseConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta. if self.vllm_config.kv_transfer_config.is_kv_consumer:
for req_id, (req, token_ids, # Loop through scheduled reqs and convert to ReqMeta.
block_ids) in self._reqs_need_recv.items(): for req_id, (req, token_ids,
assert req.kv_transfer_params is not None block_ids) in self._reqs_need_recv.items():
# For the case where there are no remote blocks to pull assert req.kv_transfer_params is not None
# (block_ids is empty), we don't need to schedule # For the case where there are no remote blocks to pull
# an async read on the worker side. # (block_ids is empty), we don't need to schedule
meta.add_new_req(request_id=req_id, # an async read on the worker side.
local_block_ids=block_ids, meta.add_new_req(request_id=req_id,
kv_transfer_params=req.kv_transfer_params, local_block_ids=block_ids,
token_ids=token_ids) kv_transfer_params=req.kv_transfer_params,
token_ids=token_ids)
# Clear the list once workers start the transfers # Clear the list once workers start the transfers
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
else:
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
# update local block ids
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
self._reqs_need_send_layerwise[
req_id].extend_local_block_ids(new_blocks[0])
cached_reqs = scheduler_output.scheduled_cached_reqs computed_tokens = dict(
new_reqs = scheduler_output.scheduled_new_reqs list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
for req_id, new_blocks in zip(cached_reqs.req_ids, + [(x.req_id, x.num_computed_tokens) for x in new_reqs])
cached_reqs.new_block_ids): for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
if req_id in self._reqs_need_send_layerwise and new_blocks is not None: ):
total_tokens, block_ids, req = self._reqs_need_send_layerwise[ if req_id in self._reqs_need_send_layerwise:
req_id] send_req_info = self._reqs_need_send_layerwise[req_id]
block_ids.extend(new_blocks[0]) # update local computed tokens, not transfer spec decode tokens
spec_decode_tokens = len(
scheduled_spec_decode_tokens[req_id]) if (
req_id in scheduled_spec_decode_tokens) else 0
send_req_info.update_computed_tokens(
computed_tokens.get(req_id, 0) + scheduled_tokens -
spec_decode_tokens)
computed_tokens = dict( def add_tranfer_task(req_id,
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) + send_req_info: SendReqInfo,
[(x.req_id, x.num_computed_tokens) for x in new_reqs]) chunk_finish=False):
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items( local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack(
): )
if req_id in self._reqs_need_send_layerwise: local_trans_block_ids = local_block_ids[(
total_tokens, block_ids, req = self._reqs_need_send_layerwise[ local_transferred_tokens //
req_id] self.block_size):(local_computed_tokens //
current_tokens = computed_tokens.get(req_id, self.block_size)]
0) + scheduled_tokens remote_trans_block_ids = remote_block_ids[(
if current_tokens >= total_tokens: (local_transferred_tokens - remote_cache_tokens) //
logger.debug( self.block_size):((local_computed_tokens -
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" remote_cache_tokens) //
) self.block_size)]
meta.add_new_req(request_id=req_id, request.kv_transfer_params[
local_block_ids=block_ids, "remote_block_ids"] = remote_trans_block_ids
kv_transfer_params=req.kv_transfer_params, assert len(local_trans_block_ids) == len(
token_ids=[]) remote_trans_block_ids
self._reqs_need_send_layerwise.pop(req_id) ), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}"
else: adjusted_tokens = local_computed_tokens - (
logger.debug( self.block_size -
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})" 1) if chunk_finish else local_computed_tokens
) logger.info(
f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}"
)
meta.add_new_req(
request_id=req_id,
local_block_ids=local_trans_block_ids,
kv_transfer_params=request.kv_transfer_params,
token_ids=[],
chunk_finish=chunk_finish)
# update local_transferred_tokens
local_transferred_tokens = (
local_computed_tokens //
self.block_size) * self.block_size
send_req_info.update_transferred_tokens(
local_transferred_tokens)
# no chunk or last chunk
if send_req_info.local_computed_tokens >= len(
send_req_info.request.all_token_ids):
send_req_info.update_computed_tokens(
send_req_info.local_computed_tokens +
self.block_size - 1)
add_tranfer_task(req_id,
send_req_info,
chunk_finish=True)
self._reqs_need_send_layerwise.pop(req_id)
# chunk
elif (send_req_info.local_computed_tokens //
self.block_size) - (
send_req_info.local_transferred_tokens //
self.block_size) > 0:
add_tranfer_task(req_id, send_req_info)
return meta return meta
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def request_finished( def request_finished(
self, self,
request: "Request", request: "Request",
@@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker:
self.total_layers = vllm_config.model_config.get_num_layers( self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config) vllm_config.parallel_config)
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None) if self.tp_rank == 0 else None
# Handshake base port # Handshake base port
self.side_channel_port = ( self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port + vllm_config.kv_transfer_config.kv_port +
@@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_recv_layer_thread.start() self.kv_recv_layer_thread.start()
ready_event.wait() ready_event.wait()
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def get_finished(self) -> tuple[set[str], set[str]]: def get_finished(self) -> tuple[set[str], set[str]]:
done_recving = ( done_recving = (
self.kv_recv_layer_thread. self.kv_recv_layer_thread.
@@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker:
self.current_layer = 0 self.current_layer = 0
if self.vllm_config.kv_transfer_config.is_kv_consumer: if self.vllm_config.kv_transfer_config.is_kv_consumer:
for req_id, meta in metadata.requests.items(): for req_id, meta in metadata.requests.items():
if self.tp_rank % self.tp_size == 0:
logger.info(
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=meta.token_ids,
request_id=req_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=meta.local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=meta.metaserver,
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}"
)
future.add_done_callback(handle_exception)
assert self.kv_recv_layer_thread is not None assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock: with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[req_id] = 0 self.kv_recv_layer_thread.task_tracker[req_id] = 0
@@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys( if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
): ):
# enable decode prefix cache # enable decode prefix cache
for request in connector_metadata.requests.values(): if self.use_mla:
assert len(request.local_block_ids) >= len( reshape_cache_event = attn_metadata[
request.remote_block_ids layer_name].reshape_cache_event
), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num." else:
request.local_block_ids = request.local_block_ids[ reshape_cache_event = attn_metadata.reshape_cache_event
-len(request.remote_block_ids):]
if self.pd_head_ratio != 1: if self.pd_head_ratio != 1:
def sort_kv_cache(input_kv: list[list[int]]): def sort_kv_cache(input_kv: list[list[int]]):
@@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker:
f"Add request {req_id} to kv send layer thread. {req_meta_update=}" f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
) )
assert self.kv_send_layer_thread is not None assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
self.kv_send_layer_thread.send_queue.put( self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value)) (req_id, req_meta_update, self.current_layer, key, value,
reshape_cache_event))
self.current_layer += 1 self.current_layer += 1
def _get_remote_socket( def _get_remote_socket(