diff --git a/test/srt/test_throughput.py b/test/srt/test_throughput.py deleted file mode 100644 index 8948e22d7..000000000 --- a/test/srt/test_throughput.py +++ /dev/null @@ -1,61 +0,0 @@ -import json -import unittest - -import requests - -from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server - - -class TestSRTEndpoint(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = "http://127.0.0.1:8157" - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) - - @classmethod - def tearDownClass(cls): - kill_child_process(cls.process.pid) - - def run_decode( - self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1 - ): - response = requests.post( - self.base_url + "/generate", - json={ - "text": "The capital of France is", - "sampling_params": { - "temperature": 0 if n == 1 else 0.5, - "max_new_tokens": 32, - "n": n, - }, - "stream": False, - "return_logprob": return_logprob, - "top_logprobs_num": top_logprobs_num, - "return_text_in_logprobs": return_text, - "logprob_start_len": 0, - }, - ) - print(json.dumps(response.json())) - print("=" * 100) - - def test_simple_decode(self): - self.run_decode() - - def test_parallel_sample(self): - self.run_decode(n=3) - - def test_logprob(self): - for top_logprobs_num in [0, 3]: - for return_text in [True, False]: - self.run_decode( - return_logprob=True, - top_logprobs_num=top_logprobs_num, - return_text=return_text, - ) - - -if __name__ == "__main__": - unittest.main()