[P/D][BugFix] Fix layerwise P/D request_id error (#6360)

### What this PR does / why we need it?
Fix layerwise Connector P/D request_id error, due to vllm pr:
https://github.com/vllm-project/vllm/pull/27987, which will add a random
suffix to request_id in EngineCore.

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
zxr2333
2026-01-29 20:19:05 +08:00
committed by GitHub
parent feab047084
commit 14bd55f30c
2 changed files with 25 additions and 6 deletions

View File

@@ -240,6 +240,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
with th.lock:
th.task_tracker["reqX"] = 0
th.request_map["reqX"] = "reqX"
th.update_task("reqX")
with th.lock:
@@ -313,6 +314,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
with th.lock:
th.task_tracker["reqA"] = 0
th.request_map["reqA"] = "reqA"
with self.assertRaises(SystemExit):
th.run()
@@ -386,6 +388,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
ready_event=self.ready_event)
with th.lock:
th.task_tracker["reqB"] = 0
th.request_map["reqB"] = "reqB"
with self.assertRaises(SystemExit):
th.run()

View File

@@ -355,6 +355,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.lock = threading.Lock()
self.done_requests = set[str]()
self.task_tracker = dict[str, int]()
self.request_map = dict[str, str]()
self.ready_event = ready_event
self.metadata = metadata
@@ -374,7 +375,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.task_tracker[req_id] += 1
if self.task_tracker[req_id] == self.pd_head_ratio:
self.task_tracker.pop(req_id)
self.done_requests.add(req_id)
self.done_requests.add(self.request_map[req_id])
self.request_map.pop(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
@@ -615,9 +617,11 @@ class MooncakeLayerwiseConnectorScheduler:
logger.info(f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}")
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
# change the format of request_id if vllm-version >= 0.14.0
external_req_id = get_external_request_id(request.request_id)
kv_transfer_params = dict(
token_ids=[],
request_id=request.request_id,
request_id=external_req_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=local_block_ids,
@@ -1050,9 +1054,11 @@ class MooncakeLayerwiseConnectorWorker:
self.current_layer = 0
if self.vllm_config.kv_transfer_config.is_kv_consumer:
for req_id, meta in metadata.requests.items():
external_req_id = get_external_request_id(req_id)
assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[req_id] = 0
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
def save_kv_layer(
self,
@@ -1213,13 +1219,17 @@ class MooncakeLayerwiseConnectorWorker:
return req_meta_update
def send_done_send_signal(self, req_id, req_meta):
external_req_id = get_external_request_id(req_id)
logger.info(
"Sending done sending signal for request %s to %s:%d", req_id, req_meta.remote_host, req_meta.remote_port
"Sending done sending signal for request %s to %s:%d",
external_req_id,
req_meta.remote_host,
req_meta.remote_port,
)
try:
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
msg_encoder = msgspec.msgpack.Encoder()
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id))
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data)
ack = sock.recv()
@@ -1227,7 +1237,7 @@ class MooncakeLayerwiseConnectorWorker:
raise ValueError(f"Unexpected ACK response: {ack}")
except Exception as e:
logger.error(
f"Sending done sending signal for request {req_id} to "
f"Sending done sending signal for request {external_req_id} to "
f"{req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}"
)
@@ -1338,3 +1348,9 @@ def ensure_zmq_recv(
else:
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
def get_external_request_id(request_id: str):
# NOTE(zxr): vLLM PR #27987 add additional suffix
# to EngineCore request_id with len(suffix) == 9
return request_id[:-9]