70 lines
2.0 KiB
Python
70 lines
2.0 KiB
Python
import unittest
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.test_deterministic import BenchArgs, test_deterministic
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
popen_launch_server,
|
|
)
|
|
|
|
DEFAULT_MODEL = "Qwen/Qwen3-8B"
|
|
COMMON_SERVER_ARGS = [
|
|
"--trust-remote-code",
|
|
"--cuda-graph-max-bs",
|
|
"32",
|
|
"--enable-deterministic-inference",
|
|
]
|
|
|
|
|
|
class TestDeterministicBase(CustomTestCase):
|
|
@classmethod
|
|
def get_server_args(cls):
|
|
return COMMON_SERVER_ARGS
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.model = DEFAULT_MODEL
|
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
if "--attention-backend" not in cls.get_server_args():
|
|
raise unittest.SkipTest("Skip the base test class")
|
|
|
|
cls.process = popen_launch_server(
|
|
cls.model,
|
|
cls.base_url,
|
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=cls.get_server_args(),
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
kill_process_tree(cls.process.pid)
|
|
|
|
def _extract_host_and_port(self, url):
|
|
return url.split("://")[-1].split(":")[0], int(url.split(":")[-1])
|
|
|
|
def test_single(self):
|
|
args = BenchArgs()
|
|
url = DEFAULT_URL_FOR_TEST
|
|
args.host, args.port = self._extract_host_and_port(url)
|
|
args.test_mode = "single"
|
|
args.n_start = 10
|
|
args.n_trials = 20
|
|
results = test_deterministic(args)
|
|
args.temperature = 0.5 # test for deterministic sampling
|
|
for result in results:
|
|
assert result == 1
|
|
|
|
def test_prefix(self):
|
|
args = BenchArgs()
|
|
url = DEFAULT_URL_FOR_TEST
|
|
args.host, args.port = self._extract_host_and_port(url)
|
|
args.test_mode = "prefix"
|
|
args.n_start = 10
|
|
args.n_trials = 10
|
|
args.temperature = 0.5 # test for deterministic sampling
|
|
results = test_deterministic(args)
|
|
for result in results:
|
|
assert result == 1
|