Improve multi-node stability (#1171)
This commit is contained in:
@@ -24,7 +24,6 @@ import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
@@ -301,27 +300,29 @@ def launch_server(
|
||||
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
||||
|
||||
# Launch processes for multi-node tensor parallelism
|
||||
if server_args.nnodes > 1:
|
||||
if server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [
|
||||
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
|
||||
]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
while True:
|
||||
pass
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
model_overide_args,
|
||||
)
|
||||
|
||||
try:
|
||||
for p in procs:
|
||||
p.join()
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
||||
@@ -356,15 +357,11 @@ def launch_server(
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
print(
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}",
|
||||
flush=True,
|
||||
raise RuntimeError(
|
||||
"Initialization failed. "
|
||||
f"controller_init_state: {controller_init_state}, "
|
||||
f"detoken_init_state: {detoken_init_state}"
|
||||
)
|
||||
print(
|
||||
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
||||
flush=True,
|
||||
)
|
||||
sys.exit(1)
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
|
||||
# Add api key authorization
|
||||
@@ -373,12 +370,12 @@ def launch_server(
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
||||
)
|
||||
t.start()
|
||||
|
||||
# Listen for requests
|
||||
try:
|
||||
# Listen for requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
if not success:
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
||||
sys.exit(1)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
return
|
||||
|
||||
# Send a warmup request
|
||||
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
||||
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
timeout=600,
|
||||
)
|
||||
assert res.status_code == 200, f"{res}"
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
last_traceback = get_exception_traceback()
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
||||
sys.exit(1)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
return
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
|
||||
Reference in New Issue
Block a user