Fix possible ZMQ hanging (#1800)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.managers.scheduler import run_scheduler_process
|
|||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
configure_logger,
|
configure_logger,
|
||||||
|
get_zmq_socket,
|
||||||
kill_parent_process,
|
kill_parent_process,
|
||||||
suppress_other_loggers,
|
suppress_other_loggers,
|
||||||
)
|
)
|
||||||
@@ -66,8 +67,9 @@ class DataParallelController:
|
|||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
self.context = zmq.Context(1 + server_args.dp_size)
|
self.context = zmq.Context(1 + server_args.dp_size)
|
||||||
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
# Dispatch method
|
# Dispatch method
|
||||||
self.round_robin_counter = 0
|
self.round_robin_counter = 0
|
||||||
@@ -120,8 +122,9 @@ class DataParallelController:
|
|||||||
scheduler_procs.append(proc)
|
scheduler_procs.append(proc)
|
||||||
scheduler_pipe_readers.append(reader)
|
scheduler_pipe_readers.append(reader)
|
||||||
|
|
||||||
send_to = self.context.socket(zmq.PUSH)
|
send_to = get_zmq_socket(
|
||||||
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
# Wait for model to finish loading
|
# Wait for model to finish loading
|
||||||
for i in range(len(scheduler_pipe_readers)):
|
for i in range(len(scheduler_pipe_readers)):
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
|
||||||
from sglang.utils import find_printable_text, get_exception_traceback
|
from sglang.utils import find_printable_text, get_exception_traceback
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -59,11 +59,12 @@ class DetokenizerManager:
|
|||||||
):
|
):
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
self.recv_from_scheduler = context.socket(zmq.PULL)
|
self.recv_from_scheduler = get_zmq_socket(
|
||||||
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")
|
context, zmq.PULL, port_args.detokenizer_ipc_name
|
||||||
|
)
|
||||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
self.send_to_tokenizer = get_zmq_socket(
|
||||||
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
broadcast_pyobj,
|
broadcast_pyobj,
|
||||||
configure_logger,
|
configure_logger,
|
||||||
|
get_zmq_socket,
|
||||||
is_generation_model,
|
is_generation_model,
|
||||||
is_multimodal_model,
|
is_multimodal_model,
|
||||||
kill_parent_process,
|
kill_parent_process,
|
||||||
@@ -110,20 +111,19 @@ class Scheduler:
|
|||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
if self.tp_rank == 0:
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
self.recv_from_tokenizer = get_zmq_socket(
|
||||||
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")
|
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
# Directly send to the tokenizer/api
|
# Directly send to the tokenizer/api
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
self.send_to_detokenizer.connect(
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
||||||
f"ipc://{port_args.tokenizer_ipc_name}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Send to the detokenizer
|
# Send to the detokenizer
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = get_zmq_socket(
|
||||||
self.send_to_detokenizer.connect(
|
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
||||||
f"ipc://{port_args.detokenizer_ipc_name}"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.recv_from_tokenizer = None
|
self.recv_from_tokenizer = None
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import is_generation_model, is_multimodal_model
|
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
@@ -86,11 +86,12 @@ class TokenizerManager:
|
|||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
self.recv_from_detokenizer = get_zmq_socket(
|
||||||
self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
|
context, zmq.PULL, port_args.tokenizer_ipc_name
|
||||||
|
)
|
||||||
self.send_to_scheduler = context.socket(zmq.PUSH)
|
self.send_to_scheduler = get_zmq_socket(
|
||||||
self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
|
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||||
|
)
|
||||||
|
|
||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import zmq
|
||||||
from fastapi.responses import ORJSONResponse
|
from fastapi.responses import ORJSONResponse
|
||||||
from packaging import version as pkg_version
|
from packaging import version as pkg_version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -720,3 +721,19 @@ def first_rank_print(*args, **kwargs):
|
|||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str):
|
||||||
|
socket = context.socket(socket_type)
|
||||||
|
if socket_type == zmq.PUSH:
|
||||||
|
socket.setsockopt(zmq.SNDHWM, 0)
|
||||||
|
socket.setsockopt(zmq.SNDBUF, 100000000)
|
||||||
|
socket.connect(f"ipc://{endpoint}")
|
||||||
|
elif socket_type == zmq.PULL:
|
||||||
|
socket.setsockopt(zmq.RCVHWM, 0)
|
||||||
|
socket.setsockopt(zmq.RCVBUF, 100000000)
|
||||||
|
socket.bind(f"ipc://{endpoint}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||||
|
|
||||||
|
return socket
|
||||||
|
|||||||
Reference in New Issue
Block a user