From 59dd090f1c25461ebe371dc7debad59a425795e7 Mon Sep 17 00:00:00 2001 From: ybyang <10629930+whybeyoung@users.noreply.github.com> Date: Sat, 19 Apr 2025 14:55:28 +0800 Subject: [PATCH] [PD] Fix no cache connect for recevier (#5534) --- .../srt/disaggregation/mooncake/conn.py | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py index c97222454..ef9e127c0 100644 --- a/python/sglang/srt/disaggregation/mooncake/conn.py +++ b/python/sglang/srt/disaggregation/mooncake/conn.py @@ -387,6 +387,10 @@ class MooncakeKVSender(BaseKVSender): class MooncakeKVReceiver(BaseKVReceiver): + _ctx = zmq.Context() + _socket_cache = {} + _socket_locks = {} + _global_lock = threading.Lock() def __init__( self, @@ -436,11 +440,15 @@ class MooncakeKVReceiver(BaseKVReceiver): logger.error(f"Error fetching prefill info from bootstrap: {e}") return None - @cache - def _connect(self, endpoint: str): - socket = zmq.Context().socket(zmq.PUSH) - socket.connect(endpoint) - return socket + @classmethod + def _connect(cls, endpoint: str): + with cls._global_lock: + if endpoint not in cls._socket_cache: + sock = cls._ctx.socket(zmq.PUSH) + sock.connect(endpoint) + cls._socket_cache[endpoint] = sock + cls._socket_locks[endpoint] = threading.Lock() + return cls._socket_cache[endpoint], cls._socket_locks[endpoint] def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): self.prefill_server_url = ( @@ -456,18 +464,20 @@ class MooncakeKVReceiver(BaseKVReceiver): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) - self._connect("tcp://" + self.prefill_server_url).send_multipart( - [ - str(self.bootstrap_room).encode("ascii"), - get_local_ip_by_remote().encode("ascii"), - str(self.kv_mgr.rank_port).encode("ascii"), - self.session_id.encode("ascii"), - packed_kv_data_ptrs, - kv_indices.tobytes(), - packed_aux_data_ptrs, - str(aux_index).encode("ascii"), - ] - ) + sock, lock = self._connect("tcp://" + self.prefill_server_url) + with lock: + sock.send_multipart( + [ + str(self.bootstrap_room).encode("ascii"), + get_local_ip_by_remote().encode("ascii"), + str(self.kv_mgr.rank_port).encode("ascii"), + self.session_id.encode("ascii"), + packed_kv_data_ptrs, + kv_indices.tobytes(), + packed_aux_data_ptrs, + str(aux_index).encode("ascii"), + ] + ) def poll(self) -> KVPoll: return self.kv_mgr.check_status(self.bootstrap_room)