From 33b54e7c40a58d64c988c067e3e607f96a04ae58 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 14 Sep 2024 23:15:30 +0800 Subject: [PATCH] Add pytorch sampling backend ut (#1425) --- test/srt/run_suite.py | 1 + test/srt/test_pytorch_sampling_backend.py | 44 +++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100644 test/srt/test_pytorch_sampling_backend.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d5982844c..943c50144 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ suites = { "test_skip_tokenizer_init.py", "test_torch_compile.py", "test_triton_attn_backend.py", + "test_pytorch_sampling_backend.py", "test_update_weights.py", "test_vision_openai_server.py", "test_server_args.py", diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py new file mode 100644 index 000000000..ddd744149 --- /dev/null +++ b/test/srt/test_pytorch_sampling_backend.py @@ -0,0 +1,44 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestPyTorchSamplingBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--sampling-backend", "pytorch"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + +if __name__ == "__main__": + unittest.main()