Simplify the process launch code in server.py (#2923)

This commit is contained in:
Lianmin Zheng
2025-01-16 07:52:17 -08:00
committed by GitHub
parent e00e5385e0
commit 93d690617e
2 changed files with 40 additions and 22 deletions

View File

@@ -44,7 +44,6 @@ import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from uvicorn.config import LOGGING_CONFIG
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -97,6 +96,7 @@ from sglang.srt.utils import (
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
set_uvicorn_logging_configs,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
from sglang.version import __version__ from sglang.version import __version__
@@ -474,13 +474,13 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path server_args.model_path, server_args.tokenizer_path
) )
memory_saver_adapter = TorchMemorySaverAdapter.create( scheduler_procs = []
enable=server_args.enable_memory_saver
)
if server_args.dp_size == 1: if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes # Launch tensor parallel scheduler processes
scheduler_procs = [] memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=server_args.enable_memory_saver
)
scheduler_pipe_readers = [] scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range( tp_rank_range = range(
@@ -498,12 +498,6 @@ def launch_engine(
proc.start() proc.start()
scheduler_procs.append(proc) scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader) scheduler_pipe_readers.append(reader)
if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
for proc in scheduler_procs:
proc.join()
else: else:
# Launch the data parallel controller # Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False) reader, writer = mp.Pipe(duplex=False)
@@ -512,8 +506,27 @@ def launch_engine(
target=run_data_parallel_controller_process, target=run_data_parallel_controller_process,
args=(server_args, port_args, writer), args=(server_args, port_args, writer),
) )
with memory_saver_adapter.configure_subprocess(): proc.start()
proc.start() scheduler_procs.append(proc)
if server_args.node_rank >= 1:
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
# so they can just wait here.
for reader in scheduler_pipe_readers:
data = reader.recv()
assert data["status"] == "ready"
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
# When using `Engine` as a Python API, we don't want to block here.
return
for proc in scheduler_procs:
proc.join()
logger.error(
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
)
return
# Launch detokenizer process # Launch detokenizer process
detoken_proc = mp.Process( detoken_proc = mp.Process(
@@ -597,14 +610,7 @@ def launch_server(
try: try:
# Update logging configs # Update logging configs
LOGGING_CONFIG["formatters"]["default"][ set_uvicorn_logging_configs()
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG["formatters"]["access"][
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
# Listen for HTTP requests # Listen for HTTP requests
uvicorn.run( uvicorn.run(

View File

@@ -59,6 +59,7 @@ from triton.runtime.cache import (
default_dump_dir, default_dump_dir,
default_override_dir, default_override_dir,
) )
from uvicorn.config import LOGGING_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -1404,3 +1405,14 @@ def nullable_str(val: str):
if not val or val == "None": if not val or val == "None":
return None return None
return val return val
def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["default"][
"fmt"
] = "[%(asctime)s] %(levelprefix)s %(message)s"
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG["formatters"]["access"][
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"