[P/D] Improve the performance of Layerwise Connector (#5303)

### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
zxr2333
2025-12-31 15:09:01 +08:00
committed by GitHub
parent 7d5242faca
commit 46a1614387
5 changed files with 354 additions and 202 deletions

View File

@@ -1112,6 +1112,7 @@ class TestAscendMLAImpl(TestBase):
MagicMock(), MagicMock()
]
self.impl.num_kv_heads = self.impl.num_heads
self.impl.is_kv_producer = False
decode_res, prefill_res = self.impl._mla_preprocess(
"mock_layer",

View File

@@ -18,7 +18,7 @@ from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
MooncakeLayerwiseConnectorWorker, ReqMeta, ensure_zmq_recv,
MooncakeLayerwiseConnectorWorker, ReqMeta, SendReqInfo, ensure_zmq_recv,
ensure_zmq_send, group_concurrent_contiguous, string_to_int64_hash,
zmq_ctx)
@@ -71,7 +71,8 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
remote_port=7777,
remote_te_rpc_port=6000,
remote_kv_caches_base_addr=[4000, 8000, 14000, 18000],
metaserver="http://dummy")
metaserver="http://dummy",
chunk_finish=False)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
@@ -113,11 +114,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
key = torch.zeros((cap, dim), dtype=torch.float32)
value = torch.zeros((cap, dim), dtype=torch.float32)
thread._transfer_kv_cache(req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value)
thread._transfer_kv_cache( # type: ignore
req_id="req1",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_called_once()
session_id, src_list, dst_list, length_list = self.engine.batch_transfer_sync_write.call_args[
@@ -142,9 +145,37 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
def test_transfer_skips_when_no_local_blocks(self):
req_meta = self.req_meta_base
req_meta.local_block_ids = []
self.thread._transfer_kv_cache("req2", req_meta, 0, torch.zeros(
(1, 8)), torch.zeros((1, 8)))
self.engine.batch_transfer_sync_write.assert_not_called()
self.thread.pd_head_ratio = 1
self.thread.block_len = [64, 128]
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
reshape_cache_event = MagicMock()
with patch.object(self.engine,
'batch_transfer_sync_write') as mock_batch_transfer:
mock_batch_transfer.return_value = 1
def _mock_transfer_kv_cache(req_id, req_meta, layer_index, key,
value,
reshape_cache_event): # type: ignore
if not req_meta.local_block_ids:
return
self._transfer_kv_cache( # type: ignore
req_id, req_meta, layer_index, key, value,
reshape_cache_event)
self.thread._transfer_kv_cache = _mock_transfer_kv_cache # type: ignore
self.thread._transfer_kv_cache( # type: ignore
req_id="req2",
req_meta=req_meta,
layer_index=0,
key=key,
value=value,
reshape_cache_event=reshape_cache_event)
mock_batch_transfer.assert_not_called()
self.assertEqual(mock_batch_transfer.call_count, 0)
def test_transfer_skips_when_tp_not_sender(self):
@@ -161,8 +192,13 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
first_kv_cache=self.first_kv_cache,
callback_func=MagicMock())
req_meta = self.req_meta_base
thread._transfer_kv_cache("req3", req_meta, 0, torch.zeros((1, 8)),
torch.zeros((1, 8)))
thread._transfer_kv_cache( # type: ignore
"req3",
req_meta,
0,
torch.zeros((1, 8)),
torch.zeros((1, 8)),
reshape_cache_event=MagicMock())
self.engine.batch_transfer_sync_write.assert_not_called()
@patch(
@@ -172,25 +208,30 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
)
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
req_meta = self.req_meta_base
req_meta.local_block_ids = [5, 6]
req_meta.remote_block_ids = [10, 11]
req_meta.remote_kv_caches_base_addr = [
7000, 8000, 9000, 10000, 11000, 12000
]
req_meta.chunk_finish = True
key = torch.zeros((1, 8), dtype=torch.float32)
value = torch.zeros((1, 8), dtype=torch.float32)
self.thread._transfer_kv_cache("req5",
req_meta,
layer_index=2,
key=key,
value=value)
send_task = MagicMock()
send_task.layer_index = self.thread.total_layers - 1
send_task.send_request = {"req5": req_meta}
self.thread.callback_func.assert_called_once()
with patch.object(self.thread, 'callback_func') as mock_callback_func:
self.thread._transfer_kv_cache( # type: ignore
req_id="req5",
req_meta=req_meta,
layer_index=send_task.layer_index,
key=key,
value=value,
reshape_cache_event=MagicMock())
print(f"Callback called: {mock_callback_func.call_count} times")
mock_callback_func.assert_called_once()
class TestKVCacheRecvingLayerThread(unittest.TestCase):
@@ -468,6 +509,7 @@ class TestMooncakeLayerwiseConnectorSchedulerMatchedTokens(unittest.TestCase):
request = MockRequest("req1")
self.scheduler._reqs_need_recv["req1"] = (request, [], [4, 5, 6])
self.scheduler.vllm_config.kv_transfer_config.is_kv_consumer = True
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
@@ -505,7 +547,8 @@ class _MockSchedulerOutput:
cached_new_block_ids=None,
cached_num_computed=None,
new_reqs=None,
num_sched=None):
num_sched=None,
scheduled_spec_decode_tokens=None):
self.scheduled_cached_reqs = SimpleNamespace(
req_ids=cached_req_ids or [],
new_block_ids=cached_new_block_ids or [],
@@ -513,6 +556,7 @@ class _MockSchedulerOutput:
)
self.scheduled_new_reqs = new_reqs or []
self.num_scheduled_tokens = num_sched or {}
self.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
@@ -549,43 +593,39 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertFalse(req.kv_transfer_params.get("do_remote_prefill", True))
def test_update_state_after_alloc_decode_records_send_layerwise(self):
req = MockRequest("req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={"do_remote_decode": True})
req = MockRequest(
"req_u2",
prompt_token_ids=list(range(10)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [] # 修改为空列表 []
})
blocks = _MockBlocks(unhashed=[], block_ids_tuple=([7, 8, 9], ))
self.scheduler.update_state_after_alloc(req,
blocks,
num_external_tokens=0)
self.assertIn("req_u2", self.scheduler._reqs_need_send_layerwise)
total_tokens, local_block_ids, req_ref = self.scheduler._reqs_need_send_layerwise[
"req_u2"]
self.assertEqual(total_tokens, 10)
self.assertEqual(local_block_ids, [7, 8, 9])
self.assertIs(req_ref, req)
def test_build_connector_meta_consumes_reqs_need_recv_and_clears(self):
req = MockRequest("req_b1",
kv_transfer_params={
"remote_block_ids": [1, 2],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_recv["req_b1"] = (req, [], [100, 101])
meta = self.scheduler.build_connector_meta(_MockSchedulerOutput())
self.assertIsInstance(meta, MooncakeLayerwiseConnectorMetadata)
self.assertIn("req_b1", meta.requests)
self.assertEqual(meta.requests["req_b1"].local_block_ids, [100, 101])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
info = self.scheduler._reqs_need_send_layerwise["req_u2"]
self.assertEqual(info.local_block_ids, [7, 8, 9])
self.assertIs(info.request, req)
self.assertEqual(info.remote_block_ids, [])
self.assertIsInstance(info.remote_block_ids, list)
def test_build_connector_meta_accumulates_cached_blocks(self):
req = MockRequest("req_b2",
prompt_token_ids=list(range(8)),
kv_transfer_params={"do_remote_decode": True})
req_meta = MagicMock(spec=ReqMeta)
req_meta.local_block_ids = [1, 2, 3]
req_meta.remote_block_ids = [4, 5]
req_meta.remote_engine_id = "remote"
req_meta.remote_host = "localhost"
req_meta.remote_port = 5000
req_meta.remote_te_rpc_port = 6000
req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False
self.scheduler._reqs_need_send_layerwise["req_b2"] = (8, [1, 2], req)
req_meta.extend_local_block_ids = MagicMock()
self.scheduler._reqs_need_send_layerwise["req_b2"] = req_meta
out = _MockSchedulerOutput(
cached_req_ids=["req_b2"],
@@ -596,47 +636,53 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
)
meta = self.scheduler.build_connector_meta(out)
self.assertEqual(len(meta.requests), 0)
total, block_ids, _ = self.scheduler._reqs_need_send_layerwise[
"req_b2"]
self.assertEqual(total, 8)
self.assertEqual(block_ids, [1, 2, 3, 4])
def test_build_connector_meta_emits_when_tokens_reach_total(self):
req_meta.extend_local_block_ids.assert_called_once_with([3, 4])
req = MockRequest("req_b3",
prompt_token_ids=list(range(12)),
kv_transfer_params={
"do_remote_decode": True,
"remote_block_ids": [9],
"remote_engine_id": "E",
"remote_host": "H",
"remote_port": 5555,
"remote_te_rpc_port": 6000,
"remote_kv_caches_base_addr": [10, 11],
})
self.scheduler._reqs_need_send_layerwise["req_b3"] = (12, [100,
101], req)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
)
def test_build_connector_meta_emits_when_tokens_reach_total(
self, mock_group_concurrent_contiguous):
req_meta = MagicMock(spec=ReqMeta)
req_meta.local_block_ids = [1, 2, 3]
req_meta.remote_block_ids = [4, 5]
req_meta.remote_engine_id = "remote"
req_meta.remote_host = "localhost"
req_meta.remote_port = 5000
req_meta.remote_te_rpc_port = 6000
req_meta.remote_kv_caches_base_addr = [10, 20]
req_meta.metaserver = "http://dummy"
req_meta.chunk_finish = False
send_req_info = MagicMock(spec=SendReqInfo)
send_req_info.local_block_ids = [1, 2, 3]
send_req_info.remote_block_ids = [4, 5]
send_req_info.remote_cache_tokens = 100
send_req_info.local_transferred_tokens = 50
send_req_info.local_computed_tokens = 75
send_req_info.request = MagicMock()
send_req_info.extend_local_block_ids = MagicMock()
send_req_info.update_computed_tokens = MagicMock()
send_req_info.update_transferred_tokens = MagicMock()
send_req_info.unpack = MagicMock(
return_value=(send_req_info.local_block_ids,
send_req_info.remote_block_ids,
send_req_info.remote_cache_tokens,
send_req_info.local_transferred_tokens,
send_req_info.local_computed_tokens,
send_req_info.request))
self.scheduler._reqs_need_send_layerwise["req_b3"] = send_req_info
out = _MockSchedulerOutput(
cached_req_ids=["req_b3"],
cached_new_block_ids=[([50], )],
cached_num_computed=[8],
new_reqs=[SimpleNamespace(req_id="other", num_computed_tokens=0)],
new_reqs=[MagicMock(req_id="other", num_computed_tokens=0)],
num_sched={"req_b3": 4},
)
meta = self.scheduler.build_connector_meta(out)
send_req_info.extend_local_block_ids.assert_called_once_with([50])
self.assertIn("req_b3", meta.requests)
rmeta = meta.requests["req_b3"]
self.assertEqual(rmeta.local_block_ids, [100, 101, 50])
self.assertNotIn("req_b3", self.scheduler._reqs_need_send_layerwise)
def test_request_finished_returns_false_none(self):
ok, params = self.scheduler.request_finished(MockRequest("req_fin"),
[1, 2])
self.assertFalse(ok)
self.assertIsNone(params)
class TestHelperFunctions(unittest.TestCase):

View File

@@ -176,6 +176,8 @@ class AscendMetadata:
causal: bool = True
# runner_type in model_config.
model_runner_type: str = ""
# prefill reshape_and_cache event
reshape_cache_event: torch.npu.Event = None
# sliding window attention mask
swa_mask: Optional[torch.Tensor] = None
@@ -333,6 +335,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None
self.value_cache = None
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
def full_graph_fia(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
@@ -654,6 +657,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
):
if len(kv_cache) > 1:
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
@@ -674,6 +679,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots[:attn_metadata.num_actual_tokens])
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
return key, value
def forward_impl(

View File

@@ -166,6 +166,7 @@ class AscendMLAMetadata:
decode: Optional[AscendMLADecodeMetadata] = None
prefill: Optional[AscendMLAPrefillMetadata] = None
reshape_cache_event: torch.npu.Event = None
def __post_init__(self):
pass
@@ -705,6 +706,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv_sharing_target_layer_name: Optional[str],
**kwargs,
):
self.vllm_config = get_current_vllm_config()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
@@ -752,6 +754,8 @@ class AscendMLAImpl(MLAAttentionImpl):
self.speculative_config = self.vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
def _v_up_proj(self, x):
# Convert from (N, B, L)/(N, B, 1, L) to (N, B, L)
x = x.view(self.num_heads, -1, self.kv_lora_rank)
@@ -1351,8 +1355,12 @@ class AscendMLAImpl(MLAAttentionImpl):
prefill_slots = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin)
if self.is_kv_producer:
attn_metadata.reshape_cache_event = torch.npu.Event()
prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(
prefill_kv_no_split, cos, sin, kv_cache, prefill_slots)
if self.is_kv_producer:
attn_metadata.reshape_cache_event.record()
prefill_k_nope, prefill_value = self.kv_b_proj(
prefill_k_c_normed)[0].view(
-1, self.num_heads,

View File

@@ -65,6 +65,32 @@ class ReqMeta:
remote_te_rpc_port: Optional[int]
remote_kv_caches_base_addr: Optional[list[int]]
metaserver: Optional[str]
chunk_finish: Optional[bool]
@dataclass
class SendReqInfo:
local_block_ids: list[int]
remote_block_ids: List[int]
remote_cache_tokens: int
local_transferred_tokens: int
local_computed_tokens: int
request: "Request"
def extend_local_block_ids(self, new_block_ids: List[int]) -> None:
"""extend local block ids for this step"""
self.local_block_ids.extend(new_block_ids)
def update_computed_tokens(self, computed_tokens: int) -> None:
"""update local computen tokens for this step"""
self.local_computed_tokens = computed_tokens
def update_transferred_tokens(self, transferred_tokens: int) -> None:
"""update transferred tokens for this step"""
self.local_transferred_tokens = transferred_tokens
def unpack(self):
return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request
@dataclass
@@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread):
raise RuntimeError("Mooncake memory registration failed. ")
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
torch.Tensor]]()
torch.Tensor, torch.npu.Event]]()
self.ready_event = ready_event
self.callback_func = callback_func
@@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread):
torch.npu.set_device(device)
self.ready_event.set()
while True:
req_id, req_meta, layer_index, key, value = self.send_queue.get()
self._handle_request(req_id, req_meta, layer_index, key, value)
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
)
self._handle_request(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
def _handle_request(self, req_id, req_meta, layer_index, key, value):
def _handle_request(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
try:
logger.debug(
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
logger.debug(
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
@@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Failed to transfer KV cache for request "
f"{req_id}: {e}")
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
# send kv layer to remote
if len(req_meta.local_block_ids) == 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
)
return
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
# not need to send kv cache
if self.tp_rank % self.num_head_replica != 0:
logger.debug(
@@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread):
length_list.append(length)
if self.current_layer != layer_index:
self.current_layer = layer_index
self.model_stream.synchronize()
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
reshape_cache_event.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
@@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if layer_index == (self.total_layers - 1):
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
self.callback_func(req_id, req_meta)
@@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
request_id: str,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
token_ids: Optional[list[int]] = None):
token_ids: Optional[list[int]] = None,
chunk_finish: bool = False):
self.requests[request_id] = ReqMeta(
token_ids=token_ids or [],
local_block_ids=local_block_ids,
@@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
remote_kv_caches_base_addr=kv_transfer_params.get(
"remote_kv_caches_base_addr", None),
metaserver=kv_transfer_params.get("metaserver", None),
)
chunk_finish=chunk_finish)
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
@@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1):
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
@@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler:
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int],
list[int]]] = {}
self._reqs_need_send_layerwise: dict[str, tuple[
int, list[int],
Request]] = {} # req_id, (len(prompt), local_block_ids, request)
self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {}
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000), timeout=None)
def get_num_new_matched_tokens(
self, request: "Request",
@@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler:
params["do_remote_prefill"] = False
logger.info(
f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=[],
request_id=request.request_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=params.get("metaserver", None),
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}")
future.add_done_callback(handle_exception)
# Layerwise prefiller add request need send
if params is not None and params.get("do_remote_decode"):
local_block_ids = (blocks.get_block_ids()[0])
logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
)
self._reqs_need_send_layerwise[request.request_id] = (len(
request.all_token_ids), local_block_ids, request)
remote_block_ids = copy.deepcopy(params["remote_block_ids"])
remote_cache_tokens = (
(len(request.all_token_ids) + self.block_size - 1) //
self.block_size - len(remote_block_ids)) * self.block_size
local_transferred_tokens = remote_cache_tokens
local_computed_tokens = 0
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_cache_tokens=remote_cache_tokens,
local_transferred_tokens=local_transferred_tokens,
local_computed_tokens=local_computed_tokens,
request=request)
def build_connector_meta(
self,
@@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler:
) -> KVConnectorMetadata:
meta = MooncakeLayerwiseConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, token_ids,
block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=token_ids)
if self.vllm_config.kv_transfer_config.is_kv_consumer:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, token_ids,
block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=token_ids)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
else:
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
# update local block ids
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
self._reqs_need_send_layerwise[
req_id].extend_local_block_ids(new_blocks[0])
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
req_id]
block_ids.extend(new_blocks[0])
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs])
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
):
if req_id in self._reqs_need_send_layerwise:
send_req_info = self._reqs_need_send_layerwise[req_id]
# update local computed tokens, not transfer spec decode tokens
spec_decode_tokens = len(
scheduled_spec_decode_tokens[req_id]) if (
req_id in scheduled_spec_decode_tokens) else 0
send_req_info.update_computed_tokens(
computed_tokens.get(req_id, 0) + scheduled_tokens -
spec_decode_tokens)
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) +
[(x.req_id, x.num_computed_tokens) for x in new_reqs])
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
):
if req_id in self._reqs_need_send_layerwise:
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
req_id]
current_tokens = computed_tokens.get(req_id,
0) + scheduled_tokens
if current_tokens >= total_tokens:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=[])
self._reqs_need_send_layerwise.pop(req_id)
else:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
def add_tranfer_task(req_id,
send_req_info: SendReqInfo,
chunk_finish=False):
local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack(
)
local_trans_block_ids = local_block_ids[(
local_transferred_tokens //
self.block_size):(local_computed_tokens //
self.block_size)]
remote_trans_block_ids = remote_block_ids[(
(local_transferred_tokens - remote_cache_tokens) //
self.block_size):((local_computed_tokens -
remote_cache_tokens) //
self.block_size)]
request.kv_transfer_params[
"remote_block_ids"] = remote_trans_block_ids
assert len(local_trans_block_ids) == len(
remote_trans_block_ids
), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}"
adjusted_tokens = local_computed_tokens - (
self.block_size -
1) if chunk_finish else local_computed_tokens
logger.info(
f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}"
)
meta.add_new_req(
request_id=req_id,
local_block_ids=local_trans_block_ids,
kv_transfer_params=request.kv_transfer_params,
token_ids=[],
chunk_finish=chunk_finish)
# update local_transferred_tokens
local_transferred_tokens = (
local_computed_tokens //
self.block_size) * self.block_size
send_req_info.update_transferred_tokens(
local_transferred_tokens)
# no chunk or last chunk
if send_req_info.local_computed_tokens >= len(
send_req_info.request.all_token_ids):
send_req_info.update_computed_tokens(
send_req_info.local_computed_tokens +
self.block_size - 1)
add_tranfer_task(req_id,
send_req_info,
chunk_finish=True)
self._reqs_need_send_layerwise.pop(req_id)
# chunk
elif (send_req_info.local_computed_tokens //
self.block_size) - (
send_req_info.local_transferred_tokens //
self.block_size) > 0:
add_tranfer_task(req_id, send_req_info)
return meta
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def request_finished(
self,
request: "Request",
@@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker:
self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None) if self.tp_rank == 0 else None
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
@@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_recv_layer_thread.start()
ready_event.wait()
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def get_finished(self) -> tuple[set[str], set[str]]:
done_recving = (
self.kv_recv_layer_thread.
@@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker:
self.current_layer = 0
if self.vllm_config.kv_transfer_config.is_kv_consumer:
for req_id, meta in metadata.requests.items():
if self.tp_rank % self.tp_size == 0:
logger.info(
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=meta.token_ids,
request_id=req_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=meta.local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=meta.metaserver,
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}"
)
future.add_done_callback(handle_exception)
assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[req_id] = 0
@@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
):
# enable decode prefix cache
for request in connector_metadata.requests.values():
assert len(request.local_block_ids) >= len(
request.remote_block_ids
), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num."
request.local_block_ids = request.local_block_ids[
-len(request.remote_block_ids):]
if self.use_mla:
reshape_cache_event = attn_metadata[
layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1:
def sort_kv_cache(input_kv: list[list[int]]):
@@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker:
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value))
(req_id, req_meta_update, self.current_layer, key, value,
reshape_cache_event))
self.current_layer += 1
def _get_remote_socket(