[P/D][BugFix] Fix layerwise P/D request_id error (#6360)
### What this PR does / why we need it?
Fix layerwise Connector P/D request_id error, due to vllm pr:
https://github.com/vllm-project/vllm/pull/27987, which will add a random
suffix to request_id in EngineCore.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -240,6 +240,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
|||||||
|
|
||||||
with th.lock:
|
with th.lock:
|
||||||
th.task_tracker["reqX"] = 0
|
th.task_tracker["reqX"] = 0
|
||||||
|
th.request_map["reqX"] = "reqX"
|
||||||
|
|
||||||
th.update_task("reqX")
|
th.update_task("reqX")
|
||||||
with th.lock:
|
with th.lock:
|
||||||
@@ -313,6 +314,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
|||||||
|
|
||||||
with th.lock:
|
with th.lock:
|
||||||
th.task_tracker["reqA"] = 0
|
th.task_tracker["reqA"] = 0
|
||||||
|
th.request_map["reqA"] = "reqA"
|
||||||
|
|
||||||
with self.assertRaises(SystemExit):
|
with self.assertRaises(SystemExit):
|
||||||
th.run()
|
th.run()
|
||||||
@@ -386,6 +388,7 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
|
|||||||
ready_event=self.ready_event)
|
ready_event=self.ready_event)
|
||||||
with th.lock:
|
with th.lock:
|
||||||
th.task_tracker["reqB"] = 0
|
th.task_tracker["reqB"] = 0
|
||||||
|
th.request_map["reqB"] = "reqB"
|
||||||
with self.assertRaises(SystemExit):
|
with self.assertRaises(SystemExit):
|
||||||
th.run()
|
th.run()
|
||||||
|
|
||||||
|
|||||||
@@ -355,6 +355,7 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
|||||||
self.lock = threading.Lock()
|
self.lock = threading.Lock()
|
||||||
self.done_requests = set[str]()
|
self.done_requests = set[str]()
|
||||||
self.task_tracker = dict[str, int]()
|
self.task_tracker = dict[str, int]()
|
||||||
|
self.request_map = dict[str, str]()
|
||||||
self.ready_event = ready_event
|
self.ready_event = ready_event
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
|
|
||||||
@@ -374,7 +375,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
|||||||
self.task_tracker[req_id] += 1
|
self.task_tracker[req_id] += 1
|
||||||
if self.task_tracker[req_id] == self.pd_head_ratio:
|
if self.task_tracker[req_id] == self.pd_head_ratio:
|
||||||
self.task_tracker.pop(req_id)
|
self.task_tracker.pop(req_id)
|
||||||
self.done_requests.add(req_id)
|
self.done_requests.add(self.request_map[req_id])
|
||||||
|
self.request_map.pop(req_id)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Run the thread to handle KV cache transfer requests."""
|
"""Run the thread to handle KV cache transfer requests."""
|
||||||
@@ -615,9 +617,11 @@ class MooncakeLayerwiseConnectorScheduler:
|
|||||||
logger.info(f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}")
|
logger.info(f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}")
|
||||||
# All parameters here should appear in the returned dict of
|
# All parameters here should appear in the returned dict of
|
||||||
# request_finished in the scheduler side except "request_id".
|
# request_finished in the scheduler side except "request_id".
|
||||||
|
# change the format of request_id if vllm-version >= 0.14.0
|
||||||
|
external_req_id = get_external_request_id(request.request_id)
|
||||||
kv_transfer_params = dict(
|
kv_transfer_params = dict(
|
||||||
token_ids=[],
|
token_ids=[],
|
||||||
request_id=request.request_id,
|
request_id=external_req_id,
|
||||||
do_remote_prefill=False,
|
do_remote_prefill=False,
|
||||||
do_remote_decode=True,
|
do_remote_decode=True,
|
||||||
remote_block_ids=local_block_ids,
|
remote_block_ids=local_block_ids,
|
||||||
@@ -1050,9 +1054,11 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.current_layer = 0
|
self.current_layer = 0
|
||||||
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||||
for req_id, meta in metadata.requests.items():
|
for req_id, meta in metadata.requests.items():
|
||||||
|
external_req_id = get_external_request_id(req_id)
|
||||||
assert self.kv_recv_layer_thread is not None
|
assert self.kv_recv_layer_thread is not None
|
||||||
with self.kv_recv_layer_thread.lock:
|
with self.kv_recv_layer_thread.lock:
|
||||||
self.kv_recv_layer_thread.task_tracker[req_id] = 0
|
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
||||||
|
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
||||||
|
|
||||||
def save_kv_layer(
|
def save_kv_layer(
|
||||||
self,
|
self,
|
||||||
@@ -1213,13 +1219,17 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
return req_meta_update
|
return req_meta_update
|
||||||
|
|
||||||
def send_done_send_signal(self, req_id, req_meta):
|
def send_done_send_signal(self, req_id, req_meta):
|
||||||
|
external_req_id = get_external_request_id(req_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Sending done sending signal for request %s to %s:%d", req_id, req_meta.remote_host, req_meta.remote_port
|
"Sending done sending signal for request %s to %s:%d",
|
||||||
|
external_req_id,
|
||||||
|
req_meta.remote_host,
|
||||||
|
req_meta.remote_port,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
||||||
msg_encoder = msgspec.msgpack.Encoder()
|
msg_encoder = msgspec.msgpack.Encoder()
|
||||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, req_id))
|
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
|
||||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||||
ensure_zmq_send(sock, encoded_data)
|
ensure_zmq_send(sock, encoded_data)
|
||||||
ack = sock.recv()
|
ack = sock.recv()
|
||||||
@@ -1227,7 +1237,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Sending done sending signal for request {req_id} to "
|
f"Sending done sending signal for request {external_req_id} to "
|
||||||
f"{req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}"
|
f"{req_meta.remote_host}:{req_meta.remote_port} fail with error: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1338,3 +1348,9 @@ def ensure_zmq_recv(
|
|||||||
else:
|
else:
|
||||||
logger.error(f"Receive failed after all retries: {e}")
|
logger.error(f"Receive failed after all retries: {e}")
|
||||||
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
|
raise RuntimeError(f"Failed to receive data after {max_retries} retries: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_external_request_id(request_id: str):
|
||||||
|
# NOTE(zxr): vLLM PR #27987 add additional suffix
|
||||||
|
# to EngineCore request_id with len(suffix) == 9
|
||||||
|
return request_id[:-9]
|
||||||
|
|||||||
Reference in New Issue
Block a user