From 0427416b59d11958a63f2ed344af3c5141d8e835 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 16 Jan 2025 14:36:07 -0800 Subject: [PATCH] Fix zmq binding (#2930) Co-authored-by: Chunyuan WU --- .../sglang/srt/managers/data_parallel_controller.py | 3 ++- python/sglang/srt/managers/detokenizer_manager.py | 4 ++-- python/sglang/srt/managers/scheduler.py | 8 ++++---- python/sglang/srt/managers/tokenizer_manager.py | 4 ++-- python/sglang/srt/utils.py | 11 ++++++++--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index c4ebbb3cf..4f57ac5b2 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -66,7 +66,7 @@ class DataParallelController: self.context = zmq.Context(1 + server_args.dp_size) if server_args.node_rank == 0: self.recv_from_tokenizer = get_zmq_socket( - self.context, zmq.PULL, port_args.scheduler_input_ipc_name + self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) # Dispatch method @@ -93,6 +93,7 @@ class DataParallelController: self.context, zmq.PUSH, dp_port_args[dp_rank].scheduler_input_ipc_name, + True, ) def launch_dp_schedulers(self, server_args, port_args): diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 7a0f7b0d5..f0605ee1f 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -58,10 +58,10 @@ class DetokenizerManager: # Init inter-process communication context = zmq.Context(2) self.recv_from_scheduler = get_zmq_socket( - context, zmq.PULL, port_args.detokenizer_ipc_name + context, zmq.PULL, port_args.detokenizer_ipc_name, True ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9bebbcd92..e00bd980f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -162,21 +162,21 @@ class Scheduler: if self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, port_args.scheduler_input_ipc_name + context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: # Directly send to the TokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) else: # Send to the DetokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.detokenizer_ipc_name + context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) else: self.recv_from_tokenizer = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4e120f3a9..230d4f8d0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -119,10 +119,10 @@ class TokenizerManager: # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_ipc_name + context, zmq.PULL, port_args.tokenizer_ipc_name, True ) self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True ) # Read model args diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f1603ec0e..0813fa248 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -789,7 +789,9 @@ def first_rank_print(*args, **kwargs): pass -def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): +def get_zmq_socket( + context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool +): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -802,14 +804,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - socket.connect(endpoint) elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) - socket.bind(endpoint) else: raise ValueError(f"Unsupported socket type: {socket_type}") + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + return socket