adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
86
test/lang/test_srt_backend.py
Normal file
86
test/lang/test_srt_backend.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
|
||||
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_programs import (
|
||||
test_decode_int,
|
||||
test_decode_json_regex,
|
||||
test_dtype_gen,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_gen_min_new_tokens,
|
||||
test_hellaswag_select,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
test_regex,
|
||||
test_select,
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
|
||||
|
||||
|
||||
class TestSRTBackend(CustomTestCase):
|
||||
backend = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = sgl.Runtime(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
|
||||
)
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.backend.shutdown()
|
||||
|
||||
def test_few_shot_qa(self):
|
||||
test_few_shot_qa()
|
||||
|
||||
def test_mt_bench(self):
|
||||
test_mt_bench()
|
||||
|
||||
def test_select(self):
|
||||
test_select(check_answer=False)
|
||||
|
||||
def test_decode_int(self):
|
||||
test_decode_int()
|
||||
|
||||
def test_decode_json_regex(self):
|
||||
test_decode_json_regex()
|
||||
|
||||
def test_expert_answer(self):
|
||||
test_expert_answer()
|
||||
|
||||
def test_tool_use(self):
|
||||
test_tool_use()
|
||||
|
||||
def test_parallel_decoding(self):
|
||||
test_parallel_decoding()
|
||||
|
||||
def test_stream(self):
|
||||
test_stream()
|
||||
|
||||
def test_regex(self):
|
||||
test_regex()
|
||||
|
||||
def test_dtype_gen(self):
|
||||
test_dtype_gen()
|
||||
|
||||
def test_hellaswag_select(self):
|
||||
# Run twice to capture more bugs
|
||||
for _ in range(2):
|
||||
accuracy, latency = test_hellaswag_select()
|
||||
self.assertGreater(accuracy, 0.60)
|
||||
|
||||
def test_gen_min_new_tokens(self):
|
||||
test_gen_min_new_tokens()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user