[BugFix][P/D][0.18.0]Add a retry mechanism to prevent packet loss (#8167)
### What this PR does / why we need it? Add a retry mechanism to prevent packet loss Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -541,12 +541,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
|||||||
self.done_requests = set()
|
self.done_requests = set()
|
||||||
return finished_requests
|
return finished_requests
|
||||||
|
|
||||||
def update_task(self, req_id, trans_count):
|
def update_task(self, req_id, trans_count, side_channel_path):
|
||||||
with self.lock:
|
with self.lock:
|
||||||
if req_id not in self.task_tracker:
|
if req_id not in self.task_tracker:
|
||||||
self.task_tracker[req_id] = 0
|
self.task_tracker[req_id] = set()
|
||||||
self.task_tracker[req_id] += 1
|
self.task_tracker[req_id].add(side_channel_path)
|
||||||
if self.task_tracker[req_id] == trans_count:
|
if len(self.task_tracker[req_id]) == trans_count:
|
||||||
self.task_tracker.pop(req_id)
|
self.task_tracker.pop(req_id)
|
||||||
self.done_requests.add(req_id)
|
self.done_requests.add(req_id)
|
||||||
|
|
||||||
@@ -581,7 +581,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
|||||||
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
||||||
request_id = msg[1]
|
request_id = msg[1]
|
||||||
trans_count = msg[2]
|
trans_count = msg[2]
|
||||||
self.update_task(request_id, trans_count)
|
side_channel_path = msg[3]
|
||||||
|
self.update_task(request_id, trans_count, side_channel_path)
|
||||||
sock.send_multipart((identity, b"", b"ACK"))
|
sock.send_multipart((identity, b"", b"ACK"))
|
||||||
else:
|
else:
|
||||||
logger.error("Connection listener got unexpected message %s", msg)
|
logger.error("Connection listener got unexpected message %s", msg)
|
||||||
@@ -1760,24 +1761,39 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
try:
|
try:
|
||||||
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
||||||
msg_encoder = msgspec.msgpack.Encoder()
|
msg_encoder = msgspec.msgpack.Encoder()
|
||||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx]))
|
side_channel_path = f"{self.side_channel_host}:{self.handshake_port}"
|
||||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
encoded_data = msg_encoder.encode(
|
||||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
(DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx], side_channel_path)
|
||||||
# Avoid blocking forever waiting for the REQ/ACK response.
|
)
|
||||||
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore
|
max_retries = 3
|
||||||
|
for attempt in range(1, max_retries + 1):
|
||||||
try:
|
try:
|
||||||
ack = sock.recv()
|
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||||
except zmq.Again: # type: ignore
|
sock.setsockopt(zmq.SNDTIMEO, int(self.timeout * 1000))
|
||||||
logger.warning(
|
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||||
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)",
|
if not sock.poll(int(self.timeout * 1000), zmq.POLLIN): # type: ignore[attr-defined]
|
||||||
external_req_id,
|
raise TimeoutError(
|
||||||
req_meta.remote_host,
|
f"Timed out waiting for ACK from {req_meta.remote_host}:{req_meta.remote_port}"
|
||||||
req_meta.remote_port,
|
)
|
||||||
self.timeout,
|
ack = sock.recv()
|
||||||
)
|
if ack != b"ACK":
|
||||||
return
|
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||||
if ack != b"ACK":
|
return
|
||||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
except Exception as e:
|
||||||
|
if attempt < max_retries:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to send done sending signal for request %s to %s:%d on attempt %d/%d: %s. "
|
||||||
|
"Retrying...",
|
||||||
|
external_req_id,
|
||||||
|
req_meta.remote_host,
|
||||||
|
req_meta.remote_port,
|
||||||
|
attempt,
|
||||||
|
max_retries,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
time.sleep(0.1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to receive ACK after {max_retries} attempts: {e}") from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Sending done sending signal for request {external_req_id} to "
|
f"Sending done sending signal for request {external_req_id} to "
|
||||||
|
|||||||
Reference in New Issue
Block a user