From 9b4c4497356d7cd7d8211f53eefd6dd933e12117 Mon Sep 17 00:00:00 2001 From: Alex Chi Z <4198311+skyzh@users.noreply.github.com> Date: Tue, 7 Oct 2025 05:33:11 +0200 Subject: [PATCH] convert test_deterministic into unit tests (#11095) Signed-off-by: Alex Chi Z Co-authored-by: Baizhou Zhang --- python/sglang/test/test_deterministic.py | 20 ++++- .../sglang/test/test_deterministic_utils.py | 81 +++++++++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_deterministic.py | 70 ++++++++++++++++ 4 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 python/sglang/test/test_deterministic_utils.py create mode 100644 test/srt/test_deterministic.py diff --git a/python/sglang/test/test_deterministic.py b/python/sglang/test/test_deterministic.py index 3f56b6539..8c513cb6a 100644 --- a/python/sglang/test/test_deterministic.py +++ b/python/sglang/test/test_deterministic.py @@ -39,12 +39,15 @@ class BenchArgs: profile_steps: int = 3 profile_by_stage: bool = False test_mode: str = "single" + n_trials: int = 50 + n_start: int = 1 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--host", type=str, default=BenchArgs.host) parser.add_argument("--port", type=int, default=BenchArgs.port) - parser.add_argument("--n-trials", type=int, default=50) + parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials) + parser.add_argument("--n-start", type=int, default=BenchArgs.n_start) parser.add_argument("--temperature", type=float, default=BenchArgs.temperature) parser.add_argument( "--sampling-seed", type=int, default=BenchArgs.sampling_seed @@ -238,6 +241,8 @@ def test_deterministic(args): texts.append(text) print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}") + return [len(set(texts))] + elif args.test_mode == "mixed": # In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials. output_prompt_1 = [] @@ -264,13 +269,19 @@ def test_deterministic(args): f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}" ) + return [ + len(set(output_prompt_1)), + len(set(output_prompt_2)), + len(set(output_long_prompt)), + ] + elif args.test_mode == "prefix": # In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix. len_prefix = [1, 511, 2048, 4097] num_prompts = len(len_prefix) outputs = {i: [] for i in range(4)} prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)] - for i in range(1, args.n_trials + 1): + for i in range(args.n_start, args.n_start + args.n_trials): batch_size = i ret_dict = send_prefix(args, batch_size, prompts) msg = f"Testing Trial {i} with batch size {batch_size}," @@ -285,6 +296,11 @@ def test_deterministic(args): f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}" ) + results = [] + for i in range(num_prompts): + results.append(len(set(outputs[i]))) + return results + else: raise ValueError(f"Invalid test mode: {args.test_mode}") diff --git a/python/sglang/test/test_deterministic_utils.py b/python/sglang/test/test_deterministic_utils.py new file mode 100644 index 000000000..0c1607686 --- /dev/null +++ b/python/sglang/test/test_deterministic_utils.py @@ -0,0 +1,81 @@ +import time +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_deterministic import BenchArgs, test_deterministic +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DEFAULT_MODEL = "Qwen/Qwen3-8B" +COMMON_SERVER_ARGS = [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "32", + "--enable-deterministic-inference", +] + + +class TestDeterministicBase(CustomTestCase): + @classmethod + def get_server_args(cls): + return COMMON_SERVER_ARGS + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + if "--attention-backend" not in cls.get_server_args(): + raise unittest.SkipTest("Skip the base test class") + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _extract_host_and_port(self, url): + return url.split("://")[-1].split(":")[0], int(url.split(":")[-1]) + + def test_single(self): + args = BenchArgs() + url = DEFAULT_URL_FOR_TEST + args.host, args.port = self._extract_host_and_port(url) + args.test_mode = "single" + args.n_start = 10 + args.n_trials = 20 + results = test_deterministic(args) + for result in results: + assert result == 1 + + def test_mixed(self): + args = BenchArgs() + url = DEFAULT_URL_FOR_TEST + args.host, args.port = self._extract_host_and_port(url) + args.test_mode = "mixed" + args.n_start = 10 + args.n_trials = 20 + results = test_deterministic(args) + for result in results: + assert result == 1 + + def test_prefix(self): + args = BenchArgs() + url = DEFAULT_URL_FOR_TEST + args.host, args.port = self._extract_host_and_port(url) + args.test_mode = "prefix" + args.n_start = 10 + args.n_trials = 10 + results = test_deterministic(args) + for result in results: + assert result == 1 diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cab0b527d..4d97e97c0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -67,6 +67,7 @@ suites = { TestFile("test_abort.py", 51), TestFile("test_create_kvindices.py", 2), TestFile("test_chunked_prefill.py", 313), + TestFile("test_deterministic.py", 300), TestFile("test_eagle_infer_a.py", 370), TestFile("test_eagle_infer_b.py", 700), TestFile("test_ebnf_constrained.py", 108), diff --git a/test/srt/test_deterministic.py b/test/srt/test_deterministic.py new file mode 100644 index 000000000..f0fcc426b --- /dev/null +++ b/test/srt/test_deterministic.py @@ -0,0 +1,70 @@ +""" +Usage: +cd test/srt +python3 -m unittest test_deterministic.TestDeterministic.TESTCASE + +Note that there is also `python/sglang/test/test_deterministic.py` as an interactive test. We are converting that +test into unit tests so that's easily reproducible in CI. +""" + +import unittest + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_deterministic import BenchArgs, test_deterministic +from sglang.test.test_deterministic_utils import ( + COMMON_SERVER_ARGS, + DEFAULT_MODEL, + TestDeterministicBase, +) +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestFlashinferDeterministic(TestDeterministicBase): + # Test with flashinfer attention backend + @classmethod + def get_server_args(cls): + args = COMMON_SERVER_ARGS + args.extend( + [ + "--attention-backend", + "flashinfer", + ] + ) + return args + + +class TestFa3Deterministic(TestDeterministicBase): + # Test with fa3 attention backend + @classmethod + def get_server_args(cls): + args = COMMON_SERVER_ARGS + args.extend( + [ + "--attention-backend", + "fa3", + ] + ) + return args + + +class TestTritonDeterministic(TestDeterministicBase): + # Test with triton attention backend + @classmethod + def get_server_args(cls): + args = COMMON_SERVER_ARGS + args.extend( + [ + "--attention-backend", + "triton", + ] + ) + return args + + +if __name__ == "__main__": + unittest.main()