Improve multi-node stability (#1171)

This commit is contained in:
Lianmin Zheng
2024-08-20 22:35:05 -07:00
committed by GitHub
parent cd10654e7e
commit bea2bb9eea
11 changed files with 94 additions and 76 deletions

View File

@@ -24,7 +24,6 @@ import json
import logging
import multiprocessing as mp
import os
import sys
import threading
import time
from http import HTTPStatus
@@ -301,27 +300,29 @@ def launch_server(
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
# Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1:
if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
if server_args.nnodes > 1 and server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
)
while True:
pass
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
)
try:
for p in procs:
p.join()
finally:
kill_child_process(os.getpid(), including_parent=False)
return
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
@@ -356,15 +357,11 @@ def launch_server(
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
print(
f"Initialization failed. controller_init_state: {controller_init_state}",
flush=True,
raise RuntimeError(
"Initialization failed. "
f"controller_init_state: {controller_init_state}, "
f"detoken_init_state: {detoken_init_state}"
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive()
# Add api key authorization
@@ -373,12 +370,12 @@ def launch_server(
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
)
t.start()
# Listen for requests
try:
# Listen for requests
uvicorn.run(
app,
host=server_args.host,
@@ -426,7 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
)
def _wait_and_warmup(server_args, pipe_finish_writer):
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {}
url = server_args.url()
if server_args.api_key:
@@ -449,8 +446,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
@@ -475,12 +473,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout=600,
)
assert res.status_code == 200, f"{res}"
except Exception as e:
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: