minor: update nightly eval (#1867)
This commit is contained in:
1
.github/workflows/nightly-eval.yml
vendored
1
.github/workflows/nightly-eval.yml
vendored
@@ -32,3 +32,4 @@ jobs:
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 test_nightly_human_eval.py
|
||||
python3 test_nightly_gsm8k_eval.py
|
||||
|
||||
@@ -19,6 +19,35 @@ def parse_models(model_string):
|
||||
return [model.strip() for model in model_string.split(",") if model.strip()]
|
||||
|
||||
|
||||
def launch_server(base_url, model, is_fp8, is_tp2):
|
||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||
if is_fp8:
|
||||
if "Llama-3" in model or "gemma-2" in model:
|
||||
# compressed-tensors
|
||||
other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
|
||||
elif "Qwen2-72B-Instruct-FP8" in model:
|
||||
# bug
|
||||
other_args.extend(["--quantization", "fp8"])
|
||||
else:
|
||||
other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"])
|
||||
if is_tp2:
|
||||
other_args.extend(["--tp", "2"])
|
||||
if "DeepSeek" in model:
|
||||
other_args.extend(["--mem-frac", "0.85"])
|
||||
if "AWQ" in model:
|
||||
other_args.extend(["--quantization", "awq"])
|
||||
elif "GPTQ" in model:
|
||||
other_args.extend(["--quantization", "gptq"])
|
||||
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -38,40 +67,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
if self.process:
|
||||
kill_child_process(self.process.pid, include_self=True)
|
||||
|
||||
def launch_server(self, model, is_fp8, is_tp2):
|
||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||
if is_fp8:
|
||||
if "Llama-3" in model or "gemma-2" in model:
|
||||
# compressed-tensors
|
||||
other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
|
||||
elif "Qwen2-72B-Instruct-FP8" in model:
|
||||
# bug
|
||||
other_args.extend(["--quantization", "fp8"])
|
||||
else:
|
||||
other_args.extend(
|
||||
["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"]
|
||||
)
|
||||
if is_tp2:
|
||||
other_args.extend(["--tp", "2"])
|
||||
if "DeepSeek" in model:
|
||||
other_args.extend(["--mem-frac", "0.85"])
|
||||
if "AWQ" in model:
|
||||
other_args.extend(["--quantization", "awq"])
|
||||
elif "GPTQ" in model:
|
||||
other_args.extend(["--quantization", "gptq"])
|
||||
|
||||
self.process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
def test_mgsm_en_all_models(self):
|
||||
for model_group, is_fp8, is_tp2 in self.model_groups:
|
||||
for model in model_group:
|
||||
with self.subTest(model=model):
|
||||
self.launch_server(model, is_fp8, is_tp2)
|
||||
self.process = launch_server(self.base_url, model, is_fp8, is_tp2)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
|
||||
@@ -5,7 +5,7 @@ import subprocess
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from test_nightly_gsm8k_eval import parse_models
|
||||
from test_nightly_gsm8k_eval import launch_server, parse_models
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
@@ -39,35 +39,6 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
if cls.eval_process:
|
||||
kill_child_process(cls.eval_process.pid)
|
||||
|
||||
def launch_server(self, model, is_fp8, is_tp2):
|
||||
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
|
||||
if is_fp8:
|
||||
if "Llama-3" in model or "gemma-2" in model:
|
||||
# compressed-tensors
|
||||
other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
|
||||
elif "Qwen2-72B-Instruct-FP8" in model:
|
||||
# bug
|
||||
other_args.extend(["--quantization", "fp8"])
|
||||
else:
|
||||
other_args.extend(
|
||||
["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"]
|
||||
)
|
||||
if is_tp2:
|
||||
other_args.extend(["--tp", "2"])
|
||||
if "DeepSeek" in model:
|
||||
other_args.extend(["--mem-frac", "0.85"])
|
||||
if "AWQ" in model:
|
||||
other_args.extend(["--quantization", "awq"])
|
||||
elif "GPTQ" in model:
|
||||
other_args.extend(["--quantization", "gptq"])
|
||||
|
||||
self.process = popen_launch_server(
|
||||
model,
|
||||
self.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
|
||||
def run_evalplus(self, model):
|
||||
print("Delete evalplus results")
|
||||
shutil.rmtree("evalplus_results", ignore_errors=True)
|
||||
@@ -116,7 +87,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
# NOTE: only Llama for now
|
||||
if "Llama" in model:
|
||||
with self.subTest(model=model):
|
||||
self.launch_server(model, is_fp8, is_tp2)
|
||||
self.process = launch_server(
|
||||
self.base_url, model, is_fp8, is_tp2
|
||||
)
|
||||
self.run_evalplus(model)
|
||||
self.tearDownClass()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user