diff --git a/.github/workflows/nightly-eval.yml b/.github/workflows/nightly-eval.yml index 13911f989..a39786611 100644 --- a/.github/workflows/nightly-eval.yml +++ b/.github/workflows/nightly-eval.yml @@ -32,3 +32,4 @@ jobs: run: | cd test/srt python3 test_nightly_human_eval.py + python3 test_nightly_gsm8k_eval.py diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index b035db52b..49ef46169 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -19,6 +19,35 @@ def parse_models(model_string): return [model.strip() for model in model_string.split(",") if model.strip()] +def launch_server(base_url, model, is_fp8, is_tp2): + other_args = ["--log-level-http", "warning", "--trust-remote-code"] + if is_fp8: + if "Llama-3" in model or "gemma-2" in model: + # compressed-tensors + other_args.extend(["--kv-cache-dtype", "fp8_e5m2"]) + elif "Qwen2-72B-Instruct-FP8" in model: + # bug + other_args.extend(["--quantization", "fp8"]) + else: + other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"]) + if is_tp2: + other_args.extend(["--tp", "2"]) + if "DeepSeek" in model: + other_args.extend(["--mem-frac", "0.85"]) + if "AWQ" in model: + other_args.extend(["--quantization", "awq"]) + elif "GPTQ" in model: + other_args.extend(["--quantization", "gptq"]) + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + return process + + class TestEvalAccuracyLarge(unittest.TestCase): @classmethod def setUpClass(cls): @@ -38,40 +67,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): if self.process: kill_child_process(self.process.pid, include_self=True) - def launch_server(self, model, is_fp8, is_tp2): - other_args = ["--log-level-http", "warning", "--trust-remote-code"] - if is_fp8: - if "Llama-3" in model or "gemma-2" in model: - # compressed-tensors - other_args.extend(["--kv-cache-dtype", "fp8_e5m2"]) - elif "Qwen2-72B-Instruct-FP8" in model: - # bug - other_args.extend(["--quantization", "fp8"]) - else: - other_args.extend( - ["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"] - ) - if is_tp2: - other_args.extend(["--tp", "2"]) - if "DeepSeek" in model: - other_args.extend(["--mem-frac", "0.85"]) - if "AWQ" in model: - other_args.extend(["--quantization", "awq"]) - elif "GPTQ" in model: - other_args.extend(["--quantization", "gptq"]) - - self.process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - def test_mgsm_en_all_models(self): for model_group, is_fp8, is_tp2 in self.model_groups: for model in model_group: with self.subTest(model=model): - self.launch_server(model, is_fp8, is_tp2) + self.process = launch_server(self.base_url, model, is_fp8, is_tp2) args = SimpleNamespace( base_url=self.base_url, diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index 6d2ecee50..9028b1045 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -5,7 +5,7 @@ import subprocess import unittest from types import SimpleNamespace -from test_nightly_gsm8k_eval import parse_models +from test_nightly_gsm8k_eval import launch_server, parse_models from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( @@ -39,35 +39,6 @@ class TestEvalAccuracyLarge(unittest.TestCase): if cls.eval_process: kill_child_process(cls.eval_process.pid) - def launch_server(self, model, is_fp8, is_tp2): - other_args = ["--log-level-http", "warning", "--trust-remote-code"] - if is_fp8: - if "Llama-3" in model or "gemma-2" in model: - # compressed-tensors - other_args.extend(["--kv-cache-dtype", "fp8_e5m2"]) - elif "Qwen2-72B-Instruct-FP8" in model: - # bug - other_args.extend(["--quantization", "fp8"]) - else: - other_args.extend( - ["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"] - ) - if is_tp2: - other_args.extend(["--tp", "2"]) - if "DeepSeek" in model: - other_args.extend(["--mem-frac", "0.85"]) - if "AWQ" in model: - other_args.extend(["--quantization", "awq"]) - elif "GPTQ" in model: - other_args.extend(["--quantization", "gptq"]) - - self.process = popen_launch_server( - model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=other_args, - ) - def run_evalplus(self, model): print("Delete evalplus results") shutil.rmtree("evalplus_results", ignore_errors=True) @@ -116,7 +87,9 @@ class TestEvalAccuracyLarge(unittest.TestCase): # NOTE: only Llama for now if "Llama" in model: with self.subTest(model=model): - self.launch_server(model, is_fp8, is_tp2) + self.process = launch_server( + self.base_url, model, is_fp8, is_tp2 + ) self.run_evalplus(model) self.tearDownClass()