[BugFix][P/D][0.18.0]Add a retry mechanism to prevent packet loss (#8167)
### What this PR does / why we need it? Add a retry mechanism to prevent packet loss Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -541,12 +541,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
self.done_requests = set()
|
||||
return finished_requests
|
||||
|
||||
def update_task(self, req_id, trans_count):
|
||||
def update_task(self, req_id, trans_count, side_channel_path):
|
||||
with self.lock:
|
||||
if req_id not in self.task_tracker:
|
||||
self.task_tracker[req_id] = 0
|
||||
self.task_tracker[req_id] += 1
|
||||
if self.task_tracker[req_id] == trans_count:
|
||||
self.task_tracker[req_id] = set()
|
||||
self.task_tracker[req_id].add(side_channel_path)
|
||||
if len(self.task_tracker[req_id]) == trans_count:
|
||||
self.task_tracker.pop(req_id)
|
||||
self.done_requests.add(req_id)
|
||||
|
||||
@@ -581,7 +581,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
||||
request_id = msg[1]
|
||||
trans_count = msg[2]
|
||||
self.update_task(request_id, trans_count)
|
||||
side_channel_path = msg[3]
|
||||
self.update_task(request_id, trans_count, side_channel_path)
|
||||
sock.send_multipart((identity, b"", b"ACK"))
|
||||
else:
|
||||
logger.error("Connection listener got unexpected message %s", msg)
|
||||
@@ -1760,24 +1761,39 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
try:
|
||||
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx]))
|
||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||
# Avoid blocking forever waiting for the REQ/ACK response.
|
||||
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore
|
||||
side_channel_path = f"{self.side_channel_host}:{self.handshake_port}"
|
||||
encoded_data = msg_encoder.encode(
|
||||
(DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx], side_channel_path)
|
||||
)
|
||||
max_retries = 3
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
ack = sock.recv()
|
||||
except zmq.Again: # type: ignore
|
||||
logger.warning(
|
||||
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)",
|
||||
external_req_id,
|
||||
req_meta.remote_host,
|
||||
req_meta.remote_port,
|
||||
self.timeout,
|
||||
)
|
||||
return
|
||||
if ack != b"ACK":
|
||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||
sock.setsockopt(zmq.SNDTIMEO, int(self.timeout * 1000))
|
||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||
if not sock.poll(int(self.timeout * 1000), zmq.POLLIN): # type: ignore[attr-defined]
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for ACK from {req_meta.remote_host}:{req_meta.remote_port}"
|
||||
)
|
||||
ack = sock.recv()
|
||||
if ack != b"ACK":
|
||||
raise ValueError(f"Unexpected ACK response: {ack}")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
"Failed to send done sending signal for request %s to %s:%d on attempt %d/%d: %s. "
|
||||
"Retrying...",
|
||||
external_req_id,
|
||||
req_meta.remote_host,
|
||||
req_meta.remote_port,
|
||||
attempt,
|
||||
max_retries,
|
||||
e,
|
||||
)
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to receive ACK after {max_retries} attempts: {e}") from e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Sending done sending signal for request {external_req_id} to "
|
||||
|
||||
Reference in New Issue
Block a user