From 86fc0d79d0b564fba1c313feafd15323ba731418 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 27 Oct 2024 02:00:50 -0700 Subject: [PATCH] Add a watch dog thread (#1816) --- python/sglang/bench_latency.py | 2 +- python/sglang/bench_server_latency.py | 5 +-- python/sglang/launch_server.py | 2 +- python/sglang/srt/managers/scheduler.py | 38 ++++++++++++++++--- python/sglang/srt/server.py | 12 +++--- python/sglang/srt/server_args.py | 7 ++++ python/sglang/srt/utils.py | 27 ++++++++----- python/sglang/test/test_utils.py | 10 ++--- .../test_srt_endpoint_with_penalizers.py | 2 +- test/srt/test_cache_report.py | 2 +- test/srt/test_data_parallelism.py | 2 +- test/srt/test_double_sparsity.py | 2 +- test/srt/test_embedding_openai_server.py | 2 +- test/srt/test_eval_accuracy_large.py | 2 +- ...est_eval_accuracy_large_chunked_prefill.py | 2 +- ...al_accuracy_large_mixed_chunked_prefill.py | 2 +- test/srt/test_eval_accuracy_mini.py | 2 +- test/srt/test_json_constrained.py | 2 +- test/srt/test_large_max_new_tokens.py | 2 +- test/srt/test_matched_stop.py | 2 +- test/srt/test_mla.py | 2 +- test/srt/test_mla_fp8.py | 2 +- test/srt/test_moe_eval_accuracy_large.py | 2 +- test/srt/test_nightly_gsm8k_eval.py | 2 +- test/srt/test_openai_server.py | 2 +- test/srt/test_pytorch_sampling_backend.py | 2 +- test/srt/test_retract_decode.py | 2 +- test/srt/test_skip_tokenizer_init.py | 2 +- test/srt/test_srt_endpoint.py | 2 +- test/srt/test_torch_compile.py | 2 +- test/srt/test_torchao.py | 2 +- test/srt/test_triton_attn_backend.py | 2 +- test/srt/test_update_weights.py | 2 +- test/srt/test_vision_openai_server.py | 2 +- 34 files changed, 99 insertions(+), 56 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 43cb7bc3f..d97b641ea 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -550,4 +550,4 @@ if __name__ == "__main__": except Exception as e: raise e finally: - kill_child_process(os.getpid(), including_parent=False) + kill_child_process() diff --git a/python/sglang/bench_server_latency.py b/python/sglang/bench_server_latency.py index 57506913f..f76682c9f 100644 --- a/python/sglang/bench_server_latency.py +++ b/python/sglang/bench_server_latency.py @@ -15,7 +15,6 @@ import dataclasses import itertools import json import multiprocessing -import os import time from typing import Tuple @@ -70,7 +69,7 @@ def launch_server_internal(server_args): except Exception as e: raise e finally: - kill_child_process(os.getpid(), including_parent=False) + kill_child_process() def launch_server_process(server_args: ServerArgs): @@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ) finally: if proc: - kill_child_process(proc.pid) + kill_child_process(proc.pid, include_self=True) print(f"\nResults are saved to {bench_args.result_filename}") diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index ce4cb07c2..57f1dd10e 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -15,4 +15,4 @@ if __name__ == "__main__": except Exception as e: raise e finally: - kill_child_process(os.getpid(), including_parent=False) + kill_child_process() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4677568c4..f876847e1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -18,6 +18,7 @@ limitations under the License. import json import logging import os +import threading import time import warnings from collections import deque @@ -222,10 +223,11 @@ class Scheduler: self.waiting_queue: List[Req] = [] self.running_batch: Optional[ScheduleBatch] = None self.cur_batch: Optional[ScheduleBatch] = None - self.decode_forward_ct = 0 - self.stream_interval = server_args.stream_interval + self.forward_ct = 0 + self.forward_ct_decode = 0 self.num_generated_tokens = 0 self.last_stats_tic = time.time() + self.stream_interval = server_args.stream_interval # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size @@ -272,6 +274,11 @@ class Scheduler: self.batch_is_full = False + # Init watchdog thread + self.watchdog_timeout = server_args.watchdog_timeout + t = threading.Thread(target=self.watchdog_thread, daemon=True) + t.start() + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None @@ -289,6 +296,23 @@ class Scheduler: with_stack=True, ) + def watchdog_thread(self): + self.watchdog_last_forward_ct = 0 + self.watchdog_last_time = time.time() + + while True: + if self.cur_batch is not None: + if self.watchdog_last_forward_ct == self.forward_ct: + if time.time() > self.watchdog_last_time + self.watchdog_timeout: + logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") + break + else: + self.watchdog_last_forward_ct = self.forward_ct + self.watchdog_last_time = time.time() + time.sleep(self.watchdog_timeout / 2) + + kill_parent_process() + @torch.inference_mode() def event_loop_normal(self): """A normal blocking scheduler loop.""" @@ -299,6 +323,7 @@ class Scheduler: self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() + self.cur_batch = batch if batch: result = self.run_batch(batch) @@ -746,6 +771,8 @@ class Scheduler: def run_batch(self, batch: ScheduleBatch): """Run a batch.""" + self.forward_ct += 1 + if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: model_worker_batch = batch.get_model_worker_batch() @@ -778,6 +805,7 @@ class Scheduler: self.process_batch_result_prefill(batch, result) def process_batch_result_prefill(self, batch: ScheduleBatch, result): + if self.is_generation: logits_output, next_token_ids, bid = result @@ -890,8 +918,8 @@ class Scheduler: self.token_to_kv_pool.free_group_end() - self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) - if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) + if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0: self.print_decode_stats() def add_logprob_return_values( @@ -984,7 +1012,7 @@ class Scheduler: else: # embedding or reward model output_embeddings = [] - is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 + is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 for req in reqs: if req.finished() or ( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 8912c5583..7dffcc4a2 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -441,7 +441,7 @@ def launch_server( # Send a warmup request t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid()) + target=_wait_and_warmup, args=(server_args, pipe_finish_writer) ) t.start() @@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) -def _wait_and_warmup(server_args, pipe_finish_writer, pid): +def _wait_and_warmup(server_args, pipe_finish_writer): headers = {} url = server_args.url() if server_args.api_key: @@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_child_process(pid, including_parent=False) + kill_child_process(include_self=True) return model_info = res.json() @@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_child_process(pid, including_parent=False) + kill_child_process(include_self=True) return # logger.info(f"{res.json()=}") @@ -617,7 +617,7 @@ class Runtime: def shutdown(self): if self.pid is not None: - kill_child_process(self.pid) + kill_child_process(self.pid, include_self=True) self.pid = None def cache_prefix(self, prefix: str): @@ -834,7 +834,7 @@ class Engine: return ret def shutdown(self): - kill_child_process(os.getpid(), including_parent=False) + kill_child_process(include_self=True) def get_tokenizer(self): global tokenizer_manager diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 753debb66..7d23cb8bd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -74,6 +74,7 @@ class ServerArgs: api_key: Optional[str] = None file_storage_pth: str = "SGLang_storage" enable_cache_report: bool = False + watchdog_timeout: float = 600 # Data parallelism dp_size: int = 1 @@ -429,6 +430,12 @@ class ServerArgs: action="store_true", help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.", ) + parser.add_argument( + "--watchdog-timeout", + type=float, + default=ServerArgs.watchdog_timeout, + help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", + ) # Data parallelism parser.add_argument( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6ad39647f..2be3a298e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -398,17 +398,26 @@ def kill_parent_process(): """Kill the parent process and all children of the parent process.""" current_process = psutil.Process() parent_process = current_process.parent() - kill_child_process(parent_process.pid, skip_pid=current_process.pid) - - -def kill_child_process(pid, including_parent=True, skip_pid=None): - """Kill the process and all its children process.""" + kill_child_process( + parent_process.pid, include_self=True, skip_pid=current_process.pid + ) try: - parent = psutil.Process(pid) + current_process.kill() + except psutil.NoSuchProcess: + pass + + +def kill_child_process(pid=None, include_self=False, skip_pid=None): + """Kill the process and all its children process.""" + if pid is None: + pid = os.getpid() + + try: + itself = psutil.Process(pid) except psutil.NoSuchProcess: return - children = parent.children(recursive=True) + children = itself.children(recursive=True) for child in children: if child.pid == skip_pid: continue @@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None): except psutil.NoSuchProcess: pass - if including_parent: + if include_self: try: - parent.kill() + itself.kill() except psutil.NoSuchProcess: pass diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4a5a894c0..d6a4c1a29 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -495,7 +495,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): ) assert ret_code == 0 except TimeoutError: - kill_child_process(process.pid) + kill_child_process(process.pid, include_self=True) time.sleep(5) print( f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", @@ -563,7 +563,7 @@ def run_bench_serving( try: res = run_benchmark(args) finally: - kill_child_process(process.pid) + kill_child_process(process.pid, include_self=True) assert res["completed"] == num_prompts return res @@ -596,7 +596,7 @@ def run_bench_latency(model, other_args): lastline = output.split("\n")[-3] output_throughput = float(lastline.split(" ")[-2]) finally: - kill_child_process(process.pid) + kill_child_process(process.pid, include_self=True) return output_throughput @@ -707,8 +707,8 @@ def run_mmlu_test( pass # Clean up everything - kill_child_process(process.pid) - kill_child_process(process.pid) + kill_child_process(process.pid, include_self=True) + kill_child_process(process.pid, include_self=True) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index e3496102c..689d52a1c 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -31,7 +31,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode( self, diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py index 1d8e9a4a0..dfc140d58 100644 --- a/test/srt/test_cache_report.py +++ b/test/srt/test_cache_report.py @@ -45,7 +45,7 @@ class TestCacheReport(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): response = requests.post( diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index a921a6b57..5f17994a2 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -25,7 +25,7 @@ class TestDataParallelism(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_double_sparsity.py b/test/srt/test_double_sparsity.py index 0f2f572eb..14ee4de3c 100644 --- a/test/srt/test_double_sparsity.py +++ b/test/srt/test_double_sparsity.py @@ -43,7 +43,7 @@ class TestDoubleSparsity(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py index 45f7850da..666297c65 100644 --- a/test/srt/test_embedding_openai_server.py +++ b/test/srt/test_embedding_openai_server.py @@ -28,7 +28,7 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_embedding(self, use_list_input, token_input): client = openai.Client(api_key=self.api_key, base_url=self.base_url) diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index 0b95f435c..000910cf2 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -30,7 +30,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_large_chunked_prefill.py b/test/srt/test_eval_accuracy_large_chunked_prefill.py index 02df2a7f5..2e9ff59cd 100644 --- a/test/srt/test_eval_accuracy_large_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_chunked_prefill.py @@ -25,7 +25,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py index 8ba71e5c8..0fb08e64f 100644 --- a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py @@ -31,7 +31,7 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index ee977a636..fa15c1181 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -22,7 +22,7 @@ class TestEvalAccuracyMini(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index c054d7234..88368fba8 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -41,7 +41,7 @@ class TestJSONConstrained(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): response = requests.post( diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 24c011c75..ea9c20e5c 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -42,7 +42,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) cls.stdout.close() cls.stderr.close() os.remove("stdout.txt") diff --git a/test/srt/test_matched_stop.py b/test/srt/test_matched_stop.py index a3399687d..df37fa13c 100644 --- a/test/srt/test_matched_stop.py +++ b/test/srt/test_matched_stop.py @@ -32,7 +32,7 @@ class TestMatchedStop(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_completions_generation( self, diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 13b0aa2d8..796655adb 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -25,7 +25,7 @@ class TestMLA(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py index 37275d696..5091759a9 100644 --- a/test/srt/test_mla_fp8.py +++ b/test/srt/test_mla_fp8.py @@ -31,7 +31,7 @@ class TestMLA(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mgsm_en(self): args = SimpleNamespace( diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 5f2560526..401a47ce2 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -35,7 +35,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 8b8e0e16b..b035db52b 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -36,7 +36,7 @@ class TestEvalAccuracyLarge(unittest.TestCase): def tearDown(self): if self.process: - kill_child_process(self.process.pid) + kill_child_process(self.process.pid, include_self=True) def launch_server(self, model, is_fp8, is_tp2): other_args = ["--log-level-http", "warning", "--trust-remote-code"] diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 5afe9b0b1..d3e21d04b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -31,7 +31,7 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_completion( self, echo, logprobs, use_list_input, parallel_sample_num, token_input diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 5dbb9ae2b..f7affa8ac 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -27,7 +27,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py index b16fd5163..20352e729 100644 --- a/test/srt/test_retract_decode.py +++ b/test/srt/test_retract_decode.py @@ -22,7 +22,7 @@ class TestRetractDecode(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index 3a8c34c16..a5dcde4a2 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -26,7 +26,7 @@ class TestSkipTokenizerInit(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): max_new_tokens = 32 diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index c4c8e844d..e1b5318c0 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -27,7 +27,7 @@ class TestSRTEndpoint(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode( self, diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 40f47d6b6..f5f4b602e 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py index 8b5ce58ed..765567136 100644 --- a/test/srt/test_torchao.py +++ b/test/srt/test_torchao.py @@ -27,7 +27,7 @@ class TestTorchCompile(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index 55df1951f..2a6fe17bd 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -50,7 +50,7 @@ class TestTritonAttnBackend(unittest.TestCase): metrics = run_eval(args) assert metrics["score"] >= 0.65 finally: - kill_child_process(process.pid) + kill_child_process(process.pid, include_self=True) if __name__ == "__main__": diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index 73c3cc706..c3cde0f14 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -23,7 +23,7 @@ class TestUpdateWeights(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def run_decode(self): response = requests.post( diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index bf8f9d277..f44bc98e2 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -45,7 +45,7 @@ class TestOpenAIVisionServer(unittest.TestCase): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid) + kill_child_process(cls.process.pid, include_self=True) def test_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)