Add accuracy test to CI: MMLU (#882)
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.test.test_programs import (
|
||||
test_stream,
|
||||
test_tool_use,
|
||||
)
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class TestSRTBackend(unittest.TestCase):
|
||||
@@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST)
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
@classmethod
|
||||
|
||||
43
test/srt/test_eval_accuracy.py
Normal file
43
test/srt/test_eval_accuracy.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import json
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||
|
||||
|
||||
class TestAccuracy(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
port = 30000
|
||||
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}"
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def test_mmlu(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=20,
|
||||
num_threads=20,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
# t = TestAccuracy()
|
||||
# t.setUpClass()
|
||||
# t.test_mmlu()
|
||||
# t.tearDownClass()
|
||||
@@ -1,47 +1,21 @@
|
||||
import json
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||
|
||||
|
||||
class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
port = 30000
|
||||
timeout = 300
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(port),
|
||||
]
|
||||
cls.process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}/v1"
|
||||
cls.model = model
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{cls.base_url}/models")
|
||||
if response.status_code == 200:
|
||||
return
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -178,8 +152,6 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
is_first = True
|
||||
for response in generator:
|
||||
print(response)
|
||||
|
||||
data = response.choices[0].delta
|
||||
if is_first:
|
||||
data.role == "assistant"
|
||||
|
||||
64
test/srt/test_srt_endpoint.py
Normal file
64
test/srt/test_srt_endpoint.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||
|
||||
|
||||
class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
port = 30000
|
||||
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}"
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def run_decode(
|
||||
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
def test_simple_decode(self):
|
||||
self.run_decode()
|
||||
|
||||
def test_parallel_sample(self):
|
||||
self.run_decode(n=3)
|
||||
|
||||
def test_logprob(self):
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
self.run_decode(
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
Reference in New Issue
Block a user