[PD] Fix no cache connect for recevier (#5534)
This commit is contained in:
@@ -387,6 +387,10 @@ class MooncakeKVSender(BaseKVSender):
|
||||
|
||||
|
||||
class MooncakeKVReceiver(BaseKVReceiver):
|
||||
_ctx = zmq.Context()
|
||||
_socket_cache = {}
|
||||
_socket_locks = {}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -436,11 +440,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
@cache
|
||||
def _connect(self, endpoint: str):
|
||||
socket = zmq.Context().socket(zmq.PUSH)
|
||||
socket.connect(endpoint)
|
||||
return socket
|
||||
@classmethod
|
||||
def _connect(cls, endpoint: str):
|
||||
with cls._global_lock:
|
||||
if endpoint not in cls._socket_cache:
|
||||
sock = cls._ctx.socket(zmq.PUSH)
|
||||
sock.connect(endpoint)
|
||||
cls._socket_cache[endpoint] = sock
|
||||
cls._socket_locks[endpoint] = threading.Lock()
|
||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
self.prefill_server_url = (
|
||||
@@ -456,18 +464,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
packed_aux_data_ptrs = b"".join(
|
||||
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
||||
)
|
||||
self._connect("tcp://" + self.prefill_server_url).send_multipart(
|
||||
[
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.session_id.encode("ascii"),
|
||||
packed_kv_data_ptrs,
|
||||
kv_indices.tobytes(),
|
||||
packed_aux_data_ptrs,
|
||||
str(aux_index).encode("ascii"),
|
||||
]
|
||||
)
|
||||
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
||||
with lock:
|
||||
sock.send_multipart(
|
||||
[
|
||||
str(self.bootstrap_room).encode("ascii"),
|
||||
get_local_ip_by_remote().encode("ascii"),
|
||||
str(self.kv_mgr.rank_port).encode("ascii"),
|
||||
self.session_id.encode("ascii"),
|
||||
packed_kv_data_ptrs,
|
||||
kv_indices.tobytes(),
|
||||
packed_aux_data_ptrs,
|
||||
str(aux_index).encode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
def poll(self) -> KVPoll:
|
||||
return self.kv_mgr.check_status(self.bootstrap_room)
|
||||
|
||||
Reference in New Issue
Block a user