Set ZMQ buffer size heuristic (#1801)
This commit is contained in:
@@ -724,14 +724,22 @@ def first_rank_print(*args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
|
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
total_mem = mem.total / 1024**3
|
||||||
|
available_mem = mem.available / 1024**3
|
||||||
|
if total_mem > 32 and available_mem > 16:
|
||||||
|
buf_size = int(0.5 * 1024**3)
|
||||||
|
else:
|
||||||
|
buf_size = -1
|
||||||
|
|
||||||
socket = context.socket(socket_type)
|
socket = context.socket(socket_type)
|
||||||
if socket_type == zmq.PUSH:
|
if socket_type == zmq.PUSH:
|
||||||
socket.setsockopt(zmq.SNDHWM, 0)
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
socket.setsockopt(zmq.SNDBUF, 100000000)
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||||
socket.connect(f"ipc://{endpoint}")
|
socket.connect(f"ipc://{endpoint}")
|
||||||
elif socket_type == zmq.PULL:
|
elif socket_type == zmq.PULL:
|
||||||
socket.setsockopt(zmq.RCVHWM, 0)
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
socket.setsockopt(zmq.RCVBUF, 100000000)
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||||
socket.bind(f"ipc://{endpoint}")
|
socket.bind(f"ipc://{endpoint}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|||||||
Reference in New Issue
Block a user