Files
sglang/python/sglang/test/simple_eval_common.py

467 lines
12 KiB
Python
Raw Normal View History

2024-08-01 21:20:17 -07:00
# Adapted from https://github.com/openai/simple-evals/
import os
import resource
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
2024-08-16 02:13:00 -07:00
from typing import Any, Dict, List, Optional, Tuple
2024-08-01 21:20:17 -07:00
2024-08-02 00:47:23 -07:00
import httpx
2024-08-01 21:20:17 -07:00
import jinja2
import numpy as np
import openai
import requests
from openai import OpenAI
from tqdm import tqdm
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
)
2024-08-04 16:02:05 -07:00
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
2024-08-01 21:20:17 -07:00
class SamplerBase:
"""
Base class for defining a sampling model, which can be evaluated,
or used as part of the grading process.
"""
def __call__(self, message_list: MessageList) -> str:
raise NotImplementedError()
@dataclass
class EvalResult:
"""
Result of running an evaluation (usually consisting of many samples)
"""
2024-08-16 02:13:00 -07:00
score: Optional[float] # top-line metric
metrics: Optional[Dict[str, float]] # other metrics
2024-08-04 16:02:05 -07:00
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
2024-08-01 21:20:17 -07:00
@dataclass
class SingleEvalResult:
"""
Result of evaluating a single sample
"""
2024-08-16 02:13:00 -07:00
score: Optional[float]
2024-08-04 16:02:05 -07:00
metrics: Dict[str, float] = field(default_factory=dict)
2024-08-16 02:13:00 -07:00
html: Optional[str] = None
convo: Optional[MessageList] = None # sampled conversation
2024-08-01 21:20:17 -07:00
class Eval:
"""
Base class for defining an evaluation.
"""
def __call__(self, sampler: SamplerBase) -> EvalResult:
raise NotImplementedError()
2024-08-02 00:47:23 -07:00
class LargerHttpxClient(httpx.Client):
def __init__(self):
timeout_config = httpx.Timeout(3600)
limits = httpx.Limits(
max_keepalive_connections=3600,
max_connections=3600,
)
super().__init__(timeout=timeout_config, limits=limits)
2024-08-01 21:20:17 -07:00
class ChatCompletionSampler(SamplerBase):
"""
Sample from OpenAI's chat completion API
"""
def __init__(
self,
base_url: str = None,
2024-08-16 02:13:00 -07:00
model: Optional[str] = None,
system_message: Optional[str] = None,
2024-08-01 21:20:17 -07:00
temperature: float = 0.0,
max_tokens: int = 2048,
):
2024-08-02 00:47:23 -07:00
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
2024-08-01 21:20:17 -07:00
if model is None:
model = self.client.models.list().data[0].id
self.model = model
self.system_message = system_message
self.temperature = temperature
self.max_tokens = max_tokens
self.image_format = "url"
def _handle_image(
self,
image: str,
encoding: str = "base64",
format: str = "png",
fovea: int = 768,
):
new_image = {
"type": "image_url",
"image_url": {
"url": f"data:image/{format};{encoding},{image}",
},
}
return new_image
def _handle_text(self, text: str):
return {"type": "text", "text": text}
def _pack_message(self, role: str, content: Any):
return {"role": str(role), "content": content}
def __call__(self, message_list: MessageList) -> str:
if self.system_message:
message_list = [
self._pack_message("system", self.system_message)
] + message_list
trial = 0
while True:
try:
response = self.client.chat.completions.create(
model=self.model,
messages=message_list,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
return response.choices[0].message.content
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
2024-08-01 21:20:17 -07:00
except openai.BadRequestError as e:
print("Bad Request Error", e)
return ""
except Exception as e:
exception_backoff = 2**trial # expontial back off
print(
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
e,
)
time.sleep(exception_backoff)
trial += 1
# unknown error shall throw exception
QUERY_TEMPLATE_MULTICHOICE = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{Question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
EQUALITY_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Yes
Expression 1: 3/2
Expression 2: 1.5
Yes
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
No
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Yes
Expression 1: 3245/5
Expression 2: 649
No
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Yes
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Yes
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Yes
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s
Expression 2: %(expression2)s
""".strip()
HTML_JINJA = """
<h3>Prompt conversation</h3>
{% for message in prompt_messages %}
{{ message_to_html(message) | safe }}
{% endfor %}
<h3>Sampled message</h3>
{{ message_to_html(next_message) | safe }}
<h3>Results</h3>
<p>Correct Answer: {{ correct_answer }}</p>
<p>Extracted Answer: {{ extracted_answer }}</p>
<p>Score: {{ score }}</p>
"""
def format_multichoice_question(row):
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
response = sampler([dict(content=prompt, role="user")])
return response.lower().strip() == "yes"
def _compute_stat(values: list, stat: str):
if stat == "mean":
return np.mean(values)
elif stat == "std":
return np.std(values)
elif stat == "min":
return np.min(values)
elif stat == "max":
return np.max(values)
else:
raise ValueError(f"Unknown {stat =}")
def aggregate_results(
2024-08-04 16:02:05 -07:00
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
2024-08-16 02:13:00 -07:00
name2stats: Optional[Dict[str, Tuple[str]]] = None,
2024-08-01 21:20:17 -07:00
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
"""
name2stats = name2stats or {}
name2values = defaultdict(list)
htmls = []
convos = []
for single_eval_result in single_eval_results:
for name, value in single_eval_result.metrics.items():
name2values[name].append(value)
if single_eval_result.score is not None:
name2values["score"].append(single_eval_result.score)
htmls.append(single_eval_result.html)
convos.append(single_eval_result.convo)
final_metrics = {}
for name, values in name2values.items():
stats = name2stats.get(name, default_stats)
for stat in stats:
key = name if stat == "mean" else f"{name}:{stat}"
final_metrics[key] = _compute_stat(values, stat)
return EvalResult(
score=final_metrics.pop("score", None),
metrics=final_metrics,
htmls=htmls,
convos=convos,
)
2024-08-04 16:02:05 -07:00
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
2024-08-01 21:20:17 -07:00
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
if os.getenv("debug"):
return list(map(f, tqdm(xs, total=len(xs))))
else:
with ThreadPool(min(num_threads, len(xs))) as pool:
return list(tqdm(pool.imap(f, xs), total=len(xs)))
jinja_env = jinja2.Environment(
loader=jinja2.BaseLoader(),
undefined=jinja2.StrictUndefined,
autoescape=jinja2.select_autoescape(["html", "xml"]),
)
_message_template = """
<div class="message {{ role }}">
<div class="role">
{{ role }}
2024-08-01 21:20:17 -07:00
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
</div>
<div class="content">
<pre>{{ content }}</pre>
</div>
</div>
"""
def message_to_html(message: Message) -> str:
"""
Generate HTML snippet (inside a <div>) for a message.
"""
return jinja_env.from_string(_message_template).render(
role=message["role"],
content=message["content"],
variant=message.get("variant", None),
)
jinja_env.globals["message_to_html"] = message_to_html
_report_template = """<!DOCTYPE html>
<html>
<head>
<style>
.message {
padding: 8px 16px;
margin-bottom: 8px;
border-radius: 4px;
}
.message.user {
background-color: #B2DFDB;
color: #00695C;
}
.message.assistant {
background-color: #B39DDB;
color: #4527A0;
}
.message.system {
background-color: #EEEEEE;
color: #212121;
}
.role {
font-weight: bold;
margin-bottom: 4px;
}
.variant {
color: #795548;
}
table, th, td {
border: 1px solid black;
}
pre {
white-space: pre-wrap;
}
</style>
</head>
<body>
{% if metrics %}
<h1>Metrics</h1>
<table>
<tr>
<th>Metric</th>
<th>Value</th>
</tr>
<tr>
<td><b>Score</b></td>
<td>{{ score | float | round(3) }}</td>
</tr>
{% for name, value in metrics.items() %}
<tr>
<td>{{ name }}</td>
<td>{{ value }}</td>
</tr>
{% endfor %}
</table>
{% endif %}
<h1>Examples</h1>
{% for html in htmls %}
{{ html | safe }}
<hr>
{% endfor %}
</body>
</html>
"""
def make_report(eval_result: EvalResult) -> str:
"""
Create a standalone HTML report from an EvalResult.
"""
return jinja_env.from_string(_report_template).render(
score=eval_result.score,
metrics=eval_result.metrics,
htmls=eval_result.htmls,
)
2024-08-04 16:02:05 -07:00
def make_report_from_example_htmls(htmls: List[str]):
2024-08-01 21:20:17 -07:00
"""
Create a standalone HTML report from a list of example htmls
"""
return jinja_env.from_string(_report_template).render(
score=None, metrics={}, htmls=htmls
)
def download_dataset(path, url):
print(f"Downloading dataset {path} from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
with open(path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
print(f"Dataset downloaded and saved to {path}")
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
if current_soft < target_soft_limit:
try:
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
print(f"Fail to set RLIMIT_NOFILE: {e}")