[CI] Fix the issue of unit test hanging (#1211)

This commit is contained in:
Ying Sheng
2024-08-25 16:21:37 -07:00
committed by GitHub
parent ab4990e4bf
commit 308d024092
4 changed files with 27 additions and 15 deletions

View File

@@ -460,24 +460,25 @@ def run_with_timeout(
return ret_value[0] return ret_value[0]
def run_one_file(filename, out_queue):
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
def run_unittest_files(files: List[str], timeout_per_file: float): def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time() tic = time.time()
success = True success = True
for filename in files: for filename in files:
out_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=run_one_file, args=(filename, out_queue))
def func(): def run_process():
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
p.start() p.start()
p.join() p.join()
try: try:
run_with_timeout(run_one_file, timeout=timeout_per_file) run_with_timeout(run_process, timeout=timeout_per_file)
if p.exitcode != 0: if p.exitcode != 0:
success = False success = False
break break

View File

@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import multiprocessing as mp
import unittest import unittest
import torch import torch
@@ -71,4 +72,9 @@ class TestEmbeddingModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore") unittest.main(warnings="ignore")

View File

@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import multiprocessing as mp
import unittest import unittest
import torch import torch
@@ -108,13 +109,6 @@ class TestGenerationModels(unittest.TestCase):
), f"Not all ROUGE-L scores are greater than {rouge_threshold}" ), f"Not all ROUGE-L scores are greater than {rouge_threshold}"
def test_prefill_logits_and_output_strs(self): def test_prefill_logits_and_output_strs(self):
import multiprocessing as mp
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
for ( for (
model, model,
tp_size, tp_size,
@@ -137,4 +131,9 @@ class TestGenerationModels(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore") unittest.main(warnings="ignore")

View File

@@ -1,5 +1,6 @@
import argparse import argparse
import glob import glob
import multiprocessing as mp
from sglang.test.test_utils import run_unittest_files from sglang.test.test_utils import run_unittest_files
@@ -54,5 +55,10 @@ if __name__ == "__main__":
else: else:
files = suites[args.suite] files = suites[args.suite]
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
exit_code = run_unittest_files(files, args.timeout_per_file) exit_code = run_unittest_files(files, args.timeout_per_file)
exit(exit_code) exit(exit_code)