diff --git a/python/sglang/srt/debug_utils/text_comparator.py b/python/sglang/srt/debug_utils/text_comparator.py index 5917fcfb6..3a6df19b9 100644 --- a/python/sglang/srt/debug_utils/text_comparator.py +++ b/python/sglang/srt/debug_utils/text_comparator.py @@ -1,4 +1,5 @@ import argparse +import hashlib import json from pathlib import Path @@ -13,7 +14,11 @@ Supported inputs: def main(args): - df_input = _transform_df_input(_compute_df_raw(args)) + if args.data_type == "simple_evals": + df_input = _compute_df_input_mode_simple_evals(args) + else: + df_input = _transform_df_input(_compute_df_raw(args)) + assert all( c in df_input.columns for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"] @@ -37,8 +42,9 @@ def main(args): df_meta=df_meta.to_dicts(), df_good_to_bad=df_good_to_bad.to_dicts(), df_bad_to_good=df_bad_to_good.to_dicts(), - ) - ) + ), + indent=4, + ), ) if not args.disable_print_details: @@ -65,19 +71,70 @@ def main(args): print(df) +def _compute_df_input_mode_simple_evals(args): + return pl.concat( + [ + _compute_df_input_one_mode_simple_evals(**info) + for info in _get_file_infos(args=args) + ] + ) + + +def _compute_df_input_one_mode_simple_evals(path, category, trial_index): + data = json.loads(Path(path).read_text()) + rows = [] + + for single_eval_result in data["metadata"]["single_eval_results"]: + prompt = single_eval_result["example_level_metadata"][ + "actual_queried_prompt_messages" + ] + score = single_eval_result["score"] + assert score in {0.0, 1.0}, f"{score=}" + + row = dict( + category=category, + trial_index=trial_index, + prompt_id=_compute_id_from_object(prompt), + prompt=json.dumps(prompt), + output=single_eval_result["example_level_metadata"]["response_text"], + correct=score == 1.0, + ) + rows.append(row) + + return pl.DataFrame(rows) + + +def _compute_id_from_object(obj): + if isinstance(obj, pl.Series): + obj = obj.to_list() + json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False) + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + def _compute_df_raw(args): return pl.concat( [ - _read_df_raw(p, category=category, trial_index=i) - for category, paths in [ - ("baseline", args.baseline_path), - ("target", args.target_path), - ] - for i, p in enumerate(paths) + _read_df_raw( + path=info["path"], + category=info["category"], + trial_index=info["trial_index"], + ) + for info in _get_file_infos(args=args) ] ) +def _get_file_infos(args): + return [ + dict(path=path, category=category, trial_index=trial_index) + for category, paths in [ + ("baseline", args.baseline_path), + ("target", args.target_path), + ] + for trial_index, path in enumerate(paths) + ] + + def _read_df_raw(path: str, category: str, trial_index: int): return pl.read_ndjson(path).with_columns( category=pl.lit(category), trial_index=trial_index @@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame): print("Transform mode: SGLang bench") return df else: - raise Exception(f"Unknown data: {df.columns}") + raise Exception( + f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals." + ) def _compute_df_meta(df_input: pl.DataFrame): @@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame): def _handle_one_prompt(df_one_prompt: pl.DataFrame): - assert len(set(df_one_prompt["prompt"])) == 1 + assert ( + len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1 + ) df_baseline = df_one_prompt.filter(pl.col("category") == "baseline") df_target = df_one_prompt.filter(pl.col("category") == "target") @@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int: if __name__ == "__main__": parser = argparse.ArgumentParser(description=_DESCRIPTION) + parser.add_argument("--data-type", type=str, default="auto") parser.add_argument("--baseline-path", type=str, nargs="+") parser.add_argument("--target-path", type=str, nargs="+") parser.add_argument(