Crash the server correctly during error (#2231)

This commit is contained in:
Lianmin Zheng
2024-11-28 00:22:39 -08:00
committed by GitHub
parent db674e3d24
commit d4fc1a70e3
46 changed files with 147 additions and 139 deletions

View File

@@ -23,6 +23,8 @@ import json
import logging
import multiprocessing as mp
import os
import signal
import sys
import threading
import time
from http import HTTPStatus
@@ -79,7 +81,7 @@ from sglang.srt.utils import (
configure_logger,
delete_directory,
is_port_available,
kill_child_process,
kill_process_tree,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
set_prometheus_multiproc_dir,
@@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)
# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())
signal.signal(signal.SIGQUIT, sigquit_handler)
# Set mp start method
mp.set_start_method("spawn", force=True)
@@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True)
kill_process_tree(os.getpid())
return
model_info = res.json()
@@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True)
kill_process_tree(os.getpid())
return
# logger.info(f"{res.json()=}")
@@ -700,7 +711,7 @@ class Runtime:
def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid, include_self=True)
kill_process_tree(self.pid)
self.pid = None
def cache_prefix(self, prefix: str):
@@ -924,7 +935,7 @@ class Engine:
return ret
def shutdown(self):
kill_child_process()
kill_process_tree(os.getpid(), include_parent=False)
def get_tokenizer(self):
global tokenizer_manager