From 93d690617e2d0dce582a68f8037f98b7e168d72f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 16 Jan 2025 07:52:17 -0800 Subject: [PATCH] Simplify the process launch code in server.py (#2923) --- python/sglang/srt/server.py | 50 +++++++++++++++++++++---------------- python/sglang/srt/utils.py | 12 +++++++++ 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 6b180039e..af0f2a08d 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -44,7 +44,6 @@ import uvloop from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from uvicorn.config import LOGGING_CONFIG from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.hf_transformers_utils import get_tokenizer @@ -97,6 +96,7 @@ from sglang.srt.utils import ( prepare_model_and_tokenizer, set_prometheus_multiproc_dir, set_ulimit, + set_uvicorn_logging_configs, ) from sglang.utils import get_exception_traceback from sglang.version import __version__ @@ -474,13 +474,13 @@ def launch_engine( server_args.model_path, server_args.tokenizer_path ) - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=server_args.enable_memory_saver - ) - + scheduler_procs = [] if server_args.dp_size == 1: # Launch tensor parallel scheduler processes - scheduler_procs = [] + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + scheduler_pipe_readers = [] tp_size_per_node = server_args.tp_size // server_args.nnodes tp_rank_range = range( @@ -498,12 +498,6 @@ def launch_engine( proc.start() scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - for proc in scheduler_procs: - proc.join() else: # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) @@ -512,8 +506,27 @@ def launch_engine( target=run_data_parallel_controller_process, args=(server_args, port_args, writer), ) - with memory_saver_adapter.configure_subprocess(): - proc.start() + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return # Launch detokenizer process detoken_proc = mp.Process( @@ -597,14 +610,7 @@ def launch_server( try: # Update logging configs - LOGGING_CONFIG["formatters"]["default"][ - "fmt" - ] = "[%(asctime)s] %(levelprefix)s %(message)s" - LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - LOGGING_CONFIG["formatters"]["access"][ - "fmt" - ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' - LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + set_uvicorn_logging_configs() # Listen for HTTP requests uvicorn.run( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c521e002f..583dd92e1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -59,6 +59,7 @@ from triton.runtime.cache import ( default_dump_dir, default_override_dir, ) +from uvicorn.config import LOGGING_CONFIG logger = logging.getLogger(__name__) @@ -1404,3 +1405,14 @@ def nullable_str(val: str): if not val or val == "None": return None return val + + +def set_uvicorn_logging_configs(): + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"