Improve process creation (#1534)
This commit is contained in:
@@ -43,20 +43,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.controller_multi import (
|
||||
start_controller_process as start_controller_process_multi,
|
||||
)
|
||||
from sglang.srt.managers.controller_single import launch_tp_servers
|
||||
from sglang.srt.managers.controller_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
RewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
@@ -82,8 +76,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model,
|
||||
prepare_tokenizer,
|
||||
prepare_model_and_tokenizer,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -303,8 +296,8 @@ def launch_server(
|
||||
"""Launch an HTTP server."""
|
||||
global tokenizer_manager
|
||||
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
|
||||
server_args.check_server_args()
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
@@ -317,81 +310,60 @@ def launch_server(
|
||||
ports = server_args.additional_ports
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
controller_port=ports[1],
|
||||
scheduler_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Use model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path = prepare_model(server_args.model_path)
|
||||
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
||||
|
||||
# Launch processes for multi-node tensor parallelism
|
||||
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
)
|
||||
|
||||
try:
|
||||
for p in procs:
|
||||
p.join()
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_controller_process = start_controller_process_single
|
||||
else:
|
||||
start_controller_process = start_controller_process_multi
|
||||
proc_controller = mp.Process(
|
||||
target=start_controller_process,
|
||||
args=(server_args, port_args, pipe_controller_writer),
|
||||
# If using model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
proc_controller.start()
|
||||
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
while True:
|
||||
pass
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
target=run_detokenizer_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_detoken_writer,
|
||||
),
|
||||
)
|
||||
proc_detoken.start()
|
||||
detoken_proc.start()
|
||||
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. "
|
||||
f"controller_init_state: {controller_init_state}, "
|
||||
f"detoken_init_state: {detoken_init_state}"
|
||||
)
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
# Wait for model to finish loading
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
@@ -404,7 +376,7 @@ def launch_server(
|
||||
t.start()
|
||||
|
||||
try:
|
||||
# Listen for requests
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -451,9 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
# to figure out a better method of not using fork later
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
@@ -517,7 +487,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
|
||||
class Runtime:
|
||||
@@ -564,7 +534,7 @@ class Runtime:
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "init ok":
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
|
||||
Reference in New Issue
Block a user