Improve the structure of CI (#911)

This commit is contained in:
Ying Sheng
2024-08-03 23:09:21 -07:00
committed by GitHub
parent 539856455d
commit 995af5a54b
29 changed files with 451 additions and 237 deletions

View File

@@ -20,8 +20,10 @@ dependencies = [
]
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow",
"psutil", "pydantic", "torch", "uvicorn", "uvloop", "zmq", "vllm==0.5.3.post1", "outlines>=0.0.44", "python-multipart", "jsonlines"]
srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "jsonlines",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torch", "uvicorn", "uvloop", "zmq",
"vllm==0.5.3.post1", "outlines>=0.0.44"]
openai = ["openai>=1.0", "tiktoken"]
anthropic = ["anthropic>=0.20.0"]
litellm = ["litellm>=1.0.0"]

View File

@@ -10,7 +10,6 @@ import time
from sglang.test.simple_eval_common import (
ChatCompletionSampler,
download_dataset,
make_report,
set_ulimit,
)
@@ -27,14 +26,26 @@ def run_eval(args):
if args.eval_name == "mmlu":
from sglang.test.simple_eval_mmlu import MMLUEval
dataset_path = "mmlu.csv"
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
elif args.eval_name == "math":
from sglang.test.simple_eval_math import MathEval
if not os.path.exists(dataset_path):
download_dataset(
dataset_path,
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
)
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
filename = (
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
)
eval_obj = MathEval(
filename, equality_checker, args.num_examples, args.num_threads
)
elif args.eval_name == "gpqa":
from sglang.test.simple_eval_gpqa import GPQAEval
filename = (
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
)
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
elif args.eval_name == "humaneval":
from sglang.test.simple_eval_humaneval import HumanEval
@@ -97,7 +108,7 @@ if __name__ == "__main__":
)
parser.add_argument("--eval-name", type=str, default="mmlu")
parser.add_argument("--num-examples", type=int)
parser.add_argument("--num-threads", type=int, default=64)
parser.add_argument("--num-threads", type=int, default=512)
set_ulimit()
args = parser.parse_args()

View File

@@ -0,0 +1,92 @@
# Adapted from https://github.com/openai/simple-evals/
"""
GPQA: A Graduate-Level Google-Proof Q&A Benchmark
David Rein, Betty Li Hou, Asa Cooper Stickland, Jackson Petty, Richard Yuanzhe Pang, Julien Dirani, Julian Michael, Samuel R. Bowman
https://arxiv.org/abs/2311.12022
"""
import random
import re
import pandas
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
ANSWER_PATTERN_MULTICHOICE,
HTML_JINJA,
Eval,
EvalResult,
MessageList,
SamplerBase,
SingleEvalResult,
format_multichoice_question,
)
class GPQAEval(Eval):
def __init__(
self,
filename: str,
num_examples: int | None,
num_threads: int,
n_repeats: int = 1,
):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
rng = random.Random(0)
if num_examples:
assert n_repeats == 1, "n_repeats only supported for num_examples"
examples = rng.sample(examples, num_examples)
examples = examples * n_repeats
examples = [
example | {"permutation": rng.sample(range(4), 4)} for example in examples
]
self.examples = examples
self.n_repeats = n_repeats
self.num_threads = num_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
choices = [
row["Correct Answer"],
row["Incorrect Answer 1"],
row["Incorrect Answer 2"],
row["Incorrect Answer 3"],
]
choices = [choices[i] for i in row["permutation"]]
correct_index = choices.index(row["Correct Answer"])
correct_answer = "ABCD"[correct_index]
choices_dict = dict(
A=choices[0],
B=choices[1],
C=choices[2],
D=choices[3],
Question=row["Question"],
)
prompt_messages = [
sampler._pack_message(
content=format_multichoice_question(choices_dict), role="user"
)
]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
extracted_answer = match.group(1) if match else None
score = 1.0 if extracted_answer == correct_answer else 0.0
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics={"chars": len(response_text)},
)
results = common.map_with_progress(fn, self.examples, self.num_threads)
return common.aggregate_results(results)

