Add a watch dog thread (#1816)

This commit is contained in:
Lianmin Zheng
2024-10-27 02:00:50 -07:00
committed by GitHub
parent 1be853ee69
commit 86fc0d79d0
34 changed files with 99 additions and 56 deletions

View File

@@ -550,4 +550,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()

View File

@@ -15,7 +15,6 @@ import dataclasses
import itertools
import json
import multiprocessing
import os
import time
from typing import Tuple
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
def launch_server_process(server_args: ServerArgs):
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally:
if proc:
kill_child_process(proc.pid)
kill_child_process(proc.pid, include_self=True)
print(f"\nResults are saved to {bench_args.result_filename}")

View File

@@ -15,4 +15,4 @@ if __name__ == "__main__":
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()

View File

@@ -18,6 +18,7 @@ limitations under the License.
import json
import logging
import os
import threading
import time
import warnings
from collections import deque
@@ -222,10 +223,11 @@ class Scheduler:
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
@@ -272,6 +274,11 @@ class Scheduler:
self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
@@ -289,6 +296,23 @@ class Scheduler:
with_stack=True,
)
def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)
kill_parent_process()
@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
@@ -299,6 +323,7 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
@@ -746,6 +771,8 @@ class Scheduler:
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
@@ -778,6 +805,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids, bid = result
@@ -890,8 +918,8 @@ class Scheduler:
self.token_to_kv_pool.free_group_end()
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
self.print_decode_stats()
def add_logprob_return_values(
@@ -984,7 +1012,7 @@ class Scheduler:
else: # embedding or reward model
output_embeddings = []
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
for req in reqs:
if req.finished() or (

View File

@@ -441,7 +441,7 @@ def launch_server(
# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
)
t.start()
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
model_info = res.json()
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return
# logger.info(f"{res.json()=}")
@@ -617,7 +617,7 @@ class Runtime:
def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid)
kill_child_process(self.pid, include_self=True)
self.pid = None
def cache_prefix(self, prefix: str):
@@ -834,7 +834,7 @@ class Engine:
return ret
def shutdown(self):
kill_child_process(os.getpid(), including_parent=False)
kill_child_process(include_self=True)
def get_tokenizer(self):
global tokenizer_manager

View File

@@ -74,6 +74,7 @@ class ServerArgs:
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600
# Data parallelism
dp_size: int = 1
@@ -429,6 +430,12 @@ class ServerArgs:
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)
# Data parallelism
parser.add_argument(

View File

@@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
def kill_child_process(pid, including_parent=True, skip_pid=None):
"""Kill the process and all its children process."""
kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
parent = psutil.Process(pid)
current_process.kill()
except psutil.NoSuchProcess:
pass
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
try:
itself = psutil.Process(pid)
except psutil.NoSuchProcess:
return
children = parent.children(recursive=True)
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except psutil.NoSuchProcess:
pass
if including_parent:
if include_self:
try:
parent.kill()
itself.kill()
except psutil.NoSuchProcess:
pass

View File

@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
)
assert ret_code == 0
except TimeoutError:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
@@ -563,7 +563,7 @@ def run_bench_serving(
try:
res = run_benchmark(args)
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
assert res["completed"] == num_prompts
return res
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
return output_throughput
@@ -707,8 +707,8 @@ def run_mmlu_test(
pass
# Clean up everything
kill_child_process(process.pid)
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
kill_child_process(process.pid, include_self=True)
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):

View File

@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(
self,

View File

@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(

View File

@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)

View File

@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(

View File

@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
cls.stdout.close()
cls.stderr.close()
os.remove("stdout.txt")

View File

@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_completions_generation(
self,

View File

@@ -25,7 +25,7 @@ class TestMLA(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -31,7 +31,7 @@ class TestMLA(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mgsm_en(self):
args = SimpleNamespace(

View File

@@ -35,7 +35,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -36,7 +36,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
def tearDown(self):
if self.process:
kill_child_process(self.process.pid)
kill_child_process(self.process.pid, include_self=True)
def launch_server(self, model, is_fp8, is_tp2):
other_args = ["--log-level-http", "warning", "--trust-remote-code"]

View File

@@ -31,7 +31,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_completion(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input

View File

@@ -27,7 +27,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -22,7 +22,7 @@ class TestRetractDecode(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -26,7 +26,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32

View File

@@ -27,7 +27,7 @@ class TestSRTEndpoint(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(
self,

View File

@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(

View File

@@ -50,7 +50,7 @@ class TestTritonAttnBackend(unittest.TestCase):
metrics = run_eval(args)
assert metrics["score"] >= 0.65
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
if __name__ == "__main__":

View File

@@ -23,7 +23,7 @@ class TestUpdateWeights(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def run_decode(self):
response = requests.post(

View File

@@ -45,7 +45,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
def test_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)