[PD] Fix no cache connect for recevier (#5534)
This commit is contained in:
@@ -387,6 +387,10 @@ class MooncakeKVSender(BaseKVSender):
|
|||||||
|
|
||||||
|
|
||||||
class MooncakeKVReceiver(BaseKVReceiver):
|
class MooncakeKVReceiver(BaseKVReceiver):
|
||||||
|
_ctx = zmq.Context()
|
||||||
|
_socket_cache = {}
|
||||||
|
_socket_locks = {}
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -436,11 +440,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@cache
|
@classmethod
|
||||||
def _connect(self, endpoint: str):
|
def _connect(cls, endpoint: str):
|
||||||
socket = zmq.Context().socket(zmq.PUSH)
|
with cls._global_lock:
|
||||||
socket.connect(endpoint)
|
if endpoint not in cls._socket_cache:
|
||||||
return socket
|
sock = cls._ctx.socket(zmq.PUSH)
|
||||||
|
sock.connect(endpoint)
|
||||||
|
cls._socket_cache[endpoint] = sock
|
||||||
|
cls._socket_locks[endpoint] = threading.Lock()
|
||||||
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||||
self.prefill_server_url = (
|
self.prefill_server_url = (
|
||||||
@@ -456,7 +464,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
packed_aux_data_ptrs = b"".join(
|
packed_aux_data_ptrs = b"".join(
|
||||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||||
)
|
)
|
||||||
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||||
|
with lock:
|
||||||
|
sock.send_multipart(
|
||||||
[
|
[
|
||||||
str(self.bootstrap_room).encode("ascii"),
|
str(self.bootstrap_room).encode("ascii"),
|
||||||
get_local_ip_by_remote().encode("ascii"),
|
get_local_ip_by_remote().encode("ascii"),
|
||||||
|
|||||||
Reference in New Issue
Block a user