View File

@@ -0,0 +1,72 @@
# Adapted from https://github.com/openai/simple-evals/
"""
Measuring Mathematical Problem Solving With the MATH Dataset
Dan Hendrycks, Collin Burns, Saurav Kadavath, Akul Arora, Steven Basart, Eric Tang, Dawn Song, Jacob Steinhardt
https://arxiv.org/abs/2103.03874
"""
import random
import re
import pandas
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
ANSWER_PATTERN,
HTML_JINJA,
Eval,
EvalResult,
SamplerBase,
SingleEvalResult,
check_equality,
)
QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form Answer: $ANSWER (without quotes) where $ANSWER is the answer to the problem.
{Question}
Remember to put your answer on its own line after "Answer:", and you do not need to use a \\boxed command.
""".strip()
class MathEval(Eval):
def __init__(
self,
filename: str,
equality_checker: SamplerBase,
num_examples: int | None,
num_threads: int,
):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
examples = random.Random(0).sample(examples, num_examples)
self.examples = examples
self.equality_checker = equality_checker
self.num_threads = num_threads
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
]
response_text = sampler(prompt_messages)
match = re.search(ANSWER_PATTERN, response_text)
extracted_answer = match.group(1) if match else None
score = float(
check_equality(self.equality_checker, row["Answer"], extracted_answer)
)
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["Answer"],
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(html=html, score=score, convo=convo)
results = common.map_with_progress(fn, self.examples, self.num_threads)
return common.aggregate_results(results)

View File

@@ -1,9 +1,14 @@
"""Common utilities for testing and benchmarking"""
import argparse
import asyncio
import multiprocessing
import subprocess
import threading
import time
import unittest
from functools import partial
from typing import Callable, Optional
import numpy as np
import requests
@@ -247,7 +252,7 @@ async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=
return choices.index(answer)
def add_common_other_args_and_parse(parser):
def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=None)
@@ -286,7 +291,7 @@ def add_common_other_args_and_parse(parser):
return args
def add_common_sglang_args_and_parse(parser):
def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
@@ -296,7 +301,7 @@ def add_common_sglang_args_and_parse(parser):
return args
def select_sglang_backend(args):
def select_sglang_backend(args: argparse.Namespace):
if args.backend.startswith("srt"):
if args.backend == "srt-no-parallel":
global_config.enable_parallel_decoding = False
@@ -309,7 +314,7 @@ def select_sglang_backend(args):
return backend
def _get_call_generate(args):
def _get_call_generate(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_generate_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
@@ -336,7 +341,7 @@ def _get_call_generate(args):
raise ValueError(f"Invalid backend: {args.backend}")
def _get_call_select(args):
def _get_call_select(args: argparse.Namespace):
if args.backend == "lightllm":
return partial(call_select_lightllm, url=f"{args.host}:{args.port}/generate")
elif args.backend == "vllm":
@@ -359,7 +364,7 @@ def _get_call_select(args):
raise ValueError(f"Invalid backend: {args.backend}")
def get_call_generate(args):
def get_call_generate(args: argparse.Namespace):
call_generate = _get_call_generate(args)
def func(*args, **kwargs):
@@ -372,7 +377,7 @@ def get_call_generate(args):
return func
def get_call_select(args):
def get_call_select(args: argparse.Namespace):
call_select = _get_call_select(args)
def func(*args, **kwargs):
@@ -385,7 +390,12 @@ def get_call_select(args):
return func
def popen_launch_server(model, port, timeout, *args):
def popen_launch_server(
model: str, base_url: str, timeout: float, other_args: tuple = ()
):
_, host, port = base_url.split(":")
host = host[2:]
command = [
"python3",
"-m",
@@ -393,21 +403,81 @@ def popen_launch_server(model, port, timeout, *args):
"--model-path",
model,
"--host",
"localhost",
host,
"--port",
str(port),
*args,
port,
*other_args,
]
process = subprocess.Popen(command, stdout=None, stderr=None)
base_url = f"http://localhost:{port}/v1"
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(f"{base_url}/models")
response = requests.get(f"{base_url}/v1/models")
if response.status_code == 200:
return process
except requests.RequestException:
pass
time.sleep(10)
raise TimeoutError("Server failed to start within the timeout period.")
def run_with_timeout(
func: Callable,
args: tuple = (),
kwargs: Optional[dict] = None,
timeout: float = None,
):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
tic = time.time()
success = True
for filename in files:
def func():
print(f"\n\nRun {filename}\n\n")
ret = unittest.main(module=None, argv=["", "-vb"] + [filename])
p = multiprocessing.Process(target=func)
def run_one_file():
p.start()
p.join()
try:
run_with_timeout(run_one_file, timeout=timeout_per_file)
if p.exitcode != 0:
success = False
break
except TimeoutError:
p.terminate()
time.sleep(5)
print(
"\nTimeout after {timeout_per_file} seconds when running {filename}\n"
)
return False
if success:
print(f"Success. Time elapsed: {time.time() - tic:.2f}s")
else:
print(f"Fail. Time elapsed: {time.time() - tic:.2f}s")
return 0 if success else -1

View File

@@ -12,6 +12,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
from typing import Union
import numpy as np
import requests
@@ -25,7 +26,7 @@ def get_exception_traceback():
return err_str
def is_same_type(values):
def is_same_type(values: list):
"""Return whether the elements in values are of the same type."""
if len(values) <= 1:
return True
@@ -45,7 +46,7 @@ def read_jsonl(filename: str):
return rets
def dump_state_text(filename, states, mode="w"):
def dump_state_text(filename: str, states: list, mode: str = "w"):
"""Dump program state in a text file."""
from sglang.lang.interpreter import ProgramState
@@ -105,7 +106,7 @@ def http_request(
return HttpResponse(e)
def encode_image_base64(image_path):
def encode_image_base64(image_path: Union[str, bytes]):
"""Encode an image in base64."""
if isinstance(image_path, str):
with open(image_path, "rb") as image_file:
@@ -144,7 +145,7 @@ def encode_frame(frame):
return frame_bytes
def encode_video_base64(video_path, num_frames=16):
def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
@@ -190,7 +191,7 @@ def encode_video_base64(video_path, num_frames=16):
return video_base64
def _is_chinese_char(cp):
def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
@@ -215,7 +216,7 @@ def _is_chinese_char(cp):
return False
def find_printable_text(text):
def find_printable_text(text: str):
"""Returns the longest printable substring of text that contains only entire words."""
# Borrowed from https://github.com/huggingface/transformers/blob/061580c82c2db1de9139528243e105953793f7a2/src/transformers/generation/streamers.py#L99
@@ -234,26 +235,7 @@ def find_printable_text(text):
return text[: text.rfind(" ") + 1]
def run_with_timeout(func, args=(), kwargs=None, timeout=None):
"""Run a function with timeout."""
ret_value = []
def _target_func():
ret_value.append(func(*args, **(kwargs or {})))
t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()
if not ret_value:
raise RuntimeError()
return ret_value[0]
def graceful_registry(sub_module_name):
def graceful_registry(sub_module_name: str):
def graceful_shutdown(signum, frame):
logger.info(
f"{sub_module_name} Received signal to shutdown. Performing graceful shutdown..."
@@ -265,7 +247,9 @@ def graceful_registry(sub_module_name):
class LazyImport:
def __init__(self, module_name, class_name):
"""Lazy import to make `import sglang` run faster."""
def __init__(self, module_name: str, class_name: str):
self.module_name = module_name
self.class_name = class_name
self._module = None
@@ -276,7 +260,7 @@ class LazyImport:
self._module = getattr(module, self.class_name)
return self._module
def __getattr__(self, name):
def __getattr__(self, name: str):
module = self._load()
return getattr(module, name)