Add a watch dog thread (#1816)

This commit is contained in:
Lianmin Zheng
2024-10-27 02:00:50 -07:00
committed by GitHub
parent 1be853ee69
commit 86fc0d79d0
34 changed files with 99 additions and 56 deletions

View File

@@ -550,4 +550,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()

View File

@@ -15,7 +15,6 @@ import dataclasses
import itertools
import json
import multiprocessing
import os
import time
from typing import Tuple
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
def launch_server_process(server_args: ServerArgs):
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally:
if proc:
kill_child_process(proc.pid)
kill_child_process(proc.pid, include_self=True)
print(f"\nResults are saved to {bench_args.result_filename}")

View File

@@ -15,4 +15,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()

View File

@@ -18,6 +18,7 @@ limitations under the License.
import json
import logging
import os
import threading
import time
import warnings
from collections import deque
@@ -222,10 +223,11 @@ class Scheduler:
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
@@ -272,6 +274,11 @@ class Scheduler:
self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
@@ -289,6 +296,23 @@ class Scheduler:
with_stack=True,
)
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
kill_parent_process()
@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
@@ -299,6 +323,7 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
@@ -746,6 +771,8 @@ class Scheduler:
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
@@ -778,6 +805,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids, bid = result
@@ -890,8 +918,8 @@ class Scheduler:
self.token_to_kv_pool.free_group_end()
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
self.print_decode_stats()
def add_logprob_return_values(
@@ -984,7 +1012,7 @@ class Scheduler:
else: # embedding or reward model
output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs:
if req.finished() or (

View File

@@ -441,7 +441,7 @@ def launch_server(
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
)
t.start()
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
model_info = res.json()
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
# logger.info(f"{res.json()=}")
@@ -617,7 +617,7 @@ class Runtime:
def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid)
kill_child_process(self.pid, include_self=True)
self.pid = None
def cache_prefix(self, prefix: str):
@@ -834,7 +834,7 @@ class Engine:
return ret
def shutdown(self):
kill_child_process(os.getpid(), including_parent=False)
kill_child_process(include_self=True)
def get_tokenizer(self):
global tokenizer_manager

View File

@@ -74,6 +74,7 @@ class ServerArgs:
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism
dp_size: int = 1
@@ -429,6 +430,12 @@ class ServerArgs:
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Data parallelism
parser.add_argument(

View File

@@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
def kill_child_process(pid, including_parent=True, skip_pid=None):
"""Kill the process and all its children process."""
kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
parent = psutil.Process(pid)
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try:
itself = psutil.Process(pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except psutil.NoSuchProcess:
pass
if including_parent:
if include_self:
try:
parent.kill()
itself.kill()
except psutil.NoSuchProcess:
pass

View File

@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
)
assert ret_code == 0
except TimeoutError:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
@@ -563,7 +563,7 @@ def run_bench_serving(
try:
res = run_benchmark(args)
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
assert res["completed"] == num_prompts
return res
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
return output_throughput
@@ -707,8 +707,8 @@ def run_mmlu_test(
pass
# Clean up everything
kill_child_process(process.pid)
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
kill_child_process(process.pid, include_self=True)
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):