diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 05ac0beb1..fe15c015a 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -10,6 +10,7 @@ import numpy as np from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, + dump_bench_raw_result, select_sglang_backend, ) from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl @@ -115,6 +116,12 @@ def main(args): # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) + dump_bench_raw_result( + path=args.raw_result_file, + states=states, + preds=preds, + labels=labels, + ) with open(args.result_file, "a") as fout: value = { diff --git a/benchmark/mmlu/bench_sglang.py b/benchmark/mmlu/bench_sglang.py index 0bae7b6e4..23057be4a 100644 --- a/benchmark/mmlu/bench_sglang.py +++ b/benchmark/mmlu/bench_sglang.py @@ -9,6 +9,7 @@ import tiktoken from sglang.test.test_utils import ( add_common_sglang_args_and_parse, + dump_bench_raw_result, select_sglang_backend, ) @@ -142,6 +143,13 @@ def main(args): assert pt == len(cors) weighted_acc = np.mean(cors) + dump_bench_raw_result( + path=args.raw_result_file, + states=states, + preds=preds, + labels=labels, + ) + # Print results print("Total latency: {:.3f}".format(latency)) print("Average accuracy: {:.3f}".format(weighted_acc)) diff --git a/python/sglang/srt/debug_utils/text_comparator.py b/python/sglang/srt/debug_utils/text_comparator.py new file mode 100644 index 000000000..5917fcfb6 --- /dev/null +++ b/python/sglang/srt/debug_utils/text_comparator.py @@ -0,0 +1,172 @@ +import argparse +import json +from pathlib import Path + +import polars as pl + +_DESCRIPTION = """Compare and find differences to benchmark outputs. + +Supported inputs: +* The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME` +* The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu) +""" + + +def main(args): + df_input = _transform_df_input(_compute_df_raw(args)) + assert all( + c in df_input.columns + for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"] + ) + + df_meta = _compute_df_meta(df_input) + + df_correctness_per_trial = df_input.group_by( + "category", "trial_index", maintain_order=True + ).agg(pl.col("correct").mean()) + df_correctness_delta = ( + df_meta.group_by("correctness_delta").len().sort("correctness_delta") + ) + df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0) + df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0) + + print(f"Dump output to {args.output_path}") + Path(args.output_path).write_text( + json.dumps( + dict( + df_meta=df_meta.to_dicts(), + df_good_to_bad=df_good_to_bad.to_dicts(), + df_bad_to_good=df_bad_to_good.to_dicts(), + ) + ) + ) + + if not args.disable_print_details: + with pl.Config( + fmt_str_lengths=10000, + tbl_cols=-1, + tbl_rows=-1, + tbl_width_chars=-1, + tbl_formatting="UTF8_FULL", + ): + print("====== Correctness per trial ======") + print(df_correctness_per_trial) + + print( + "====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======" + ) + print(df_correctness_delta) + + for name, df in [ + ("Good->Bad", df_good_to_bad), + ("Bad->Good", df_bad_to_good), + ]: + print(f"====== Concrete Examples: {name} ======") + print(df) + + +def _compute_df_raw(args): + return pl.concat( + [ + _read_df_raw(p, category=category, trial_index=i) + for category, paths in [ + ("baseline", args.baseline_path), + ("target", args.target_path), + ] + for i, p in enumerate(paths) + ] + ) + + +def _read_df_raw(path: str, category: str, trial_index: int): + return pl.read_ndjson(path).with_columns( + category=pl.lit(category), trial_index=trial_index + ) + + +def _transform_df_input(df: pl.DataFrame): + if "doc_id" in df.columns: + print("Transform mode: lm_eval") + + filter_names = df["filter"].unique(maintain_order=True).to_list() + if len(filter_names) > 1: + filter_name = filter_names[0] + print(f"Choose {filter_name=} among {filter_names}") + df = df.filter(pl.col("filter") == filter_name) + + df = df.select( + pl.col("category"), + pl.col("trial_index"), + prompt_id=pl.col("doc_id"), + prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"), + output=pl.col("resps").list.get(0).list.get(0), + correct=pl.col("exact_match").cast(bool), + ) + + return df + elif "prompt_id" in df.columns: + print("Transform mode: SGLang bench") + return df + else: + raise Exception(f"Unknown data: {df.columns}") + + +def _compute_df_meta(df_input: pl.DataFrame): + df_input = df_input.sort("prompt_id", "category", "trial_index") + df_meta = pl.DataFrame( + [ + _handle_one_prompt(df_one_prompt) + for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True) + ] + ) + df_meta = df_meta.with_columns( + correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"), + ) + df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len") + return df_meta + + +def _handle_one_prompt(df_one_prompt: pl.DataFrame): + assert len(set(df_one_prompt["prompt"])) == 1 + + df_baseline = df_one_prompt.filter(pl.col("category") == "baseline") + df_target = df_one_prompt.filter(pl.col("category") == "target") + + outputs_baseline = df_baseline["output"].to_list() + outputs_target = df_target["output"].to_list() + + output_same_prefix_len = max( + _compute_str_prefix_len(output_baseline, output_target) + for output_baseline in outputs_baseline + for output_target in outputs_target + ) + + return dict( + prompt_id=df_one_prompt[0, "prompt_id"], + correctness_baseline=df_baseline["correct"].mean(), + correctness_target=df_target["correct"].mean(), + output_same_prefix_len=output_same_prefix_len, + prompt=df_one_prompt[0, "prompt"], + outputs_baseline=outputs_baseline, + outputs_target=outputs_target, + ) + + +def _compute_str_prefix_len(a: str, b: str) -> int: + min_len = min(len(a), len(b)) + for i in range(min_len): + if a[i] != b[i]: + return i + return min_len + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=_DESCRIPTION) + parser.add_argument("--baseline-path", type=str, nargs="+") + parser.add_argument("--target-path", type=str, nargs="+") + parser.add_argument( + "--output-path", type=str, default="/tmp/text_comparator_output.json" + ) + parser.add_argument("--disable-print-details", action="store_true") + args = parser.parse_args() + main(args) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 057bc5eb9..65d989eab 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -15,6 +15,7 @@ import unittest from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial +from pathlib import Path from types import SimpleNamespace from typing import Awaitable, Callable, List, Optional, Tuple @@ -27,6 +28,7 @@ from sglang.bench_serving import run_benchmark from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.lang.interpreter import ProgramState from sglang.srt.utils import ( get_bool_env_var, get_device, @@ -348,6 +350,7 @@ def add_common_sglang_args_and_parse(parser: argparse.ArgumentParser): help="Device type (auto/cuda/rocm/cpu). Auto will detect available platforms", ) parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--raw-result-file", type=str) args = parser.parse_args() return args @@ -1309,3 +1312,35 @@ class CustomTestCase(unittest.TestCase): lambda: super(CustomTestCase, self)._callTestMethod(method), max_retry=max_retry, ) + + +def dump_bench_raw_result( + path: str, + states, + preds, + labels, +): + if not path: + return + + rows = [] + for i in range(len(states)): + state = states[i] + output = state["answer"] + prompt = _ensure_remove_suffix(state.text(), output) + rows.append( + dict( + prompt_id=i, + prompt=prompt, + output=output, + correct=bool(preds[i] == labels[i]), + ) + ) + + print(f"BenchRawResultDumper save results to {path}") + Path(path).write_text("\n".join(json.dumps(row) for row in rows)) + + +def _ensure_remove_suffix(text: str, suffix: str): + assert text.endswith(suffix) + return text.removesuffix(suffix)