diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9f6aa68ab..ac19d9370 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -460,24 +460,25 @@ def run_with_timeout( return ret_value[0] +def run_one_file(filename, out_queue): + print(f"\n\nRun {filename}\n\n") + ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) + + def run_unittest_files(files: List[str], timeout_per_file: float): tic = time.time() success = True for filename in files: + out_queue = multiprocessing.Queue() + p = multiprocessing.Process(target=run_one_file, args=(filename, out_queue)) - def func(): - print(f"\n\nRun {filename}\n\n") - ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) - - p = multiprocessing.Process(target=func) - - def run_one_file(): + def run_process(): p.start() p.join() try: - run_with_timeout(run_one_file, timeout=timeout_per_file) + run_with_timeout(run_process, timeout=timeout_per_file) if p.exitcode != 0: success = False break diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index ecb3e7576..8a43255b7 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -71,4 +72,9 @@ class TestEmbeddingModels(unittest.TestCase): if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + unittest.main(warnings="ignore") diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 7e7e401d2..4e49c0a5b 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -108,13 +109,6 @@ class TestGenerationModels(unittest.TestCase): ), f"Not all ROUGE-L scores are greater than {rouge_threshold}" def test_prefill_logits_and_output_strs(self): - import multiprocessing as mp - - try: - mp.set_start_method("spawn") - except RuntimeError: - pass - for ( model, tp_size, @@ -137,4 +131,9 @@ class TestGenerationModels(unittest.TestCase): if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e8edbb550..3756d3ddf 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -1,5 +1,6 @@ import argparse import glob +import multiprocessing as mp from sglang.test.test_utils import run_unittest_files @@ -54,5 +55,10 @@ if __name__ == "__main__": else: files = suites[args.suite] + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + exit_code = run_unittest_files(files, args.timeout_per_file) exit(exit_code)