Add accuracy test to CI: MMLU (#882)
This commit is contained in:
4
.github/workflows/e2e-test.yml
vendored
4
.github/workflows/e2e-test.yml
vendored
@@ -18,7 +18,7 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pr-e2e-test:
|
e2e-test:
|
||||||
runs-on: self-hosted
|
runs-on: self-hosted
|
||||||
|
|
||||||
env:
|
env:
|
||||||
@@ -38,7 +38,7 @@ jobs:
|
|||||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
|
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ --force-reinstall
|
||||||
pip install --upgrade transformers
|
pip install --upgrade transformers
|
||||||
|
|
||||||
- name: Benchmark Serving
|
- name: Benchmark Serving Throughput
|
||||||
run: |
|
run: |
|
||||||
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
||||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache &
|
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --port 8413 --disable-radix-cache &
|
||||||
|
|||||||
7
.github/workflows/unit-test.yml
vendored
7
.github/workflows/unit-test.yml
vendored
@@ -59,3 +59,10 @@ jobs:
|
|||||||
|
|
||||||
cd test/srt
|
cd test/srt
|
||||||
python3 test_openai_server.py
|
python3 test_openai_server.py
|
||||||
|
|
||||||
|
- name: Test Accuracy
|
||||||
|
run: |
|
||||||
|
cd /data/zhyncs/venv && source ./bin/activate && cd -
|
||||||
|
|
||||||
|
cd test/srt
|
||||||
|
python3 test_eval_accuracy.py
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = FlexibleArgumentParser(
|
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
||||||
description="Benchmark the online serving throughput."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--backend",
|
"--backend",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
|
||||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||||
|
default="sglang",
|
||||||
help="Must specify a backend, depending on the LLM Inference Engine.",
|
help="Must specify a backend, depending on the LLM Inference Engine.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
99
python/sglang/test/run_eval.py
Normal file
99
python/sglang/test/run_eval.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m sglang.test.run_eval --port 30000 --eval-name mmlu --num-examples 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from sglang.test.simple_eval_common import (
|
||||||
|
ChatCompletionSampler,
|
||||||
|
download_dataset,
|
||||||
|
make_report,
|
||||||
|
set_ulimit,
|
||||||
|
)
|
||||||
|
from sglang.test.simple_eval_mmlu import MMLUEval
|
||||||
|
|
||||||
|
|
||||||
|
def run_eval(args):
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||||||
|
|
||||||
|
base_url = (
|
||||||
|
f"{args.base_url}/v1" if args.base_url else f"http://{args.host}:{args.port}/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.eval_name == "mmlu":
|
||||||
|
dataset_path = "mmlu.csv"
|
||||||
|
|
||||||
|
if not os.path.exists(dataset_path):
|
||||||
|
download_dataset(
|
||||||
|
dataset_path,
|
||||||
|
"https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||||
|
)
|
||||||
|
eval_obj = MMLUEval(dataset_path, args.num_examples, args.num_threads)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
||||||
|
|
||||||
|
sampler = ChatCompletionSampler(
|
||||||
|
model=args.model,
|
||||||
|
max_tokens=2048,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run eval
|
||||||
|
tic = time.time()
|
||||||
|
result = eval_obj(sampler)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Dump reports
|
||||||
|
metrics = result.metrics | {"score": result.score}
|
||||||
|
file_stem = f"mmlu_{sampler.model.replace('/', '_')}"
|
||||||
|
report_filename = f"/tmp/{file_stem}.html"
|
||||||
|
print(f"Writing report to {report_filename}")
|
||||||
|
with open(report_filename, "w") as fh:
|
||||||
|
fh.write(make_report(result))
|
||||||
|
metrics = result.metrics | {"score": result.score}
|
||||||
|
print(metrics)
|
||||||
|
result_filename = f"/tmp/{file_stem}.json"
|
||||||
|
with open(result_filename, "w") as f:
|
||||||
|
f.write(json.dumps(metrics, indent=2))
|
||||||
|
print(f"Writing results to {result_filename}")
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print(f"Total latency: {latency:.3f} s")
|
||||||
|
print(f"Score: {metrics['score']:.3f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server or API base url if not using http host and port.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||||
|
parser.add_argument("--num-examples", type=int)
|
||||||
|
parser.add_argument("--num-threads", type=int, default=64)
|
||||||
|
set_ulimit()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
run_eval(args)
|
||||||
456
python/sglang/test/simple_eval_common.py
Normal file
456
python/sglang/test/simple_eval_common.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
# Adapted from https://github.com/openai/simple-evals/
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import resource
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from multiprocessing.pool import ThreadPool
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
from openai import OpenAI
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
|
||||||
|
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||||
|
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
|
||||||
|
+ "\nKnowledge cutoff: 2023-12\nCurrent date: 2024-04-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Message = dict[str, Any] # keys role, content
|
||||||
|
MessageList = list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerBase:
|
||||||
|
"""
|
||||||
|
Base class for defining a sampling model, which can be evaluated,
|
||||||
|
or used as part of the grading process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, message_list: MessageList) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalResult:
|
||||||
|
"""
|
||||||
|
Result of running an evaluation (usually consisting of many samples)
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: float | None # top-line metric
|
||||||
|
metrics: dict[str, float] | None # other metrics
|
||||||
|
htmls: list[str] # strings of valid HTML
|
||||||
|
convos: list[MessageList] # sampled conversations
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SingleEvalResult:
|
||||||
|
"""
|
||||||
|
Result of evaluating a single sample
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: float | None
|
||||||
|
metrics: dict[str, float] = field(default_factory=dict)
|
||||||
|
html: str | None = None
|
||||||
|
convo: MessageList | None = None # sampled conversation
|
||||||
|
|
||||||
|
|
||||||
|
class Eval:
|
||||||
|
"""
|
||||||
|
Base class for defining an evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionSampler(SamplerBase):
|
||||||
|
"""
|
||||||
|
Sample from OpenAI's chat completion API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str = None,
|
||||||
|
model: str | None = None,
|
||||||
|
system_message: str | None = None,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
):
|
||||||
|
self.client = OpenAI(base_url=base_url)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = self.client.models.list().data[0].id
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.system_message = system_message
|
||||||
|
self.temperature = temperature
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.image_format = "url"
|
||||||
|
|
||||||
|
def _handle_image(
|
||||||
|
self,
|
||||||
|
image: str,
|
||||||
|
encoding: str = "base64",
|
||||||
|
format: str = "png",
|
||||||
|
fovea: int = 768,
|
||||||
|
):
|
||||||
|
new_image = {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/{format};{encoding},{image}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return new_image
|
||||||
|
|
||||||
|
def _handle_text(self, text: str):
|
||||||
|
return {"type": "text", "text": text}
|
||||||
|
|
||||||
|
def _pack_message(self, role: str, content: Any):
|
||||||
|
return {"role": str(role), "content": content}
|
||||||
|
|
||||||
|
def __call__(self, message_list: MessageList) -> str:
|
||||||
|
if self.system_message:
|
||||||
|
message_list = [
|
||||||
|
self._pack_message("system", self.system_message)
|
||||||
|
] + message_list
|
||||||
|
trial = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=message_list,
|
||||||
|
temperature=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
|
||||||
|
except openai.BadRequestError as e:
|
||||||
|
print("Bad Request Error", e)
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
exception_backoff = 2**trial # expontial back off
|
||||||
|
print(
|
||||||
|
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
time.sleep(exception_backoff)
|
||||||
|
trial += 1
|
||||||
|
# unknown error shall throw exception
|
||||||
|
|
||||||
|
|
||||||
|
QUERY_TEMPLATE_MULTICHOICE = """
|
||||||
|
Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
|
||||||
|
|
||||||
|
{Question}
|
||||||
|
|
||||||
|
A) {A}
|
||||||
|
B) {B}
|
||||||
|
C) {C}
|
||||||
|
D) {D}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
ANSWER_PATTERN_MULTICHOICE = r"(?i)Answer\s*:\s*([A-D])"
|
||||||
|
ANSWER_PATTERN = r"(?i)Answer\s*:\s*([^\n]+)"
|
||||||
|
|
||||||
|
|
||||||
|
EQUALITY_TEMPLATE = r"""
|
||||||
|
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
Expression 1: $2x+3$
|
||||||
|
Expression 2: $3+2x$
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: 3/2
|
||||||
|
Expression 2: 1.5
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $y^2+2y+1$
|
||||||
|
|
||||||
|
No
|
||||||
|
|
||||||
|
Expression 1: $x^2+2x+1$
|
||||||
|
Expression 2: $(x+1)^2$
|
||||||
|
|
||||||
|
Yes
|
||||||
|
|
||||||
|
Expression 1: 3245/5
|
||||||
|
Expression 2: 649
|
||||||
|
|
||||||
|
No
|
||||||
|
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)
|
||||||
|
|
||||||
|
Expression 1: 2/(-3)
|
||||||
|
Expression 2: -2/3
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(trivial simplifications are allowed)
|
||||||
|
|
||||||
|
Expression 1: 72 degrees
|
||||||
|
Expression 2: 72
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
|
||||||
|
Expression 1: 64
|
||||||
|
Expression 2: 64 square feet
|
||||||
|
|
||||||
|
Yes
|
||||||
|
(give benefit of the doubt to units)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
YOUR TASK
|
||||||
|
|
||||||
|
|
||||||
|
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
|
||||||
|
|
||||||
|
Expression 1: %(expression1)s
|
||||||
|
Expression 2: %(expression2)s
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
HTML_JINJA = """
|
||||||
|
<h3>Prompt conversation</h3>
|
||||||
|
{% for message in prompt_messages %}
|
||||||
|
{{ message_to_html(message) | safe }}
|
||||||
|
{% endfor %}
|
||||||
|
<h3>Sampled message</h3>
|
||||||
|
{{ message_to_html(next_message) | safe }}
|
||||||
|
<h3>Results</h3>
|
||||||
|
<p>Correct Answer: {{ correct_answer }}</p>
|
||||||
|
<p>Extracted Answer: {{ extracted_answer }}</p>
|
||||||
|
<p>Score: {{ score }}</p>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def format_multichoice_question(row):
|
||||||
|
return QUERY_TEMPLATE_MULTICHOICE.format(**row)
|
||||||
|
|
||||||
|
|
||||||
|
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
||||||
|
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
||||||
|
response = sampler([dict(content=prompt, role="user")])
|
||||||
|
return response.lower().strip() == "yes"
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_stat(values: list, stat: str):
|
||||||
|
if stat == "mean":
|
||||||
|
return np.mean(values)
|
||||||
|
elif stat == "std":
|
||||||
|
return np.std(values)
|
||||||
|
elif stat == "min":
|
||||||
|
return np.min(values)
|
||||||
|
elif stat == "max":
|
||||||
|
return np.max(values)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown {stat =}")
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_results(
|
||||||
|
single_eval_results: list[SingleEvalResult],
|
||||||
|
default_stats: tuple[str] = ("mean", "std"),
|
||||||
|
name2stats: dict[str, tuple[str]] | None = None,
|
||||||
|
) -> EvalResult:
|
||||||
|
"""
|
||||||
|
Aggregate results from multiple evaluations into a single EvalResult.
|
||||||
|
"""
|
||||||
|
name2stats = name2stats or {}
|
||||||
|
name2values = defaultdict(list)
|
||||||
|
htmls = []
|
||||||
|
convos = []
|
||||||
|
for single_eval_result in single_eval_results:
|
||||||
|
for name, value in single_eval_result.metrics.items():
|
||||||
|
name2values[name].append(value)
|
||||||
|
if single_eval_result.score is not None:
|
||||||
|
name2values["score"].append(single_eval_result.score)
|
||||||
|
htmls.append(single_eval_result.html)
|
||||||
|
convos.append(single_eval_result.convo)
|
||||||
|
final_metrics = {}
|
||||||
|
for name, values in name2values.items():
|
||||||
|
stats = name2stats.get(name, default_stats)
|
||||||
|
for stat in stats:
|
||||||
|
key = name if stat == "mean" else f"{name}:{stat}"
|
||||||
|
final_metrics[key] = _compute_stat(values, stat)
|
||||||
|
return EvalResult(
|
||||||
|
score=final_metrics.pop("score", None),
|
||||||
|
metrics=final_metrics,
|
||||||
|
htmls=htmls,
|
||||||
|
convos=convos,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
||||||
|
"""
|
||||||
|
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||||
|
"""
|
||||||
|
if os.getenv("debug"):
|
||||||
|
return list(map(f, tqdm(xs, total=len(xs))))
|
||||||
|
else:
|
||||||
|
with ThreadPool(min(num_threads, len(xs))) as pool:
|
||||||
|
return list(tqdm(pool.imap(f, xs), total=len(xs)))
|
||||||
|
|
||||||
|
|
||||||
|
jinja_env = jinja2.Environment(
|
||||||
|
loader=jinja2.BaseLoader(),
|
||||||
|
undefined=jinja2.StrictUndefined,
|
||||||
|
autoescape=jinja2.select_autoescape(["html", "xml"]),
|
||||||
|
)
|
||||||
|
_message_template = """
|
||||||
|
<div class="message {{ role }}">
|
||||||
|
<div class="role">
|
||||||
|
{{ role }}
|
||||||
|
{% if variant %}<span class="variant">({{ variant }})</span>{% endif %}
|
||||||
|
</div>
|
||||||
|
<div class="content">
|
||||||
|
<pre>{{ content }}</pre>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def message_to_html(message: Message) -> str:
|
||||||
|
"""
|
||||||
|
Generate HTML snippet (inside a <div>) for a message.
|
||||||
|
"""
|
||||||
|
return jinja_env.from_string(_message_template).render(
|
||||||
|
role=message["role"],
|
||||||
|
content=message["content"],
|
||||||
|
variant=message.get("variant", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
jinja_env.globals["message_to_html"] = message_to_html
|
||||||
|
|
||||||
|
|
||||||
|
_report_template = """<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
.message {
|
||||||
|
padding: 8px 16px;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.message.user {
|
||||||
|
background-color: #B2DFDB;
|
||||||
|
color: #00695C;
|
||||||
|
}
|
||||||
|
.message.assistant {
|
||||||
|
background-color: #B39DDB;
|
||||||
|
color: #4527A0;
|
||||||
|
}
|
||||||
|
.message.system {
|
||||||
|
background-color: #EEEEEE;
|
||||||
|
color: #212121;
|
||||||
|
}
|
||||||
|
.role {
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 4px;
|
||||||
|
}
|
||||||
|
.variant {
|
||||||
|
color: #795548;
|
||||||
|
}
|
||||||
|
table, th, td {
|
||||||
|
border: 1px solid black;
|
||||||
|
}
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
{% if metrics %}
|
||||||
|
<h1>Metrics</h1>
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<th>Metric</th>
|
||||||
|
<th>Value</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><b>Score</b></td>
|
||||||
|
<td>{{ score | float | round(3) }}</td>
|
||||||
|
</tr>
|
||||||
|
{% for name, value in metrics.items() %}
|
||||||
|
<tr>
|
||||||
|
<td>{{ name }}</td>
|
||||||
|
<td>{{ value }}</td>
|
||||||
|
</tr>
|
||||||
|
{% endfor %}
|
||||||
|
</table>
|
||||||
|
{% endif %}
|
||||||
|
<h1>Examples</h1>
|
||||||
|
{% for html in htmls %}
|
||||||
|
{{ html | safe }}
|
||||||
|
<hr>
|
||||||
|
{% endfor %}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def make_report(eval_result: EvalResult) -> str:
|
||||||
|
"""
|
||||||
|
Create a standalone HTML report from an EvalResult.
|
||||||
|
"""
|
||||||
|
return jinja_env.from_string(_report_template).render(
|
||||||
|
score=eval_result.score,
|
||||||
|
metrics=eval_result.metrics,
|
||||||
|
htmls=eval_result.htmls,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_report_from_example_htmls(htmls: list[str]):
|
||||||
|
"""
|
||||||
|
Create a standalone HTML report from a list of example htmls
|
||||||
|
"""
|
||||||
|
return jinja_env.from_string(_report_template).render(
|
||||||
|
score=None, metrics={}, htmls=htmls
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(path, url):
|
||||||
|
print(f"Downloading dataset {path} from {url}")
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
block_size = 8192
|
||||||
|
|
||||||
|
with open(path, "wb") as f, tqdm(
|
||||||
|
desc="Downloading",
|
||||||
|
total=total_size,
|
||||||
|
unit="iB",
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as progress_bar:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
size = f.write(data)
|
||||||
|
progress_bar.update(size)
|
||||||
|
|
||||||
|
print(f"Dataset downloaded and saved to {path}")
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise Exception(f"Failed to download dataset: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def set_ulimit(target_soft_limit=65535):
|
||||||
|
resource_type = resource.RLIMIT_NOFILE
|
||||||
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||||
|
|
||||||
|
if current_soft < target_soft_limit:
|
||||||
|
try:
|
||||||
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
||||||
120
python/sglang/test/simple_eval_mmlu.py
Normal file
120
python/sglang/test/simple_eval_mmlu.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Adapted from https://github.com/openai/simple-evals/
|
||||||
|
|
||||||
|
"""
|
||||||
|
Measuring Massive Multitask Language Understanding
|
||||||
|
Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, Jacob Steinhardt
|
||||||
|
https://arxiv.org/abs/2009.03300
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pandas
|
||||||
|
|
||||||
|
from sglang.test import simple_eval_common as common
|
||||||
|
from sglang.test.simple_eval_common import (
|
||||||
|
ANSWER_PATTERN_MULTICHOICE,
|
||||||
|
HTML_JINJA,
|
||||||
|
Eval,
|
||||||
|
EvalResult,
|
||||||
|
SamplerBase,
|
||||||
|
SingleEvalResult,
|
||||||
|
format_multichoice_question,
|
||||||
|
)
|
||||||
|
|
||||||
|
subject2category = {
|
||||||
|
"abstract_algebra": "stem",
|
||||||
|
"anatomy": "other",
|
||||||
|
"astronomy": "stem",
|
||||||
|
"business_ethics": "other",
|
||||||
|
"clinical_knowledge": "other",
|
||||||
|
"college_biology": "stem",
|
||||||
|
"college_chemistry": "stem",
|
||||||
|
"college_computer_science": "stem",
|
||||||
|
"college_mathematics": "stem",
|
||||||
|
"college_medicine": "other",
|
||||||
|
"college_physics": "stem",
|
||||||
|
"computer_security": "stem",
|
||||||
|
"conceptual_physics": "stem",
|
||||||
|
"econometrics": "social_sciences",
|
||||||
|
"electrical_engineering": "stem",
|
||||||
|
"elementary_mathematics": "stem",
|
||||||
|
"formal_logic": "humanities",
|
||||||
|
"global_facts": "other",
|
||||||
|
"high_school_biology": "stem",
|
||||||
|
"high_school_chemistry": "stem",
|
||||||
|
"high_school_computer_science": "stem",
|
||||||
|
"high_school_european_history": "humanities",
|
||||||
|
"high_school_geography": "social_sciences",
|
||||||
|
"high_school_government_and_politics": "social_sciences",
|
||||||
|
"high_school_macroeconomics": "social_sciences",
|
||||||
|
"high_school_mathematics": "stem",
|
||||||
|
"high_school_microeconomics": "social_sciences",
|
||||||
|
"high_school_physics": "stem",
|
||||||
|
"high_school_psychology": "social_sciences",
|
||||||
|
"high_school_statistics": "stem",
|
||||||
|
"high_school_us_history": "humanities",
|
||||||
|
"high_school_world_history": "humanities",
|
||||||
|
"human_aging": "other",
|
||||||
|
"human_sexuality": "social_sciences",
|
||||||
|
"international_law": "humanities",
|
||||||
|
"jurisprudence": "humanities",
|
||||||
|
"logical_fallacies": "humanities",
|
||||||
|
"machine_learning": "stem",
|
||||||
|
"management": "other",
|
||||||
|
"marketing": "other",
|
||||||
|
"medical_genetics": "other",
|
||||||
|
"miscellaneous": "other",
|
||||||
|
"moral_disputes": "humanities",
|
||||||
|
"moral_scenarios": "humanities",
|
||||||
|
"nutrition": "other",
|
||||||
|
"philosophy": "humanities",
|
||||||
|
"prehistory": "humanities",
|
||||||
|
"professional_accounting": "other",
|
||||||
|
"professional_law": "humanities",
|
||||||
|
"professional_medicine": "other",
|
||||||
|
"professional_psychology": "social_sciences",
|
||||||
|
"public_relations": "social_sciences",
|
||||||
|
"security_studies": "social_sciences",
|
||||||
|
"sociology": "social_sciences",
|
||||||
|
"us_foreign_policy": "social_sciences",
|
||||||
|
"virology": "other",
|
||||||
|
"world_religions": "humanities",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MMLUEval(Eval):
|
||||||
|
def __init__(self, filename: str, num_examples: int | None, num_threads: int):
|
||||||
|
df = pandas.read_csv(filename)
|
||||||
|
examples = [row.to_dict() for _, row in df.iterrows()]
|
||||||
|
if num_examples:
|
||||||
|
examples = random.Random(0).sample(examples, num_examples)
|
||||||
|
self.examples = examples
|
||||||
|
self.num_threads = num_threads
|
||||||
|
|
||||||
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||||
|
def fn(row: dict):
|
||||||
|
prompt_messages = [
|
||||||
|
sampler._pack_message(
|
||||||
|
content=format_multichoice_question(row), role="user"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
response_text = sampler(prompt_messages)
|
||||||
|
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
||||||
|
extracted_answer = match.group(1) if match else None
|
||||||
|
score = 1.0 if extracted_answer == row["Answer"] else 0.0
|
||||||
|
html = common.jinja_env.from_string(HTML_JINJA).render(
|
||||||
|
prompt_messages=prompt_messages,
|
||||||
|
next_message=dict(content=response_text, role="assistant"),
|
||||||
|
score=score,
|
||||||
|
correct_answer=row["Answer"],
|
||||||
|
extracted_answer=extracted_answer,
|
||||||
|
)
|
||||||
|
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||||
|
category = subject2category.get(row["Subject"], "other")
|
||||||
|
return SingleEvalResult(
|
||||||
|
html=html, score=score, metrics={category: score}, convo=convo
|
||||||
|
)
|
||||||
|
|
||||||
|
results = common.map_with_progress(fn, self.examples, self.num_threads)
|
||||||
|
return common.aggregate_results(results)
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
from sglang.srt.conversation import generate_chat_conv
|
|
||||||
from sglang.srt.managers.openai_api.protocol import (
|
|
||||||
ChatCompletionMessageContentImagePart,
|
|
||||||
ChatCompletionMessageContentImageURL,
|
|
||||||
ChatCompletionMessageContentTextPart,
|
|
||||||
ChatCompletionMessageGenericParam,
|
|
||||||
ChatCompletionMessageUserParam,
|
|
||||||
ChatCompletionRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_to_conv_image():
|
|
||||||
"""Test that we can convert a chat image request to a convo"""
|
|
||||||
request = ChatCompletionRequest(
|
|
||||||
model="default",
|
|
||||||
messages=[
|
|
||||||
ChatCompletionMessageGenericParam(
|
|
||||||
role="system", content="You are a helpful AI assistant"
|
|
||||||
),
|
|
||||||
ChatCompletionMessageUserParam(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
ChatCompletionMessageContentTextPart(
|
|
||||||
type="text", text="Describe this image"
|
|
||||||
),
|
|
||||||
ChatCompletionMessageContentImagePart(
|
|
||||||
type="image_url",
|
|
||||||
image_url=ChatCompletionMessageContentImageURL(
|
|
||||||
url="https://someurl.com"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
conv = generate_chat_conv(request, "vicuna_v1.1")
|
|
||||||
assert conv.messages == [
|
|
||||||
["USER", "Describe this image<image>"],
|
|
||||||
["ASSISTANT", None],
|
|
||||||
]
|
|
||||||
assert conv.system_message == "You are a helpful AI assistant"
|
|
||||||
assert conv.image_data == ["https://someurl.com"]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_chat_completion_to_conv_image()
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
from sglang.srt.managers.openai_api.protocol import (
|
|
||||||
ChatCompletionMessageContentImagePart,
|
|
||||||
ChatCompletionMessageContentImageURL,
|
|
||||||
ChatCompletionMessageContentTextPart,
|
|
||||||
ChatCompletionMessageGenericParam,
|
|
||||||
ChatCompletionMessageUserParam,
|
|
||||||
ChatCompletionRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_request_image():
|
|
||||||
"""Test that Chat Completion Requests with images can be converted."""
|
|
||||||
|
|
||||||
image_request = {
|
|
||||||
"model": "default",
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": "Describe this image"},
|
|
||||||
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"temperature": 0,
|
|
||||||
"max_tokens": 64,
|
|
||||||
}
|
|
||||||
request = ChatCompletionRequest(**image_request)
|
|
||||||
assert len(request.messages) == 2
|
|
||||||
assert request.messages[0] == ChatCompletionMessageGenericParam(
|
|
||||||
role="system", content="You are a helpful AI assistant"
|
|
||||||
)
|
|
||||||
assert request.messages[1] == ChatCompletionMessageUserParam(
|
|
||||||
role="user",
|
|
||||||
content=[
|
|
||||||
ChatCompletionMessageContentTextPart(
|
|
||||||
type="text", text="Describe this image"
|
|
||||||
),
|
|
||||||
ChatCompletionMessageContentImagePart(
|
|
||||||
type="image_url",
|
|
||||||
image_url=ChatCompletionMessageContentImageURL(
|
|
||||||
url="https://someurl.com"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_chat_completion_request_image()
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Common utilities for testing and benchmarking"""
|
"""Common utilities for testing and benchmarking"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,6 +13,8 @@ from sglang.lang.backend.openai import OpenAI
|
|||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
|
MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
|
||||||
assert url is not None
|
assert url is not None
|
||||||
@@ -379,3 +383,31 @@ def get_call_select(args):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def popen_launch_server(model, port, timeout, *args):
|
||||||
|
command = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.launch_server",
|
||||||
|
"--model-path",
|
||||||
|
model,
|
||||||
|
"--host",
|
||||||
|
"localhost",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
*args,
|
||||||
|
]
|
||||||
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
|
base_url = f"http://localhost:{port}/v1"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
try:
|
||||||
|
response = requests.get(f"{base_url}/models")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return process
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(10)
|
||||||
|
raise TimeoutError("Server failed to start within the timeout period.")
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from sglang.test.test_programs import (
|
|||||||
test_stream,
|
test_stream,
|
||||||
test_tool_use,
|
test_tool_use,
|
||||||
)
|
)
|
||||||
|
from sglang.test.test_utils import MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
class TestSRTBackend(unittest.TestCase):
|
class TestSRTBackend(unittest.TestCase):
|
||||||
@@ -21,7 +22,7 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.backend = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
cls.backend = sgl.Runtime(model_path=MODEL_NAME_FOR_TEST)
|
||||||
sgl.set_default_backend(cls.backend)
|
sgl.set_default_backend(cls.backend)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
43
test/srt/test_eval_accuracy.py
Normal file
43
test/srt/test_eval_accuracy.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||||
|
|
||||||
|
|
||||||
|
class TestAccuracy(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
port = 30000
|
||||||
|
|
||||||
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = f"http://localhost:{port}"
|
||||||
|
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def test_mmlu(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=20,
|
||||||
|
num_threads=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
assert metrics["score"] >= 0.5
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
|
|
||||||
|
# t = TestAccuracy()
|
||||||
|
# t.setUpClass()
|
||||||
|
# t.test_mmlu()
|
||||||
|
# t.tearDownClass()
|
||||||
@@ -1,47 +1,21 @@
|
|||||||
import json
|
import json
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
|
||||||
|
|
||||||
from sglang.srt.utils import kill_child_process
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIServer(unittest.TestCase):
|
class TestOpenAIServer(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
model = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
|
||||||
port = 30000
|
port = 30000
|
||||||
timeout = 300
|
|
||||||
|
|
||||||
command = [
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
"python3",
|
|
||||||
"-m",
|
|
||||||
"sglang.launch_server",
|
|
||||||
"--model-path",
|
|
||||||
model,
|
|
||||||
"--host",
|
|
||||||
"localhost",
|
|
||||||
"--port",
|
|
||||||
str(port),
|
|
||||||
]
|
|
||||||
cls.process = subprocess.Popen(command, stdout=None, stderr=None)
|
|
||||||
cls.base_url = f"http://localhost:{port}/v1"
|
cls.base_url = f"http://localhost:{port}/v1"
|
||||||
cls.model = model
|
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
while time.time() - start_time < timeout:
|
|
||||||
try:
|
|
||||||
response = requests.get(f"{cls.base_url}/models")
|
|
||||||
if response.status_code == 200:
|
|
||||||
return
|
|
||||||
except requests.RequestException:
|
|
||||||
pass
|
|
||||||
time.sleep(10)
|
|
||||||
raise TimeoutError("Server failed to start within the timeout period.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -178,8 +152,6 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
|
|
||||||
is_first = True
|
is_first = True
|
||||||
for response in generator:
|
for response in generator:
|
||||||
print(response)
|
|
||||||
|
|
||||||
data = response.choices[0].delta
|
data = response.choices[0].delta
|
||||||
if is_first:
|
if is_first:
|
||||||
data.role == "assistant"
|
data.role == "assistant"
|
||||||
|
|||||||
64
test/srt/test_srt_endpoint.py
Normal file
64
test/srt/test_srt_endpoint.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_child_process
|
||||||
|
from sglang.test.run_eval import run_eval
|
||||||
|
from sglang.test.test_utils import MODEL_NAME_FOR_TEST, popen_launch_server
|
||||||
|
|
||||||
|
|
||||||
|
class TestSRTEndpoint(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
port = 30000
|
||||||
|
|
||||||
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = f"http://localhost:{port}"
|
||||||
|
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_child_process(cls.process.pid)
|
||||||
|
|
||||||
|
def run_decode(
|
||||||
|
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
|
||||||
|
):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
|
"max_new_tokens": 32,
|
||||||
|
"n": n,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": return_text,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
def test_simple_decode(self):
|
||||||
|
self.run_decode()
|
||||||
|
|
||||||
|
def test_parallel_sample(self):
|
||||||
|
self.run_decode(n=3)
|
||||||
|
|
||||||
|
def test_logprob(self):
|
||||||
|
for top_logprobs_num in [0, 3]:
|
||||||
|
for return_text in [True, False]:
|
||||||
|
self.run_decode(
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text=return_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
Reference in New Issue
Block a user