From 13f1357ef000e0d9dcf6f13dd178d809126f3ac7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 22 Sep 2024 02:21:05 -0700 Subject: [PATCH] Add a unit test for data parallelism (#1489) --- .github/workflows/pr-test.yml | 8 +++++- test/srt/test_data_parallelism.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 test/srt/test_data_parallelism.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6173fad7a..fcd54fdba 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -228,7 +228,7 @@ jobs: cd human-eval pip install -e . - - name: Evaluate Accuracy + - name: Evaluate Accuracy (TP=2) timeout-minutes: 20 run: | cd test/srt @@ -240,6 +240,12 @@ jobs: cd test/srt python3 test_mla.py + - name: Evaluate Data Parallelism Accuracy (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_data_parallelism.py + finish: needs: [ unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py new file mode 100644 index 000000000..a921a6b57 --- /dev/null +++ b/test/srt/test_data_parallelism.py @@ -0,0 +1,44 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDataParallelism(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--dp", "2"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + +if __name__ == "__main__": + unittest.main()