Minor tool for comparison of benchmark results (#7974)

This commit is contained in:
fzyzcjy
2025-07-27 15:27:50 +08:00
committed by GitHub
parent ed0fdbf35b
commit 62222bd27e
4 changed files with 222 additions and 0 deletions

View File

@@ -10,6 +10,7 @@ import numpy as np
from sglang.api import set_default_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
@@ -115,6 +116,12 @@ def main(args):
# Dump results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
with open(args.result_file, "a") as fout:
value = {

View File

@@ -9,6 +9,7 @@ import tiktoken
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
dump_bench_raw_result,
select_sglang_backend,
)
@@ -142,6 +143,13 @@ def main(args):
assert pt == len(cors)
weighted_acc = np.mean(cors)
dump_bench_raw_result(
path=args.raw_result_file,
states=states,
preds=preds,
labels=labels,
)
# Print results
print("Total latency: {:.3f}".format(latency))
print("Average accuracy: {:.3f}".format(weighted_acc))

View File

@@ -0,0 +1,172 @@
import argparse
import json
from pathlib import Path
import polars as pl
_DESCRIPTION = """Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)
df_meta = _compute_df_meta(df_input)
df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
)
if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)
print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)
for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
]
)
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)
def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")
filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)
df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)
return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")
def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()
output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)
return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)
def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)

View File

@@ -15,6 +15,7 @@ import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
@@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark
from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.interpreter import ProgramState
from sglang.srt.utils import (
get_bool_env_var,
get_device,
@@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser):
help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms",
)
parser.add_argument("--result-file", type=str, default="result.jsonl")
parser.add_argument("--raw-result-file", type=str)
args = parser.parse_args()
return args
@@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry,
)
def dump_bench_raw_result(
path: str,
states,
preds,
labels,
):
if not path:
return
rows = []
for i in range(len(states)):
state = states[i]
output = state["answer"]
prompt = _ensure_remove_suffix(state.text(), output)
rows.append(
dict(
prompt_id=i,
prompt=prompt,
output=output,
correct=bool(preds[i] == labels[i]),
)
)
print(f"BenchRawResultDumper save results to {path}")
Path(path).write_text("\n".join(json.dumps(row) for row in rows))
def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix)
return text.removesuffix(suffix)