Minor tool for comparison of benchmark results (#7974)
This commit is contained in:
@@ -10,6 +10,7 @@ import numpy as np
|
||||
from sglang.api import set_default_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
dump_bench_raw_result,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
||||
@@ -115,6 +116,12 @@ def main(args):
|
||||
|
||||
# Dump results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
dump_bench_raw_result(
|
||||
path=args.raw_result_file,
|
||||
states=states,
|
||||
preds=preds,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
|
||||
@@ -9,6 +9,7 @@ import tiktoken
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
dump_bench_raw_result,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
@@ -142,6 +143,13 @@ def main(args):
|
||||
assert pt == len(cors)
|
||||
weighted_acc = np.mean(cors)
|
||||
|
||||
dump_bench_raw_result(
|
||||
path=args.raw_result_file,
|
||||
states=states,
|
||||
preds=preds,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
# Print results
|
||||
print("Total latency: {:.3f}".format(latency))
|
||||
print("Average accuracy: {:.3f}".format(weighted_acc))
|
||||
|
||||
172
python/sglang/srt/debug_utils/text_comparator.py
Normal file
172
python/sglang/srt/debug_utils/text_comparator.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import polars as pl
|
||||
|
||||
_DESCRIPTION = """Compare and find differences to benchmark outputs.
|
||||
|
||||
Supported inputs:
|
||||
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
|
||||
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
|
||||
"""
|
||||
|
||||
|
||||
def main(args):
|
||||
df_input = _transform_df_input(_compute_df_raw(args))
|
||||
assert all(
|
||||
c in df_input.columns
|
||||
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
||||
)
|
||||
|
||||
df_meta = _compute_df_meta(df_input)
|
||||
|
||||
df_correctness_per_trial = df_input.group_by(
|
||||
"category", "trial_index", maintain_order=True
|
||||
).agg(pl.col("correct").mean())
|
||||
df_correctness_delta = (
|
||||
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
|
||||
)
|
||||
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
|
||||
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
|
||||
|
||||
print(f"Dump output to {args.output_path}")
|
||||
Path(args.output_path).write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
df_meta=df_meta.to_dicts(),
|
||||
df_good_to_bad=df_good_to_bad.to_dicts(),
|
||||
df_bad_to_good=df_bad_to_good.to_dicts(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if not args.disable_print_details:
|
||||
with pl.Config(
|
||||
fmt_str_lengths=10000,
|
||||
tbl_cols=-1,
|
||||
tbl_rows=-1,
|
||||
tbl_width_chars=-1,
|
||||
tbl_formatting="UTF8_FULL",
|
||||
):
|
||||
print("====== Correctness per trial ======")
|
||||
print(df_correctness_per_trial)
|
||||
|
||||
print(
|
||||
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
|
||||
)
|
||||
print(df_correctness_delta)
|
||||
|
||||
for name, df in [
|
||||
("Good->Bad", df_good_to_bad),
|
||||
("Bad->Good", df_bad_to_good),
|
||||
]:
|
||||
print(f"====== Concrete Examples: {name} ======")
|
||||
print(df)
|
||||
|
||||
|
||||
def _compute_df_raw(args):
|
||||
return pl.concat(
|
||||
[
|
||||
_read_df_raw(p, category=category, trial_index=i)
|
||||
for category, paths in [
|
||||
("baseline", args.baseline_path),
|
||||
("target", args.target_path),
|
||||
]
|
||||
for i, p in enumerate(paths)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _read_df_raw(path: str, category: str, trial_index: int):
|
||||
return pl.read_ndjson(path).with_columns(
|
||||
category=pl.lit(category), trial_index=trial_index
|
||||
)
|
||||
|
||||
|
||||
def _transform_df_input(df: pl.DataFrame):
|
||||
if "doc_id" in df.columns:
|
||||
print("Transform mode: lm_eval")
|
||||
|
||||
filter_names = df["filter"].unique(maintain_order=True).to_list()
|
||||
if len(filter_names) > 1:
|
||||
filter_name = filter_names[0]
|
||||
print(f"Choose {filter_name=} among {filter_names}")
|
||||
df = df.filter(pl.col("filter") == filter_name)
|
||||
|
||||
df = df.select(
|
||||
pl.col("category"),
|
||||
pl.col("trial_index"),
|
||||
prompt_id=pl.col("doc_id"),
|
||||
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
|
||||
output=pl.col("resps").list.get(0).list.get(0),
|
||||
correct=pl.col("exact_match").cast(bool),
|
||||
)
|
||||
|
||||
return df
|
||||
elif "prompt_id" in df.columns:
|
||||
print("Transform mode: SGLang bench")
|
||||
return df
|
||||
else:
|
||||
raise Exception(f"Unknown data: {df.columns}")
|
||||
|
||||
|
||||
def _compute_df_meta(df_input: pl.DataFrame):
|
||||
df_input = df_input.sort("prompt_id", "category", "trial_index")
|
||||
df_meta = pl.DataFrame(
|
||||
[
|
||||
_handle_one_prompt(df_one_prompt)
|
||||
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
|
||||
]
|
||||
)
|
||||
df_meta = df_meta.with_columns(
|
||||
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
|
||||
)
|
||||
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
|
||||
return df_meta
|
||||
|
||||
|
||||
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
||||
assert len(set(df_one_prompt["prompt"])) == 1
|
||||
|
||||
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
||||
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
||||
|
||||
outputs_baseline = df_baseline["output"].to_list()
|
||||
outputs_target = df_target["output"].to_list()
|
||||
|
||||
output_same_prefix_len = max(
|
||||
_compute_str_prefix_len(output_baseline, output_target)
|
||||
for output_baseline in outputs_baseline
|
||||
for output_target in outputs_target
|
||||
)
|
||||
|
||||
return dict(
|
||||
prompt_id=df_one_prompt[0, "prompt_id"],
|
||||
correctness_baseline=df_baseline["correct"].mean(),
|
||||
correctness_target=df_target["correct"].mean(),
|
||||
output_same_prefix_len=output_same_prefix_len,
|
||||
prompt=df_one_prompt[0, "prompt"],
|
||||
outputs_baseline=outputs_baseline,
|
||||
outputs_target=outputs_target,
|
||||
)
|
||||
|
||||
|
||||
def _compute_str_prefix_len(a: str, b: str) -> int:
|
||||
min_len = min(len(a), len(b))
|
||||
for i in range(min_len):
|
||||
if a[i] != b[i]:
|
||||
return i
|
||||
return min_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
||||
parser.add_argument("--baseline-path", type=str, nargs="+")
|
||||
parser.add_argument("--target-path", type=str, nargs="+")
|
||||
parser.add_argument(
|
||||
"--output-path", type=str, default="/tmp/text_comparator_output.json"
|
||||
)
|
||||
parser.add_argument("--disable-print-details", action="store_true")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -15,6 +15,7 @@ import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.openai import OpenAI
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_device,
|
||||
@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
|
||||
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
|
||||
)
|
||||
parser.add_argument("--result-file", type=str, default="result.jsonl")
|
||||
parser.add_argument("--raw-result-file", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
|
||||
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
||||
max_retry=max_retry,
|
||||
)
|
||||
|
||||
|
||||
def dump_bench_raw_result(
|
||||
path: str,
|
||||
states,
|
||||
preds,
|
||||
labels,
|
||||
):
|
||||
if not path:
|
||||
return
|
||||
|
||||
rows = []
|
||||
for i in range(len(states)):
|
||||
state = states[i]
|
||||
output = state["answer"]
|
||||
prompt = _ensure_remove_suffix(state.text(), output)
|
||||
rows.append(
|
||||
dict(
|
||||
prompt_id=i,
|
||||
prompt=prompt,
|
||||
output=output,
|
||||
correct=bool(preds[i] == labels[i]),
|
||||
)
|
||||
)
|
||||
|
||||
print(f"BenchRawResultDumper save results to {path}")
|
||||
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
|
||||
|
||||
|
||||
def _ensure_remove_suffix(text: str, suffix: str):
|
||||
assert text.endswith(suffix)
|
||||
return text.removesuffix(suffix)
|
||||
|
||||
Reference in New Issue
Block a user