diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 450c62020..dca3d4f01 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -30,6 +30,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( configure_logger, + get_zmq_socket, kill_parent_process, suppress_other_loggers, ) @@ -66,8 +67,9 @@ class DataParallelController: # Init inter-process communication self.context = zmq.Context(1 + server_args.dp_size) - self.recv_from_tokenizer = self.context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") + self.recv_from_tokenizer = get_zmq_socket( + self.context, zmq.PULL, port_args.scheduler_input_ipc_name + ) # Dispatch method self.round_robin_counter = 0 @@ -120,8 +122,9 @@ class DataParallelController: scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) - send_to = self.context.socket(zmq.PUSH) - send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}") + send_to = get_zmq_socket( + self.context, zmq.PUSH, port_args.scheduler_input_ipc_name + ) # Wait for model to finish loading for i in range(len(scheduler_pipe_readers)): diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index caa5b611e..0387124df 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, kill_parent_process +from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process from sglang.utils import find_printable_text, get_exception_traceback logger = logging.getLogger(__name__) @@ -59,11 +59,12 @@ class DetokenizerManager: ): # Init inter-process communication context = zmq.Context(2) - self.recv_from_scheduler = context.socket(zmq.PULL) - self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}") - - self.send_to_tokenizer = context.socket(zmq.PUSH) - self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}") + self.recv_from_scheduler = get_zmq_socket( + context, zmq.PULL, port_args.detokenizer_ipc_name + ) + self.send_to_tokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name + ) if server_args.skip_tokenizer_init: self.tokenizer = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ce5ddd7c7..c3f679198 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( broadcast_pyobj, configure_logger, + get_zmq_socket, is_generation_model, is_multimodal_model, kill_parent_process, @@ -110,20 +111,19 @@ class Scheduler: context = zmq.Context(2) if self.tp_rank == 0: - self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") + self.recv_from_tokenizer = get_zmq_socket( + context, zmq.PULL, port_args.scheduler_input_ipc_name + ) if server_args.skip_tokenizer_init: # Directly send to the tokenizer/api - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect( - f"ipc://{port_args.tokenizer_ipc_name}" + self.send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.tokenizer_ipc_name ) else: # Send to the detokenizer - self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect( - f"ipc://{port_args.detokenizer_ipc_name}" + self.send_to_detokenizer = get_zmq_socket( + context, zmq.PUSH, port_args.detokenizer_ipc_name ) else: self.recv_from_tokenizer = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 585d5d8ce..347e7ad1d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import is_generation_model, is_multimodal_model +from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -86,11 +86,12 @@ class TokenizerManager: # Init inter-process communication context = zmq.asyncio.Context(2) - self.recv_from_detokenizer = context.socket(zmq.PULL) - self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}") - - self.send_to_scheduler = context.socket(zmq.PUSH) - self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}") + self.recv_from_detokenizer = get_zmq_socket( + context, zmq.PULL, port_args.tokenizer_ipc_name + ) + self.send_to_scheduler = get_zmq_socket( + context, zmq.PUSH, port_args.scheduler_input_ipc_name + ) # Read model args self.model_path = server_args.model_path diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 69aea52ac..20bf6b264 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,6 +35,7 @@ import psutil import requests import torch import torch.distributed as dist +import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from torch import nn @@ -720,3 +721,19 @@ def first_rank_print(*args, **kwargs): print(*args, **kwargs) else: pass + + +def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): + socket = context.socket(socket_type) + if socket_type == zmq.PUSH: + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, 100000000) + socket.connect(f"ipc://{endpoint}") + elif socket_type == zmq.PULL: + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, 100000000) + socket.bind(f"ipc://{endpoint}") + else: + raise ValueError(f"Unsupported socket type: {socket_type}") + + return socket