convert test_deterministic into unit tests (#11095)
Signed-off-by: Alex Chi Z <iskyzh@gmail.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -39,12 +39,15 @@ class BenchArgs:
|
||||
profile_steps: int = 3
|
||||
profile_by_stage: bool = False
|
||||
test_mode: str = "single"
|
||||
n_trials: int = 50
|
||||
n_start: int = 1
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--host", type=str, default=BenchArgs.host)
|
||||
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
||||
parser.add_argument("--n-trials", type=int, default=50)
|
||||
parser.add_argument("--n-trials", type=int, default=BenchArgs.n_trials)
|
||||
parser.add_argument("--n-start", type=int, default=BenchArgs.n_start)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument(
|
||||
"--sampling-seed", type=int, default=BenchArgs.sampling_seed
|
||||
@@ -238,6 +241,8 @@ def test_deterministic(args):
|
||||
texts.append(text)
|
||||
|
||||
print(f"Total samples: {len(texts)}, Unique samples: {len(set(texts))}")
|
||||
return [len(set(texts))]
|
||||
|
||||
elif args.test_mode == "mixed":
|
||||
# In mixed mode, we send a mixture of two short prompts and one long prompt in the same batch with batch size ranging from 1 to n_trials.
|
||||
output_prompt_1 = []
|
||||
@@ -264,13 +269,19 @@ def test_deterministic(args):
|
||||
f"Long prompt: total samples: {len(output_long_prompt)}, Unique samples: {len(set(output_long_prompt))}"
|
||||
)
|
||||
|
||||
return [
|
||||
len(set(output_prompt_1)),
|
||||
len(set(output_prompt_2)),
|
||||
len(set(output_long_prompt)),
|
||||
]
|
||||
|
||||
elif args.test_mode == "prefix":
|
||||
# In prefix mode, we create prompts from the same long prompt, with different lengths of common prefix.
|
||||
len_prefix = [1, 511, 2048, 4097]
|
||||
num_prompts = len(len_prefix)
|
||||
outputs = {i: [] for i in range(4)}
|
||||
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
||||
for i in range(1, args.n_trials + 1):
|
||||
for i in range(args.n_start, args.n_start + args.n_trials):
|
||||
batch_size = i
|
||||
ret_dict = send_prefix(args, batch_size, prompts)
|
||||
msg = f"Testing Trial {i} with batch size {batch_size},"
|
||||
@@ -285,6 +296,11 @@ def test_deterministic(args):
|
||||
f"Prompt {i} with prefix length {len_prefix[i]}: total samples: {len(outputs[i])}, Unique samples: {len(set(outputs[i]))}"
|
||||
)
|
||||
|
||||
results = []
|
||||
for i in range(num_prompts):
|
||||
results.append(len(set(outputs[i])))
|
||||
return results
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid test mode: {args.test_mode}")
|
||||
|
||||
|
||||
81
python/sglang/test/test_deterministic_utils.py
Normal file
81
python/sglang/test/test_deterministic_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-8B"
|
||||
COMMON_SERVER_ARGS = [
|
||||
"--trust-remote-code",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
"--enable-deterministic-inference",
|
||||
]
|
||||
|
||||
|
||||
class TestDeterministicBase(CustomTestCase):
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
return COMMON_SERVER_ARGS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
if "--attention-backend" not in cls.get_server_args():
|
||||
raise unittest.SkipTest("Skip the base test class")
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=cls.get_server_args(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def _extract_host_and_port(self, url):
|
||||
return url.split("://")[-1].split(":")[0], int(url.split(":")[-1])
|
||||
|
||||
def test_single(self):
|
||||
args = BenchArgs()
|
||||
url = DEFAULT_URL_FOR_TEST
|
||||
args.host, args.port = self._extract_host_and_port(url)
|
||||
args.test_mode = "single"
|
||||
args.n_start = 10
|
||||
args.n_trials = 20
|
||||
results = test_deterministic(args)
|
||||
for result in results:
|
||||
assert result == 1
|
||||
|
||||
def test_mixed(self):
|
||||
args = BenchArgs()
|
||||
url = DEFAULT_URL_FOR_TEST
|
||||
args.host, args.port = self._extract_host_and_port(url)
|
||||
args.test_mode = "mixed"
|
||||
args.n_start = 10
|
||||
args.n_trials = 20
|
||||
results = test_deterministic(args)
|
||||
for result in results:
|
||||
assert result == 1
|
||||
|
||||
def test_prefix(self):
|
||||
args = BenchArgs()
|
||||
url = DEFAULT_URL_FOR_TEST
|
||||
args.host, args.port = self._extract_host_and_port(url)
|
||||
args.test_mode = "prefix"
|
||||
args.n_start = 10
|
||||
args.n_trials = 10
|
||||
results = test_deterministic(args)
|
||||
for result in results:
|
||||
assert result == 1
|
||||
Reference in New Issue
Block a user