Improve the structure of CI (#911)
This commit is contained in:
@@ -20,8 +20,10 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
|
||||
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
|
||||
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
|
||||
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
|
||||
"torch", "uvicorn", "uvloop", "zmq",
|
||||
"vllm==0.5.3.post1", "outlines>=0.0.44"]
|
||||
openai = ["openai>=1.0", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
|
||||
@@ -10,7 +10,6 @@ import time
|
||||
|
||||
from sglang.test.simple_eval_common import (
|
||||
ChatCompletionSampler,
|
||||
download_dataset,
|
||||
make_report,
|
||||
set_ulimit,
|
||||
)
|
||||
@@ -27,14 +26,26 @@ def run_eval(args):
|
||||
if args.eval_name == "mmlu":
|
||||
from sglang.test.simple_eval_mmlu import MMLUEval
|
||||
|
||||
dataset_path = "mmlu.csv"
|
||||
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
|
||||
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
|
||||
elif args.eval_name == "math":
|
||||
from sglang.test.simple_eval_math import MathEval
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
download_dataset(
|
||||
dataset_path,
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||
)
|
||||
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
||||
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
|
||||
|
||||
filename = (
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
|
||||
)
|
||||
eval_obj = MathEval(
|
||||
filename, equality_checker, args.num_examples, args.num_threads
|
||||
)
|
||||
elif args.eval_name == "gpqa":
|
||||
from sglang.test.simple_eval_gpqa import GPQAEval
|
||||
|
||||
filename = (
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
|
||||
)
|
||||
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
|
||||
elif args.eval_name == "humaneval":
|
||||
from sglang.test.simple_eval_humaneval import HumanEval
|
||||
|
||||
@@ -97,7 +108,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||
parser.add_argument("--num-examples", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=64)
|
||||
parser.add_argument("--num-threads", type=int, default=512)
|
||||
set_ulimit()
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
92
python/sglang/test/simple_eval_gpqa.py
Normal file
92
python/sglang/test/simple_eval_gpqa.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
"""
|
||||
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
|
||||
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
|
||||
https://arxiv.org/abs/2311.12022
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN_MULTICHOICE,
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
MessageList,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
format_multichoice_question,
|
||||
)
|
||||
|
||||
|
||||
class GPQAEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
n_repeats: int = 1,
|
||||
):
|
||||
df = pandas.read_csv(filename)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
rng = random.Random(0)
|
||||
if num_examples:
|
||||
assert n_repeats == 1, "n_repeats only supported for num_examples"
|
||||
examples = rng.sample(examples, num_examples)
|
||||
examples = examples * n_repeats
|
||||
examples = [
|
||||
example | {"permutation": rng.sample(range(4), 4)} for example in examples
|
||||
]
|
||||
self.examples = examples
|
||||
self.n_repeats = n_repeats
|
||||
self.num_threads = num_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
choices = [
|
||||
row["Correct Answer"],
|
||||
row["Incorrect Answer 1"],
|
||||
row["Incorrect Answer 2"],
|
||||
row["Incorrect Answer 3"],
|
||||
]
|
||||
choices = [choices[i] for i in row["permutation"]]
|
||||
correct_index = choices.index(row["Correct Answer"])
|
||||
correct_answer = "ABCD"[correct_index]
|
||||
choices_dict = dict(
|
||||
A=choices[0],
|
||||
B=choices[1],
|
||||
C=choices[2],
|
||||
D=choices[3],
|
||||
Question=row["Question"],
|
||||
)
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_multichoice_question(choices_dict), role="user"
|
||||
)
|
||||
]
|
||||
response_text = sampler(prompt_messages)
|
||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = 1.0 if extracted_answer == correct_answer else 0.0
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=correct_answer,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(
|
||||
html=html,
|
||||
score=score,
|
||||
convo=convo,
|
||||
metrics={"chars": len(response_text)},
|
||||
)
|
||||
|
||||
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||
return common.aggregate_results(results)
|
||||
72
python/sglang/test/simple_eval_math.py
Normal file
72
python/sglang/test/simple_eval_math.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
"""
|
||||
Measuring Mathematical Problem Solving With the MATH Dataset
|
||||
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
|
||||
https://arxiv.org/abs/2103.03874
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN,
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
check_equality,
|
||||
)
|
||||
|
||||
QUERY_TEMPLATE = """
|
||||
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
|
||||
|
||||
{Question}
|
||||
|
||||
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
|
||||
""".strip()
|
||||
|
||||
|
||||
class MathEval(Eval):
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
equality_checker: SamplerBase,
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
):
|
||||
df = pandas.read_csv(filename)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
if num_examples:
|
||||
examples = random.Random(0).sample(examples, num_examples)
|
||||
self.examples = examples
|
||||
self.equality_checker = equality_checker
|
||||
self.num_threads = num_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
|
||||
]
|
||||
response_text = sampler(prompt_messages)
|
||||
match = re.search(ANSWER_PATTERN, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = float(
|
||||
check_equality(self.equality_checker, row["Answer"], extracted_answer)
|
||||
)
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=row["Answer"],
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(html=html, score=score, convo=convo)
|
||||
|
||||
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||
return common.aggregate_results(results)
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import multiprocessing
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
|
||||
return choices.index(answer)
|
||||
|
||||
|
||||
def add_common_other_args_and_parse(parser):
|
||||
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=None)
|
||||
@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
|
||||
return args
|
||||
|
||||
|
||||
def add_common_sglang_args_and_parse(parser):
|
||||
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--parallel", type=int, default=64)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
|
||||
return args
|
||||
|
||||
|
||||
def select_sglang_backend(args):
|
||||
def select_sglang_backend(args: argparse.Namespace):
|
||||
if args.backend.startswith("srt"):
|
||||
if args.backend == "srt-no-parallel":
|
||||
global_config.enable_parallel_decoding = False
|
||||
@@ -309,7 +314,7 @@ def select_sglang_backend(args):
|
||||
return backend
|
||||
|
||||
|
||||
def _get_call_generate(args):
|
||||
def _get_call_generate(args: argparse.Namespace):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
@@ -336,7 +341,7 @@ def _get_call_generate(args):
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def _get_call_select(args):
|
||||
def _get_call_select(args: argparse.Namespace):
|
||||
if args.backend == "lightllm":
|
||||
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
|
||||
elif args.backend == "vllm":
|
||||
@@ -359,7 +364,7 @@ def _get_call_select(args):
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
|
||||
def get_call_generate(args):
|
||||
def get_call_generate(args: argparse.Namespace):
|
||||
call_generate = _get_call_generate(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
@@ -372,7 +377,7 @@ def get_call_generate(args):
|
||||
return func
|
||||
|
||||
|
||||
def get_call_select(args):
|
||||
def get_call_select(args: argparse.Namespace):
|
||||
call_select = _get_call_select(args)
|
||||
|
||||
def func(*args, **kwargs):
|
||||
@@ -385,7 +390,12 @@ def get_call_select(args):
|
||||
return func
|
||||
|
||||
|
||||
def popen_launch_server(model, port, timeout, *args):
|
||||
def popen_launch_server(
|
||||
model: str, base_url: str, timeout: float, other_args: tuple = ()
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
@@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args):
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"localhost",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
*args,
|
||||
port,
|
||||
*other_args,
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
base_url = f"http://localhost:{port}/v1"
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{base_url}/models")
|
||||
response = requests.get(f"{base_url}/v1/models")
|
||||
if response.status_code == 200:
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
func: Callable,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
"""Run a function with timeout."""
|
||||
ret_value = []
|
||||
|
||||
def _target_func():
|
||||
ret_value.append(func(*args, **(kwargs or {})))
|
||||
|
||||
t = threading.Thread(target=_target_func)
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError()
|
||||
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def run_unittest_files(files: list[str], timeout_per_file: float):
|
||||
tic = time.time()
|
||||
success = True
|
||||
|
||||
for filename in files:
|
||||
|
||||
def func():
|
||||
print(f"\n\nRun {filename}\n\n")
|
||||
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
|
||||
|
||||
p = multiprocessing.Process(target=func)
|
||||
|
||||
def run_one_file():
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
try:
|
||||
run_with_timeout(run_one_file, timeout=timeout_per_file)
|
||||
if p.exitcode != 0:
|
||||
success = False
|
||||
break
|
||||
except TimeoutError:
|
||||
p.terminate()
|
||||
time.sleep(5)
|
||||
print(
|
||||
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
|
||||
)
|
||||
return False
|
||||
|
||||
if success:
|
||||
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
|
||||
else:
|
||||
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
|
||||
|
||||
return 0 if success else -1
|
||||
|
||||
@@ -12,6 +12,7 @@ import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -25,7 +26,7 @@ def get_exception_traceback():
|
||||
return err_str
|
||||
|
||||
|
||||
def is_same_type(values):
|
||||
def is_same_type(values: list):
|
||||
"""Return whether the elements in values are of the same type."""
|
||||
if len(values) <= 1:
|
||||
return True
|
||||
@@ -45,7 +46,7 @@ def read_jsonl(filename: str):
|
||||
return rets
|
||||
|
||||
|
||||
def dump_state_text(filename, states, mode="w"):
|
||||
def dump_state_text(filename: str, states: list, mode: str = "w"):
|
||||
"""Dump program state in a text file."""
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
|
||||
@@ -105,7 +106,7 @@ def http_request(
|
||||
return HttpResponse(e)
|
||||
|
||||
|
||||
def encode_image_base64(image_path):
|
||||
def encode_image_base64(image_path: Union[str, bytes]):
|
||||
"""Encode an image in base64."""
|
||||
if isinstance(image_path, str):
|
||||
with open(image_path, "rb") as image_file:
|
||||
@@ -144,7 +145,7 @@ def encode_frame(frame):
|
||||
return frame_bytes
|
||||
|
||||
|
||||
def encode_video_base64(video_path, num_frames=16):
|
||||
def encode_video_base64(video_path: str, num_frames: int = 16):
|
||||
import cv2 # pip install opencv-python-headless
|
||||
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16):
|
||||
return video_base64
|
||||
|
||||
|
||||
def _is_chinese_char(cp):
|
||||
def _is_chinese_char(cp: int):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
@@ -215,7 +216,7 @@ def _is_chinese_char(cp):
|
||||
return False
|
||||
|
||||
|
||||
def find_printable_text(text):
|
||||
def find_printable_text(text: str):
|
||||
"""Returns the longest printable substring of text that contains only entire words."""
|
||||
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
|
||||
|
||||
@@ -234,26 +235,7 @@ def find_printable_text(text):
|
||||
return text[: text.rfind(" ") + 1]
|
||||
|
||||
|
||||
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
|
||||
"""Run a function with timeout."""
|
||||
ret_value = []
|
||||
|
||||
def _target_func():
|
||||
ret_value.append(func(*args, **(kwargs or {})))
|
||||
|
||||
t = threading.Thread(target=_target_func)
|
||||
t.start()
|
||||
t.join(timeout=timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError()
|
||||
|
||||
if not ret_value:
|
||||
raise RuntimeError()
|
||||
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def graceful_registry(sub_module_name):
|
||||
def graceful_registry(sub_module_name: str):
|
||||
def graceful_shutdown(signum, frame):
|
||||
logger.info(
|
||||
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
|
||||
@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name):
|
||||
|
||||
|
||||
class LazyImport:
|
||||
def __init__(self, module_name, class_name):
|
||||
"""Lazy import to make `import sglang` run faster."""
|
||||
|
||||
def __init__(self, module_name: str, class_name: str):
|
||||
self.module_name = module_name
|
||||
self.class_name = class_name
|
||||
self._module = None
|
||||
@@ -276,7 +260,7 @@ class LazyImport:
|
||||
self._module = getattr(module, self.class_name)
|
||||
return self._module
|
||||
|
||||
def __getattr__(self, name):
|
||||
def __getattr__(self, name: str):
|
||||
module = self._load()
|
||||
return getattr(module, name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user