Files
sglang/python/sglang/srt/debug_utils/text_comparator.py

173 lines
5.4 KiB
Python

import argparse
import json
from pathlib import Path
import polars as pl
_DESCRIPTION = """Compare and find differences to benchmark outputs.
Supported inputs:
* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
"""
def main(args):
df_input = _transform_df_input(_compute_df_raw(args))
assert all(
c in df_input.columns
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
)
df_meta = _compute_df_meta(df_input)
df_correctness_per_trial = df_input.group_by(
"category", "trial_index", maintain_order=True
).agg(pl.col("correct").mean())
df_correctness_delta = (
df_meta.group_by("correctness_delta").len().sort("correctness_delta")
)
df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
print(f"Dump output to {args.output_path}")
Path(args.output_path).write_text(
json.dumps(
dict(
df_meta=df_meta.to_dicts(),
df_good_to_bad=df_good_to_bad.to_dicts(),
df_bad_to_good=df_bad_to_good.to_dicts(),
)
)
)
if not args.disable_print_details:
with pl.Config(
fmt_str_lengths=10000,
tbl_cols=-1,
tbl_rows=-1,
tbl_width_chars=-1,
tbl_formatting="UTF8_FULL",
):
print("====== Correctness per trial ======")
print(df_correctness_per_trial)
print(
"====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
)
print(df_correctness_delta)
for name, df in [
("Good->Bad", df_good_to_bad),
("Bad->Good", df_bad_to_good),
]:
print(f"====== Concrete Examples: {name} ======")
print(df)
def _compute_df_raw(args):
return pl.concat(
[
_read_df_raw(p, category=category, trial_index=i)
for category, paths in [
("baseline", args.baseline_path),
("target", args.target_path),
]
for i, p in enumerate(paths)
]
)
def _read_df_raw(path: str, category: str, trial_index: int):
return pl.read_ndjson(path).with_columns(
category=pl.lit(category), trial_index=trial_index
)
def _transform_df_input(df: pl.DataFrame):
if "doc_id" in df.columns:
print("Transform mode: lm_eval")
filter_names = df["filter"].unique(maintain_order=True).to_list()
if len(filter_names) > 1:
filter_name = filter_names[0]
print(f"Choose {filter_name=} among {filter_names}")
df = df.filter(pl.col("filter") == filter_name)
df = df.select(
pl.col("category"),
pl.col("trial_index"),
prompt_id=pl.col("doc_id"),
prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
output=pl.col("resps").list.get(0).list.get(0),
correct=pl.col("exact_match").cast(bool),
)
return df
elif "prompt_id" in df.columns:
print("Transform mode: SGLang bench")
return df
else:
raise Exception(f"Unknown data: {df.columns}")
def _compute_df_meta(df_input: pl.DataFrame):
df_input = df_input.sort("prompt_id", "category", "trial_index")
df_meta = pl.DataFrame(
[
_handle_one_prompt(df_one_prompt)
for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
]
)
df_meta = df_meta.with_columns(
correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
)
df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
return df_meta
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
assert len(set(df_one_prompt["prompt"])) == 1
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
df_target = df_one_prompt.filter(pl.col("category") == "target")
outputs_baseline = df_baseline["output"].to_list()
outputs_target = df_target["output"].to_list()
output_same_prefix_len = max(
_compute_str_prefix_len(output_baseline, output_target)
for output_baseline in outputs_baseline
for output_target in outputs_target
)
return dict(
prompt_id=df_one_prompt[0, "prompt_id"],
correctness_baseline=df_baseline["correct"].mean(),
correctness_target=df_target["correct"].mean(),
output_same_prefix_len=output_same_prefix_len,
prompt=df_one_prompt[0, "prompt"],
outputs_baseline=outputs_baseline,
outputs_target=outputs_target,
)
def _compute_str_prefix_len(a: str, b: str) -> int:
min_len = min(len(a), len(b))
for i in range(min_len):
if a[i] != b[i]:
return i
return min_len
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=_DESCRIPTION)
parser.add_argument("--baseline-path", type=str, nargs="+")
parser.add_argument("--target-path", type=str, nargs="+")
parser.add_argument(
"--output-path", type=str, default="/tmp/text_comparator_output.json"
)
parser.add_argument("--disable-print-details", action="store_true")
args = parser.parse_args()
main(args)