Add accuracy test to CI: MMLU (#882)
This commit is contained in:
@@ -21,7 +21,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the online serving throughput."
|
||||
)
|
||||
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
default="sglang",
|
||||
help="Must specify a backend, depending on the LLM Inference Engine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
99
python/sglang/test/run_eval.py
Normal file
99
python/sglang/test/run_eval.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
from sglang.test.simple_eval_common import (
|
||||
ChatCompletionSampler,
|
||||
download_dataset,
|
||||
make_report,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.test.simple_eval_mmlu import MMLUEval
|
||||
|
||||
|
||||
def run_eval(args):
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||
|
||||
base_url = (
|
||||
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
||||
)
|
||||
|
||||
if args.eval_name == "mmlu":
|
||||
dataset_path = "mmlu.csv"
|
||||
|
||||
if not os.path.exists(dataset_path):
|
||||
download_dataset(
|
||||
dataset_path,
|
||||
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||
)
|
||||
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
||||
else:
|
||||
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
||||
|
||||
sampler = ChatCompletionSampler(
|
||||
model=args.model,
|
||||
max_tokens=2048,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
# Run eval
|
||||
tic = time.time()
|
||||
result = eval_obj(sampler)
|
||||
latency = time.time() - tic
|
||||
|
||||
# Dump reports
|
||||
metrics = result.metrics | {"score": result.score}
|
||||
file_stem = f"mmlu_{sampler.model.replace('/', '_')}"
|
||||
report_filename = f"/tmp/{file_stem}.html"
|
||||
print(f"Writing report to {report_filename}")
|
||||
with open(report_filename, "w") as fh:
|
||||
fh.write(make_report(result))
|
||||
metrics = result.metrics | {"score": result.score}
|
||||
print(metrics)
|
||||
result_filename = f"/tmp/{file_stem}.json"
|
||||
with open(result_filename, "w") as f:
|
||||
f.write(json.dumps(metrics, indent=2))
|
||||
print(f"Writing results to {result_filename}")
|
||||
|
||||
# Print results
|
||||
print(f"Total latency: {latency:.3f} s")
|
||||
print(f"Score: {metrics['score']:.3f}")
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
||||
)
|
||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||
parser.add_argument("--num-examples", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=64)
|
||||
set_ulimit()
|
||||
args = parser.parse_args()
|
||||
|
||||
run_eval(args)
|
||||
456
python/sglang/test/simple_eval_common.py
Normal file
456
python/sglang/test/simple_eval_common.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
import base64
|
||||
import os
|
||||
import resource
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any
|
||||
|
||||
import jinja2
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
|
||||
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
||||
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
||||
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
||||
)
|
||||
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
|
||||
|
||||
class SamplerBase:
|
||||
"""
|
||||
Base class for defining a sampling model, which can be evaluated,
|
||||
or used as part of the grading process.
|
||||
"""
|
||||
|
||||
def __call__(self, message_list: MessageList) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalResult:
|
||||
"""
|
||||
Result of running an evaluation (usually consisting of many samples)
|
||||
"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
|
||||
|
||||
@dataclass
|
||||
class SingleEvalResult:
|
||||
"""
|
||||
Result of evaluating a single sample
|
||||
"""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
|
||||
|
||||
class Eval:
|
||||
"""
|
||||
Base class for defining an evaluation.
|
||||
"""
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ChatCompletionSampler(SamplerBase):
|
||||
"""
|
||||
Sample from OpenAI's chat completion API
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = None,
|
||||
model: str | None = None,
|
||||
system_message: str | None = None,
|
||||
temperature: float = 0.0,
|
||||
max_tokens: int = 2048,
|
||||
):
|
||||
self.client = OpenAI(base_url=base_url)
|
||||
|
||||
if model is None:
|
||||
model = self.client.models.list().data[0].id
|
||||
|
||||
self.model = model
|
||||
self.system_message = system_message
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.image_format = "url"
|
||||
|
||||
def _handle_image(
|
||||
self,
|
||||
image: str,
|
||||
encoding: str = "base64",
|
||||
format: str = "png",
|
||||
fovea: int = 768,
|
||||
):
|
||||
new_image = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/{format};{encoding},{image}",
|
||||
},
|
||||
}
|
||||
return new_image
|
||||
|
||||
def _handle_text(self, text: str):
|
||||
return {"type": "text", "text": text}
|
||||
|
||||
def _pack_message(self, role: str, content: Any):
|
||||
return {"role": str(role), "content": content}
|
||||
|
||||
def __call__(self, message_list: MessageList) -> str:
|
||||
if self.system_message:
|
||||
message_list = [
|
||||
self._pack_message("system", self.system_message)
|
||||
] + message_list
|
||||
trial = 0
|
||||
while True:
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=message_list,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
||||
except openai.BadRequestError as e:
|
||||
print("Bad Request Error", e)
|
||||
return ""
|
||||
except Exception as e:
|
||||
exception_backoff = 2**trial # expontial back off
|
||||
print(
|
||||
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
||||
e,
|
||||
)
|
||||
time.sleep(exception_backoff)
|
||||
trial += 1
|
||||
# unknown error shall throw exception
|
||||
|
||||
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||
|
||||
{Question}
|
||||
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""".strip()
|
||||
|
||||
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
|
||||
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
|
||||
|
||||
|
||||
EQUALITY_TEMPLATE = r"""
|
||||
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||
|
||||
Examples:
|
||||
|
||||
Expression 1: $2x+3$
|
||||
Expression 2: $3+2x$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3/2
|
||||
Expression 2: 1.5
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $y^2+2y+1$
|
||||
|
||||
No
|
||||
|
||||
Expression 1: $x^2+2x+1$
|
||||
Expression 2: $(x+1)^2$
|
||||
|
||||
Yes
|
||||
|
||||
Expression 1: 3245/5
|
||||
Expression 2: 649
|
||||
|
||||
No
|
||||
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||
|
||||
Expression 1: 2/(-3)
|
||||
Expression 2: -2/3
|
||||
|
||||
Yes
|
||||
(trivial simplifications are allowed)
|
||||
|
||||
Expression 1: 72 degrees
|
||||
Expression 2: 72
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
Expression 1: 64
|
||||
Expression 2: 64 square feet
|
||||
|
||||
Yes
|
||||
(give benefit of the doubt to units)
|
||||
|
||||
---
|
||||
|
||||
YOUR TASK
|
||||
|
||||
|
||||
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||||
|
||||
Expression 1: %(expression1)s
|
||||
Expression 2: %(expression2)s
|
||||
""".strip()
|
||||
|
||||
|
||||
HTML_JINJA = """
|
||||
<h3>Prompt conversation</h3>
|
||||
{% for message in prompt_messages %}
|
||||
{{ message_to_html(message) | safe }}
|
||||
{% endfor %}
|
||||
<h3>Sampled message</h3>
|
||||
{{ message_to_html(next_message) | safe }}
|
||||
<h3>Results</h3>
|
||||
<p>Correct Answer: {{ correct_answer }}</p>
|
||||
<p>Extracted Answer: {{ extracted_answer }}</p>
|
||||
<p>Score: {{ score }}</p>
|
||||
"""
|
||||
|
||||
|
||||
def format_multichoice_question(row):
|
||||
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||||
|
||||
|
||||
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
||||
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
||||
response = sampler([dict(content=prompt, role="user")])
|
||||
return response.lower().strip() == "yes"
|
||||
|
||||
|
||||
def _compute_stat(values: list, stat: str):
|
||||
if stat == "mean":
|
||||
return np.mean(values)
|
||||
elif stat == "std":
|
||||
return np.std(values)
|
||||
elif stat == "min":
|
||||
return np.min(values)
|
||||
elif stat == "max":
|
||||
return np.max(values)
|
||||
else:
|
||||
raise ValueError(f"Unknown {stat =}")
|
||||
|
||||
|
||||
def aggregate_results(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
default_stats: tuple[str] = ("mean", "std"),
|
||||
name2stats: dict[str, tuple[str]] | None = None,
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Aggregate results from multiple evaluations into a single EvalResult.
|
||||
"""
|
||||
name2stats = name2stats or {}
|
||||
name2values = defaultdict(list)
|
||||
htmls = []
|
||||
convos = []
|
||||
for single_eval_result in single_eval_results:
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
name2values["score"].append(single_eval_result.score)
|
||||
htmls.append(single_eval_result.html)
|
||||
convos.append(single_eval_result.convo)
|
||||
final_metrics = {}
|
||||
for name, values in name2values.items():
|
||||
stats = name2stats.get(name, default_stats)
|
||||
for stat in stats:
|
||||
key = name if stat == "mean" else f"{name}:{stat}"
|
||||
final_metrics[key] = _compute_stat(values, stat)
|
||||
return EvalResult(
|
||||
score=final_metrics.pop("score", None),
|
||||
metrics=final_metrics,
|
||||
htmls=htmls,
|
||||
convos=convos,
|
||||
)
|
||||
|
||||
|
||||
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
||||
"""
|
||||
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||
"""
|
||||
if os.getenv("debug"):
|
||||
return list(map(f, tqdm(xs, total=len(xs))))
|
||||
else:
|
||||
with ThreadPool(min(num_threads, len(xs))) as pool:
|
||||
return list(tqdm(pool.imap(f, xs), total=len(xs)))
|
||||
|
||||
|
||||
jinja_env = jinja2.Environment(
|
||||
loader=jinja2.BaseLoader(),
|
||||
undefined=jinja2.StrictUndefined,
|
||||
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
||||
)
|
||||
_message_template = """
|
||||
<div class="message {{ role }}">
|
||||
<div class="role">
|
||||
{{ role }}
|
||||
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
||||
</div>
|
||||
<div class="content">
|
||||
<pre>{{ content }}</pre>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def message_to_html(message: Message) -> str:
|
||||
"""
|
||||
Generate HTML snippet (inside a <div>) for a message.
|
||||
"""
|
||||
return jinja_env.from_string(_message_template).render(
|
||||
role=message["role"],
|
||||
content=message["content"],
|
||||
variant=message.get("variant", None),
|
||||
)
|
||||
|
||||
|
||||
jinja_env.globals["message_to_html"] = message_to_html
|
||||
|
||||
|
||||
_report_template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
.message {
|
||||
padding: 8px 16px;
|
||||
margin-bottom: 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.message.user {
|
||||
background-color: #B2DFDB;
|
||||
color: #00695C;
|
||||
}
|
||||
.message.assistant {
|
||||
background-color: #B39DDB;
|
||||
color: #4527A0;
|
||||
}
|
||||
.message.system {
|
||||
background-color: #EEEEEE;
|
||||
color: #212121;
|
||||
}
|
||||
.role {
|
||||
font-weight: bold;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.variant {
|
||||
color: #795548;
|
||||
}
|
||||
table, th, td {
|
||||
border: 1px solid black;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
{% if metrics %}
|
||||
<h1>Metrics</h1>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Metric</th>
|
||||
<th>Value</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><b>Score</b></td>
|
||||
<td>{{ score | float | round(3) }}</td>
|
||||
</tr>
|
||||
{% for name, value in metrics.items() %}
|
||||
<tr>
|
||||
<td>{{ name }}</td>
|
||||
<td>{{ value }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</table>
|
||||
{% endif %}
|
||||
<h1>Examples</h1>
|
||||
{% for html in htmls %}
|
||||
{{ html | safe }}
|
||||
<hr>
|
||||
{% endfor %}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def make_report(eval_result: EvalResult) -> str:
|
||||
"""
|
||||
Create a standalone HTML report from an EvalResult.
|
||||
"""
|
||||
return jinja_env.from_string(_report_template).render(
|
||||
score=eval_result.score,
|
||||
metrics=eval_result.metrics,
|
||||
htmls=eval_result.htmls,
|
||||
)
|
||||
|
||||
|
||||
def make_report_from_example_htmls(htmls: list[str]):
|
||||
"""
|
||||
Create a standalone HTML report from a list of example htmls
|
||||
"""
|
||||
return jinja_env.from_string(_report_template).render(
|
||||
score=None, metrics={}, htmls=htmls
|
||||
)
|
||||
|
||||
|
||||
def download_dataset(path, url):
|
||||
print(f"Downloading dataset {path} from {url}")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
|
||||
with open(path, "wb") as f, tqdm(
|
||||
desc="Downloading",
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress_bar:
|
||||
for data in response.iter_content(block_size):
|
||||
size = f.write(data)
|
||||
progress_bar.update(size)
|
||||
|
||||
print(f"Dataset downloaded and saved to {path}")
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to download dataset: {e}")
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
|
||||
if current_soft < target_soft_limit:
|
||||
try:
|
||||
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||
except ValueError as e:
|
||||
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||
120
python/sglang/test/simple_eval_mmlu.py
Normal file
120
python/sglang/test/simple_eval_mmlu.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Adapted from https://github.com/openai/simple-evals/
|
||||
|
||||
"""
|
||||
Measuring Massive Multitask Language Understanding
|
||||
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
|
||||
https://arxiv.org/abs/2009.03300
|
||||
"""
|
||||
|
||||
import random
|
||||
import re
|
||||
|
||||
import pandas
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN_MULTICHOICE,
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
format_multichoice_question,
|
||||
)
|
||||
|
||||
subject2category = {
|
||||
"abstract_algebra": "stem",
|
||||
"anatomy": "other",
|
||||
"astronomy": "stem",
|
||||
"business_ethics": "other",
|
||||
"clinical_knowledge": "other",
|
||||
"college_biology": "stem",
|
||||
"college_chemistry": "stem",
|
||||
"college_computer_science": "stem",
|
||||
"college_mathematics": "stem",
|
||||
"college_medicine": "other",
|
||||
"college_physics": "stem",
|
||||
"computer_security": "stem",
|
||||
"conceptual_physics": "stem",
|
||||
"econometrics": "social_sciences",
|
||||
"electrical_engineering": "stem",
|
||||
"elementary_mathematics": "stem",
|
||||
"formal_logic": "humanities",
|
||||
"global_facts": "other",
|
||||
"high_school_biology": "stem",
|
||||
"high_school_chemistry": "stem",
|
||||
"high_school_computer_science": "stem",
|
||||
"high_school_european_history": "humanities",
|
||||
"high_school_geography": "social_sciences",
|
||||
"high_school_government_and_politics": "social_sciences",
|
||||
"high_school_macroeconomics": "social_sciences",
|
||||
"high_school_mathematics": "stem",
|
||||
"high_school_microeconomics": "social_sciences",
|
||||
"high_school_physics": "stem",
|
||||
"high_school_psychology": "social_sciences",
|
||||
"high_school_statistics": "stem",
|
||||
"high_school_us_history": "humanities",
|
||||
"high_school_world_history": "humanities",
|
||||
"human_aging": "other",
|
||||
"human_sexuality": "social_sciences",
|
||||
"international_law": "humanities",
|
||||
"jurisprudence": "humanities",
|
||||
"logical_fallacies": "humanities",
|
||||
"machine_learning": "stem",
|
||||
"management": "other",
|
||||
"marketing": "other",
|
||||
"medical_genetics": "other",
|
||||
"miscellaneous": "other",
|
||||
"moral_disputes": "humanities",
|
||||
"moral_scenarios": "humanities",
|
||||
"nutrition": "other",
|
||||
"philosophy": "humanities",
|
||||
"prehistory": "humanities",
|
||||
"professional_accounting": "other",
|
||||
"professional_law": "humanities",
|
||||
"professional_medicine": "other",
|
||||
"professional_psychology": "social_sciences",
|
||||
"public_relations": "social_sciences",
|
||||
"security_studies": "social_sciences",
|
||||
"sociology": "social_sciences",
|
||||
"us_foreign_policy": "social_sciences",
|
||||
"virology": "other",
|
||||
"world_religions": "humanities",
|
||||
}
|
||||
|
||||
|
||||
class MMLUEval(Eval):
|
||||
def __init__(self, filename: str, num_examples: int | None, num_threads: int):
|
||||
df = pandas.read_csv(filename)
|
||||
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||
if num_examples:
|
||||
examples = random.Random(0).sample(examples, num_examples)
|
||||
self.examples = examples
|
||||
self.num_threads = num_threads
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(row: dict):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
content=format_multichoice_question(row), role="user"
|
||||
)
|
||||
]
|
||||
response_text = sampler(prompt_messages)
|
||||
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
||||
extracted_answer = match.group(1) if match else None
|
||||
score = 1.0 if extracted_answer == row["Answer"] else 0.0
|
||||
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=row["Answer"],
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
category = subject2category.get(row["Subject"], "other")
|
||||
return SingleEvalResult(
|
||||
html=html, score=score, metrics={category: score}, convo=convo
|
||||
)
|
||||
|
||||
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||
return common.aggregate_results(results)
|
||||
@@ -1,46 +0,0 @@
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.openai_api.protocol import (
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentImageURL,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageGenericParam,
|
||||
ChatCompletionMessageUserParam,
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_to_conv_image():
|
||||
"""Test that we can convert a chat image request to a convo"""
|
||||
request = ChatCompletionRequest(
|
||||
model="default",
|
||||
messages=[
|
||||
ChatCompletionMessageGenericParam(
|
||||
role="system", content="You are a helpful AI assistant"
|
||||
),
|
||||
ChatCompletionMessageUserParam(
|
||||
role="user",
|
||||
content=[
|
||||
ChatCompletionMessageContentTextPart(
|
||||
type="text", text="Describe this image"
|
||||
),
|
||||
ChatCompletionMessageContentImagePart(
|
||||
type="image_url",
|
||||
image_url=ChatCompletionMessageContentImageURL(
|
||||
url="https://someurl.com"
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
conv = generate_chat_conv(request, "vicuna_v1.1")
|
||||
assert conv.messages == [
|
||||
["USER", "Describe this image<image>"],
|
||||
["ASSISTANT", None],
|
||||
]
|
||||
assert conv.system_message == "You are a helpful AI assistant"
|
||||
assert conv.image_data == ["https://someurl.com"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chat_completion_to_conv_image()
|
||||
@@ -1,51 +0,0 @@
|
||||
from sglang.srt.managers.openai_api.protocol import (
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentImageURL,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageGenericParam,
|
||||
ChatCompletionMessageUserParam,
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_image():
|
||||
"""Test that Chat Completion Requests with images can be converted."""
|
||||
|
||||
image_request = {
|
||||
"model": "default",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 64,
|
||||
}
|
||||
request = ChatCompletionRequest(**image_request)
|
||||
assert len(request.messages) == 2
|
||||
assert request.messages[0] == ChatCompletionMessageGenericParam(
|
||||
role="system", content="You are a helpful AI assistant"
|
||||
)
|
||||
assert request.messages[1] == ChatCompletionMessageUserParam(
|
||||
role="user",
|
||||
content=[
|
||||
ChatCompletionMessageContentTextPart(
|
||||
type="text", text="Describe this image"
|
||||
),
|
||||
ChatCompletionMessageContentImagePart(
|
||||
type="image_url",
|
||||
image_url=ChatCompletionMessageContentImageURL(
|
||||
url="https://someurl.com"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chat_completion_request_image()
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Common utilities for testing and benchmarking"""
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
@@ -11,6 +13,8 @@ from sglang.lang.backend.openai import OpenAI
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
|
||||
|
||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
assert url is not None
|
||||
@@ -379,3 +383,31 @@ def get_call_select(args):
|
||||
raise
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def popen_launch_server(model, port, timeout, *args):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(port),
|
||||
*args,
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
base_url = f"http://localhost:{port}/v1"
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"{base_url}/models")
|
||||
if response.status_code == 200:
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
Reference in New Issue
Block a user