Fix possible ZMQ hanging (#1800)

This commit is contained in:
Liangsheng Yin
2024-10-25 23:07:07 -07:00
committed by GitHub
parent 715b16c140
commit 1e8903414a
5 changed files with 46 additions and 24 deletions

View File

@@ -30,6 +30,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
configure_logger, configure_logger,
get_zmq_socket,
kill_parent_process, kill_parent_process,
suppress_other_loggers, suppress_other_loggers,
) )
@@ -66,8 +67,9 @@ class DataParallelController:
# Init inter-process communication # Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size) self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL) self.recv_from_tokenizer = get_zmq_socket(
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") self.context, zmq.PULL, port_args.scheduler_input_ipc_name
)
# Dispatch method # Dispatch method
self.round_robin_counter = 0 self.round_robin_counter = 0
@@ -120,8 +122,9 @@ class DataParallelController:
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
send_to = self.context.socket(zmq.PUSH) send_to = get_zmq_socket(
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}") self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Wait for model to finish loading # Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):

View File

@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback from sglang.utils import find_printable_text, get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,11 +59,12 @@ class DetokenizerManager:
): ):
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.recv_from_scheduler = context.socket(zmq.PULL) self.recv_from_scheduler = get_zmq_socket(
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}") context, zmq.PULL, port_args.detokenizer_ipc_name
)
self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer = get_zmq_socket(
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}") context, zmq.PUSH, port_args.tokenizer_ipc_name
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = None self.tokenizer = None

View File

@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
broadcast_pyobj, broadcast_pyobj,
configure_logger, configure_logger,
get_zmq_socket,
is_generation_model, is_generation_model,
is_multimodal_model, is_multimodal_model,
kill_parent_process, kill_parent_process,
@@ -110,20 +111,19 @@ class Scheduler:
context = zmq.Context(2) context = zmq.Context(2)
if self.tp_rank == 0: if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = get_zmq_socket(
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") context, zmq.PULL, port_args.scheduler_input_ipc_name
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
# Directly send to the tokenizer/api # Directly send to the tokenizer/api
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = get_zmq_socket(
self.send_to_detokenizer.connect( context, zmq.PUSH, port_args.tokenizer_ipc_name
f"ipc://{port_args.tokenizer_ipc_name}"
) )
else: else:
# Send to the detokenizer # Send to the detokenizer
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = get_zmq_socket(
self.send_to_detokenizer.connect( context, zmq.PUSH, port_args.detokenizer_ipc_name
f"ipc://{port_args.detokenizer_ipc_name}"
) )
else: else:
self.recv_from_tokenizer = None self.recv_from_tokenizer = None

View File

@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import is_generation_model, is_multimodal_model from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -86,11 +86,12 @@ class TokenizerManager:
# Init inter-process communication # Init inter-process communication
context = zmq.asyncio.Context(2) context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer = get_zmq_socket(
self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}") context, zmq.PULL, port_args.tokenizer_ipc_name
)
self.send_to_scheduler = context.socket(zmq.PUSH) self.send_to_scheduler = get_zmq_socket(
self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}") context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path

View File

@@ -35,6 +35,7 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import zmq
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from torch import nn from torch import nn
@@ -720,3 +721,19 @@ def first_rank_print(*args, **kwargs):
print(*args, **kwargs) print(*args, **kwargs)
else: else:
pass pass
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
socket = context.socket(socket_type)
if socket_type == zmq.PUSH:
socket.setsockopt(zmq.SNDHWM, 0)
socket.setsockopt(zmq.SNDBUF, 100000000)
socket.connect(f"ipc://{endpoint}")
elif socket_type == zmq.PULL:
socket.setsockopt(zmq.RCVHWM, 0)
socket.setsockopt(zmq.RCVBUF, 100000000)
socket.bind(f"ipc://{endpoint}")
else:
raise ValueError(f"Unsupported socket type: {socket_type}")
return socket