diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 20bf6b264..6ad39647f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -724,14 +724,22 @@ def first_rank_print(*args, **kwargs): def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): + mem = psutil.virtual_memory() + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) + else: + buf_size = -1 + socket = context.socket(socket_type) if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) - socket.setsockopt(zmq.SNDBUF, 100000000) + socket.setsockopt(zmq.SNDBUF, buf_size) socket.connect(f"ipc://{endpoint}") elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) - socket.setsockopt(zmq.RCVBUF, 100000000) + socket.setsockopt(zmq.RCVBUF, buf_size) socket.bind(f"ipc://{endpoint}") else: raise ValueError(f"Unsupported socket type: {socket_type}")