Simplify the process launch code in server.py (#2923)
This commit is contained in:
@@ -44,7 +44,6 @@ import uvloop
|
|||||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
from uvicorn.config import LOGGING_CONFIG
|
|
||||||
|
|
||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
@@ -97,6 +96,7 @@ from sglang.srt.utils import (
|
|||||||
prepare_model_and_tokenizer,
|
prepare_model_and_tokenizer,
|
||||||
set_prometheus_multiproc_dir,
|
set_prometheus_multiproc_dir,
|
||||||
set_ulimit,
|
set_ulimit,
|
||||||
|
set_uvicorn_logging_configs,
|
||||||
)
|
)
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
from sglang.version import __version__
|
from sglang.version import __version__
|
||||||
@@ -474,13 +474,13 @@ def launch_engine(
|
|||||||
server_args.model_path, server_args.tokenizer_path
|
server_args.model_path, server_args.tokenizer_path
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
scheduler_procs = []
|
||||||
enable=server_args.enable_memory_saver
|
|
||||||
)
|
|
||||||
|
|
||||||
if server_args.dp_size == 1:
|
if server_args.dp_size == 1:
|
||||||
# Launch tensor parallel scheduler processes
|
# Launch tensor parallel scheduler processes
|
||||||
scheduler_procs = []
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=server_args.enable_memory_saver
|
||||||
|
)
|
||||||
|
|
||||||
scheduler_pipe_readers = []
|
scheduler_pipe_readers = []
|
||||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||||
tp_rank_range = range(
|
tp_rank_range = range(
|
||||||
@@ -498,12 +498,6 @@ def launch_engine(
|
|||||||
proc.start()
|
proc.start()
|
||||||
scheduler_procs.append(proc)
|
scheduler_procs.append(proc)
|
||||||
scheduler_pipe_readers.append(reader)
|
scheduler_pipe_readers.append(reader)
|
||||||
|
|
||||||
if server_args.node_rank >= 1:
|
|
||||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
|
||||||
# so they can just wait here.
|
|
||||||
for proc in scheduler_procs:
|
|
||||||
proc.join()
|
|
||||||
else:
|
else:
|
||||||
# Launch the data parallel controller
|
# Launch the data parallel controller
|
||||||
reader, writer = mp.Pipe(duplex=False)
|
reader, writer = mp.Pipe(duplex=False)
|
||||||
@@ -512,8 +506,27 @@ def launch_engine(
|
|||||||
target=run_data_parallel_controller_process,
|
target=run_data_parallel_controller_process,
|
||||||
args=(server_args, port_args, writer),
|
args=(server_args, port_args, writer),
|
||||||
)
|
)
|
||||||
with memory_saver_adapter.configure_subprocess():
|
proc.start()
|
||||||
proc.start()
|
scheduler_procs.append(proc)
|
||||||
|
|
||||||
|
if server_args.node_rank >= 1:
|
||||||
|
# In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
|
||||||
|
# so they can just wait here.
|
||||||
|
|
||||||
|
for reader in scheduler_pipe_readers:
|
||||||
|
data = reader.recv()
|
||||||
|
assert data["status"] == "ready"
|
||||||
|
|
||||||
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
||||||
|
# When using `Engine` as a Python API, we don't want to block here.
|
||||||
|
return
|
||||||
|
|
||||||
|
for proc in scheduler_procs:
|
||||||
|
proc.join()
|
||||||
|
logger.error(
|
||||||
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Launch detokenizer process
|
# Launch detokenizer process
|
||||||
detoken_proc = mp.Process(
|
detoken_proc = mp.Process(
|
||||||
@@ -597,14 +610,7 @@ def launch_server(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Update logging configs
|
# Update logging configs
|
||||||
LOGGING_CONFIG["formatters"]["default"][
|
set_uvicorn_logging_configs()
|
||||||
"fmt"
|
|
||||||
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
|
||||||
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
|
||||||
LOGGING_CONFIG["formatters"]["access"][
|
|
||||||
"fmt"
|
|
||||||
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
|
||||||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
|
||||||
|
|
||||||
# Listen for HTTP requests
|
# Listen for HTTP requests
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ from triton.runtime.cache import (
|
|||||||
default_dump_dir,
|
default_dump_dir,
|
||||||
default_override_dir,
|
default_override_dir,
|
||||||
)
|
)
|
||||||
|
from uvicorn.config import LOGGING_CONFIG
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -1404,3 +1405,14 @@ def nullable_str(val: str):
|
|||||||
if not val or val == "None":
|
if not val or val == "None":
|
||||||
return None
|
return None
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def set_uvicorn_logging_configs():
|
||||||
|
LOGGING_CONFIG["formatters"]["default"][
|
||||||
|
"fmt"
|
||||||
|
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
||||||
|
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||||
|
LOGGING_CONFIG["formatters"]["access"][
|
||||||
|
"fmt"
|
||||||
|
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
||||||
|
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|||||||
Reference in New Issue
Block a user