Simplify the process launch code in server.py (#2923)
This commit is contained in:
@@ -44,7 +44,6 @@ import uvloop
|
||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
@@ -97,6 +96,7 @@ from sglang.srt.utils import (
|
||||
prepare_model_and_tokenizer,
|
||||
set_prometheus_multiproc_dir,
|
||||
set_ulimit,
|
||||
set_uvicorn_logging_configs,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.version import __version__
|
||||
@@ -474,13 +474,13 @@ def launch_engine(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
scheduler_procs = []
|
||||
if server_args.dp_size == 1:
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=server_args.enable_memory_saver
|
||||
)
|
||||
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
@@ -498,12 +498,6 @@ def launch_engine(
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
for proc in scheduler_procs:
|
||||
proc.join()
|
||||
else:
|
||||
# Launch the data parallel controller
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
@@ -512,8 +506,27 @@ def launch_engine(
|
||||
target=run_data_parallel_controller_process,
|
||||
args=(server_args, port_args, writer),
|
||||
)
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
proc.start()
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
|
||||
for reader in scheduler_pipe_readers:
|
||||
data = reader.recv()
|
||||
assert data["status"] == "ready"
|
||||
|
||||
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||
# When using `Engine` as a Python API, we don't want to block here.
|
||||
return
|
||||
|
||||
for proc in scheduler_procs:
|
||||
proc.join()
|
||||
logger.error(
|
||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||
)
|
||||
return
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
@@ -597,14 +610,7 @@ def launch_server(
|
||||
|
||||
try:
|
||||
# Update logging configs
|
||||
LOGGING_CONFIG["formatters"]["default"][
|
||||
"fmt"
|
||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
LOGGING_CONFIG["formatters"]["access"][
|
||||
"fmt"
|
||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
set_uvicorn_logging_configs()
|
||||
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
|
||||
@@ -59,6 +59,7 @@ from triton.runtime.cache import (
|
||||
default_dump_dir,
|
||||
default_override_dir,
|
||||
)
|
||||
from uvicorn.config import LOGGING_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1404,3 +1405,14 @@ def nullable_str(val: str):
|
||||
if not val or val == "None":
|
||||
return None
|
||||
return val
|
||||
|
||||
|
||||
def set_uvicorn_logging_configs():
|
||||
LOGGING_CONFIG["formatters"]["default"][
|
||||
"fmt"
|
||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
LOGGING_CONFIG["formatters"]["access"][
|
||||
"fmt"
|
||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
Reference in New Issue
Block a user