Clean up unit tests (#1020)

This commit is contained in:
Lianmin Zheng
2024-08-10 15:09:03 -07:00
committed by GitHub
parent b68c4c073b
commit 54fb1c80c0
24 changed files with 82 additions and 157 deletions

View File

@@ -1,26 +1,32 @@
# Run Unit Tests
## Test Frontend Language
SGLang uses the built-in library [unittest](https://docs.python.org/3/library/unittest.html) as the testing framework.
## Test Backend Runtime
```bash
cd sglang/test/srt
# Run a single file
python3 test_srt_endpoint.py
# Run a single test
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
# Run a suite with multiple files
python3 run_suite.py --suite minimal
```
## Test Frontend Language
```bash
cd sglang/test/lang
export OPENAI_API_KEY=sk-*****
# Run a single file
python3 test_openai_backend.py
# Run a suite
# Run a single test
python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa
# Run a suite with multiple files
python3 run_suite.py --suite minimal
```
## Test Backend Runtime
```
cd sglang/test/srt
# Run a single file
python3 test_eval_accuracy.py
# Run a suite
python3 run_suite.py --suite minimal
```

View File

@@ -21,11 +21,4 @@ class TestAnthropicBackend(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestAnthropicBackend()
# t.setUpClass()
# t.test_mt_bench()
unittest.main()

View File

@@ -48,8 +48,4 @@ class TestBind(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestBind()
# t.setUpClass()
# t.test_cache()
unittest.main()

View File

@@ -87,9 +87,4 @@ class TestChoices(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestChoices()
# t.test_token_length_normalized()
# t.test_greedy_token_selection()
# t.test_unconditional_likelihood_normalized()
unittest.main()

View File

@@ -21,4 +21,4 @@ class TestAnthropicBackend(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()

View File

@@ -88,11 +88,4 @@ class TestOpenAIBackend(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestOpenAIBackend()
# t.setUpClass()
# t.test_stream()
unittest.main()

View File

@@ -61,12 +61,4 @@ class TestSRTBackend(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestSRTBackend()
# t.setUpClass()
# t.test_few_shot_qa()
# t.tearDownClass()
unittest.main()

View File

@@ -125,7 +125,4 @@ class TestTracing(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestTracing()
# t.test_multi_function()
unittest.main()

View File

@@ -14,26 +14,22 @@ from sglang.test.test_programs import (
class TestVertexAIBackend(unittest.TestCase):
backend = None
chat_backend = None
chat_vision_backend = None
@classmethod
def setUpClass(cls):
cls.backend = VertexAI("gemini-pro")
cls.chat_backend = VertexAI("gemini-pro")
cls.chat_vision_backend = VertexAI("gemini-pro-vision")
cls.backend = VertexAI("gemini-1.5-pro-001")
def test_few_shot_qa(self):
set_default_backend(self.backend)
test_few_shot_qa()
def test_mt_bench(self):
set_default_backend(self.chat_backend)
set_default_backend(self.backend)
test_mt_bench()
def test_expert_answer(self):
set_default_backend(self.backend)
test_expert_answer()
test_expert_answer(check_answer=False)
def test_parallel_decoding(self):
set_default_backend(self.backend)
@@ -44,7 +40,7 @@ class TestVertexAIBackend(unittest.TestCase):
test_parallel_encoding()
def test_image_qa(self):
set_default_backend(self.chat_vision_backend)
set_default_backend(self.backend)
test_image_qa()
def test_stream(self):
@@ -53,11 +49,4 @@ class TestVertexAIBackend(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# from sglang.global_config import global_config
# global_config.verbosity = 2
# t = TestVertexAIBackend()
# t.setUpClass()
# t.test_stream()
unittest.main()

View File

@@ -6,9 +6,9 @@ from sglang.test.test_utils import run_unittest_files
suites = {
"minimal": [
"test_eval_accuracy.py",
"test_embedding_openai_server.py",
"test_openai_server.py",
"test_vision_openai_server.py",
"test_embedding_openai_server.py",
"test_chunked_prefill.py",
"test_torch_compile.py",
"test_models_from_modelscope.py",

View File

@@ -37,9 +37,4 @@ class TestAccuracy(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestAccuracy()
# t.setUpClass()
# t.test_mmlu()
# t.tearDownClass()
unittest.main()

View File

@@ -1,11 +1,8 @@
import json
import time
import unittest
import openai
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.openai_api.protocol import EmbeddingObject
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import popen_launch_server
@@ -65,12 +62,12 @@ class TestOpenAIServer(unittest.TestCase):
), f"{response.usage.total_tokens} vs {num_prompt_tokens}"
def run_batch(self):
# FIXME not implemented
# FIXME: not implemented
pass
def test_embedding(self):
# TODO the fields of encoding_format, dimensions, user are skipped
# TODO support use_list_input
# TODO: the fields of encoding_format, dimensions, user are skipped
# TODO: support use_list_input
for use_list_input in [False, True]:
for token_input in [False, True]:
self.run_embedding(use_list_input, token_input)
@@ -80,9 +77,4 @@ class TestOpenAIServer(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_embedding()
# t.tearDownClass()
unittest.main()

View File

@@ -32,9 +32,4 @@ class TestAccuracy(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestAccuracy()
# t.setUpClass()
# t.test_mmlu()
# t.tearDownClass()
unittest.main()

View File

@@ -44,4 +44,4 @@ class TestDownloadFromModelScope(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()

View File

@@ -399,9 +399,4 @@ class TestOpenAIServer(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestOpenAIServer()
# t.setUpClass()
# t.test_completion()
# t.tearDownClass()
unittest.main()

View File

@@ -1,18 +1,13 @@
import json
import os
import sys
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class TestSRTEndpoint(unittest.TestCase):
class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -26,9 +21,7 @@ class TestSRTEndpoint(unittest.TestCase):
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
):
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
self.base_url + "/generate",
json={
@@ -50,7 +43,6 @@ class TestSRTEndpoint(unittest.TestCase):
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
@@ -65,13 +57,11 @@ class TestSRTEndpoint(unittest.TestCase):
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [False, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()

View File

@@ -4,7 +4,6 @@ import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
@@ -59,4 +58,4 @@ class TestSRTEndpoint(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
unittest.main()

View File

@@ -34,9 +34,4 @@ class TestAccuracy(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestAccuracy()
# t.setUpClass()
# t.test_mmlu()
# t.tearDownClass()
unittest.main()

View File

@@ -113,9 +113,4 @@ class TestOpenAIVisionServer(unittest.TestCase):
if __name__ == "__main__":
unittest.main(warnings="ignore")
# t = TestOpenAIVisionServer()
# t.setUpClass()
# t.test_chat_completion()
# t.tearDownClass()
unittest.main()