[BugFix][P/D][0.18.0]Add a retry mechanism to prevent packet loss (#8167)

### What this PR does / why we need it?
Add a retry mechanism to prevent packet loss

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
wangxiaoteng888
2026-04-12 23:30:45 +08:00
committed by GitHub
parent 4adc6a68f5
commit 39c071a0f5

View File

@@ -541,12 +541,12 @@ class KVCacheRecvingLayerThread(threading.Thread):
self.done_requests = set() self.done_requests = set()
return finished_requests return finished_requests
def update_task(self, req_id, trans_count): def update_task(self, req_id, trans_count, side_channel_path):
with self.lock: with self.lock:
if req_id not in self.task_tracker: if req_id not in self.task_tracker:
self.task_tracker[req_id] = 0 self.task_tracker[req_id] = set()
self.task_tracker[req_id] += 1 self.task_tracker[req_id].add(side_channel_path)
if self.task_tracker[req_id] == trans_count: if len(self.task_tracker[req_id]) == trans_count:
self.task_tracker.pop(req_id) self.task_tracker.pop(req_id)
self.done_requests.add(req_id) self.done_requests.add(req_id)
@@ -581,7 +581,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1]) logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
request_id = msg[1] request_id = msg[1]
trans_count = msg[2] trans_count = msg[2]
self.update_task(request_id, trans_count) side_channel_path = msg[3]
self.update_task(request_id, trans_count, side_channel_path)
sock.send_multipart((identity, b"", b"ACK")) sock.send_multipart((identity, b"", b"ACK"))
else: else:
logger.error("Connection listener got unexpected message %s", msg) logger.error("Connection listener got unexpected message %s", msg)
@@ -1760,24 +1761,39 @@ class MooncakeLayerwiseConnectorWorker:
try: try:
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port) path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
msg_encoder = msgspec.msgpack.Encoder() msg_encoder = msgspec.msgpack.Encoder()
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx])) side_channel_path = f"{self.side_channel_host}:{self.handshake_port}"
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore encoded_data = msg_encoder.encode(
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}") (DONE_SENDING_MSG, external_req_id, req_meta.trans_count[group_idx], side_channel_path)
# Avoid blocking forever waiting for the REQ/ACK response. )
sock.setsockopt(zmq.RCVTIMEO, int(self.timeout * 1000)) # type: ignore max_retries = 3
for attempt in range(1, max_retries + 1):
try: try:
ack = sock.recv() with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
except zmq.Again: # type: ignore sock.setsockopt(zmq.SNDTIMEO, int(self.timeout * 1000))
logger.warning( ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
"Timeout waiting ACK for request %s from %s:%d (timeout=%.3fs)", if not sock.poll(int(self.timeout * 1000), zmq.POLLIN): # type: ignore[attr-defined]
external_req_id, raise TimeoutError(
req_meta.remote_host, f"Timed out waiting for ACK from {req_meta.remote_host}:{req_meta.remote_port}"
req_meta.remote_port, )
self.timeout, ack = sock.recv()
) if ack != b"ACK":
return raise ValueError(f"Unexpected ACK response: {ack}")
if ack != b"ACK": return
raise ValueError(f"Unexpected ACK response: {ack}") except Exception as e:
if attempt < max_retries:
logger.warning(
"Failed to send done sending signal for request %s to %s:%d on attempt %d/%d: %s. "
"Retrying...",
external_req_id,
req_meta.remote_host,
req_meta.remote_port,
attempt,
max_retries,
e,
)
time.sleep(0.1)
else:
raise RuntimeError(f"Failed to receive ACK after {max_retries} attempts: {e}") from e
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Sending done sending signal for request {external_req_id} to " f"Sending done sending signal for request {external_req_id} to "