Add a watch dog thread (#1816)
This commit is contained in:
@@ -550,4 +550,4 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
kill_child_process()
|
||||
|
||||
@@ -15,7 +15,6 @@ import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
kill_child_process()
|
||||
|
||||
|
||||
def launch_server_process(server_args: ServerArgs):
|
||||
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
)
|
||||
finally:
|
||||
if proc:
|
||||
kill_child_process(proc.pid)
|
||||
kill_child_process(proc.pid, include_self=True)
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
|
||||
@@ -15,4 +15,4 @@ if __name__ == "__main__":
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
kill_child_process()
|
||||
|
||||
@@ -18,6 +18,7 @@ limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections import deque
|
||||
@@ -222,10 +223,11 @@ class Scheduler:
|
||||
self.waiting_queue: List[Req] = []
|
||||
self.running_batch: Optional[ScheduleBatch] = None
|
||||
self.cur_batch: Optional[ScheduleBatch] = None
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = server_args.stream_interval
|
||||
self.forward_ct = 0
|
||||
self.forward_ct_decode = 0
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
self.stream_interval = server_args.stream_interval
|
||||
|
||||
# Init chunked prefill
|
||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||
@@ -272,6 +274,11 @@ class Scheduler:
|
||||
|
||||
self.batch_is_full = False
|
||||
|
||||
# Init watchdog thread
|
||||
self.watchdog_timeout = server_args.watchdog_timeout
|
||||
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
||||
t.start()
|
||||
|
||||
# Init profiler
|
||||
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
||||
self.profiler = None
|
||||
@@ -289,6 +296,23 @@ class Scheduler:
|
||||
with_stack=True,
|
||||
)
|
||||
|
||||
def watchdog_thread(self):
|
||||
self.watchdog_last_forward_ct = 0
|
||||
self.watchdog_last_time = time.time()
|
||||
|
||||
while True:
|
||||
if self.cur_batch is not None:
|
||||
if self.watchdog_last_forward_ct == self.forward_ct:
|
||||
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
|
||||
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
|
||||
break
|
||||
else:
|
||||
self.watchdog_last_forward_ct = self.forward_ct
|
||||
self.watchdog_last_time = time.time()
|
||||
time.sleep(self.watchdog_timeout / 2)
|
||||
|
||||
kill_parent_process()
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
@@ -299,6 +323,7 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
@@ -746,6 +771,8 @@ class Scheduler:
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
"""Run a batch."""
|
||||
self.forward_ct += 1
|
||||
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -778,6 +805,7 @@ class Scheduler:
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
|
||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||
|
||||
if self.is_generation:
|
||||
logits_output, next_token_ids, bid = result
|
||||
|
||||
@@ -890,8 +918,8 @@ class Scheduler:
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
|
||||
self.print_decode_stats()
|
||||
|
||||
def add_logprob_return_values(
|
||||
@@ -984,7 +1012,7 @@ class Scheduler:
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
|
||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||
|
||||
for req in reqs:
|
||||
if req.finished() or (
|
||||
|
||||
@@ -441,7 +441,7 @@ def launch_server(
|
||||
|
||||
# Send a warmup request
|
||||
t = threading.Thread(
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
||||
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
||||
)
|
||||
t.start()
|
||||
|
||||
@@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer):
|
||||
headers = {}
|
||||
url = server_args.url()
|
||||
if server_args.api_key:
|
||||
@@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
kill_child_process(include_self=True)
|
||||
return
|
||||
|
||||
model_info = res.json()
|
||||
@@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send(last_traceback)
|
||||
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
||||
kill_child_process(pid, including_parent=False)
|
||||
kill_child_process(include_self=True)
|
||||
return
|
||||
|
||||
# logger.info(f"{res.json()=}")
|
||||
@@ -617,7 +617,7 @@ class Runtime:
|
||||
|
||||
def shutdown(self):
|
||||
if self.pid is not None:
|
||||
kill_child_process(self.pid)
|
||||
kill_child_process(self.pid, include_self=True)
|
||||
self.pid = None
|
||||
|
||||
def cache_prefix(self, prefix: str):
|
||||
@@ -834,7 +834,7 @@ class Engine:
|
||||
return ret
|
||||
|
||||
def shutdown(self):
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
kill_child_process(include_self=True)
|
||||
|
||||
def get_tokenizer(self):
|
||||
global tokenizer_manager
|
||||
|
||||
@@ -74,6 +74,7 @@ class ServerArgs:
|
||||
api_key: Optional[str] = None
|
||||
file_storage_pth: str = "SGLang_storage"
|
||||
enable_cache_report: bool = False
|
||||
watchdog_timeout: float = 600
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -429,6 +430,12 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--watchdog-timeout",
|
||||
type=float,
|
||||
default=ServerArgs.watchdog_timeout,
|
||||
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
|
||||
@@ -398,17 +398,26 @@ def kill_parent_process():
|
||||
"""Kill the parent process and all children of the parent process."""
|
||||
current_process = psutil.Process()
|
||||
parent_process = current_process.parent()
|
||||
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
|
||||
|
||||
|
||||
def kill_child_process(pid, including_parent=True, skip_pid=None):
|
||||
"""Kill the process and all its children process."""
|
||||
kill_child_process(
|
||||
parent_process.pid, include_self=True, skip_pid=current_process.pid
|
||||
)
|
||||
try:
|
||||
parent = psutil.Process(pid)
|
||||
current_process.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
||||
"""Kill the process and all its children process."""
|
||||
if pid is None:
|
||||
pid = os.getpid()
|
||||
|
||||
try:
|
||||
itself = psutil.Process(pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return
|
||||
|
||||
children = parent.children(recursive=True)
|
||||
children = itself.children(recursive=True)
|
||||
for child in children:
|
||||
if child.pid == skip_pid:
|
||||
continue
|
||||
@@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
if including_parent:
|
||||
if include_self:
|
||||
try:
|
||||
parent.kill()
|
||||
itself.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
|
||||
@@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
)
|
||||
assert ret_code == 0
|
||||
except TimeoutError:
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
time.sleep(5)
|
||||
print(
|
||||
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
|
||||
@@ -563,7 +563,7 @@ def run_bench_serving(
|
||||
try:
|
||||
res = run_benchmark(args)
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
assert res["completed"] == num_prompts
|
||||
return res
|
||||
@@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
|
||||
lastline = output.split("\n")[-3]
|
||||
output_throughput = float(lastline.split(" ")[-2])
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
return output_throughput
|
||||
|
||||
@@ -707,8 +707,8 @@ def run_mmlu_test(
|
||||
pass
|
||||
|
||||
# Clean up everything
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
stdout.close()
|
||||
stderr.close()
|
||||
if os.path.exists(STDOUT_FILENAME):
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(
|
||||
self,
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_embedding(self, use_list_input, token_input):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
response = requests.post(
|
||||
|
||||
@@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
cls.stdout.close()
|
||||
cls.stderr.close()
|
||||
os.remove("stdout.txt")
|
||||
|
||||
@@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_completions_generation(
|
||||
self,
|
||||
|
||||
@@ -25,7 +25,7 @@ class TestMLA(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestMLA(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -35,7 +35,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -36,7 +36,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
if self.process:
|
||||
kill_child_process(self.process.pid)
|
||||
kill_child_process(self.process.pid, include_self=True)
|
||||
|
||||
def launch_server(self, model, is_fp8, is_tp2):
|
||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_completion(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -22,7 +22,7 @@ class TestRetractDecode(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
max_new_tokens = 32
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(
|
||||
self,
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestTritonAttnBackend(unittest.TestCase):
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestUpdateWeights(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def run_decode(self):
|
||||
response = requests.post(
|
||||
|
||||
@@ -45,7 +45,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
kill_child_process(cls.process.pid, include_self=True)
|
||||
|
||||
def test_chat_completion(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
Reference in New Issue
Block a user