Support simple evals in text comparator (#8867)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,7 +14,11 @@ Supported inputs:
|
||||
|
||||
|
||||
def main(args):
|
||||
df_input = _transform_df_input(_compute_df_raw(args))
|
||||
if args.data_type == "simple_evals":
|
||||
df_input = _compute_df_input_mode_simple_evals(args)
|
||||
else:
|
||||
df_input = _transform_df_input(_compute_df_raw(args))
|
||||
|
||||
assert all(
|
||||
c in df_input.columns
|
||||
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
||||
@@ -37,8 +42,9 @@ def main(args):
|
||||
df_meta=df_meta.to_dicts(),
|
||||
df_good_to_bad=df_good_to_bad.to_dicts(),
|
||||
df_bad_to_good=df_bad_to_good.to_dicts(),
|
||||
)
|
||||
)
|
||||
),
|
||||
indent=4,
|
||||
),
|
||||
)
|
||||
|
||||
if not args.disable_print_details:
|
||||
@@ -65,19 +71,70 @@ def main(args):
|
||||
print(df)
|
||||
|
||||
|
||||
def _compute_df_input_mode_simple_evals(args):
|
||||
return pl.concat(
|
||||
[
|
||||
_compute_df_input_one_mode_simple_evals(**info)
|
||||
for info in _get_file_infos(args=args)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
|
||||
data = json.loads(Path(path).read_text())
|
||||
rows = []
|
||||
|
||||
for single_eval_result in data["metadata"]["single_eval_results"]:
|
||||
prompt = single_eval_result["example_level_metadata"][
|
||||
"actual_queried_prompt_messages"
|
||||
]
|
||||
score = single_eval_result["score"]
|
||||
assert score in {0.0, 1.0}, f"{score=}"
|
||||
|
||||
row = dict(
|
||||
category=category,
|
||||
trial_index=trial_index,
|
||||
prompt_id=_compute_id_from_object(prompt),
|
||||
prompt=json.dumps(prompt),
|
||||
output=single_eval_result["example_level_metadata"]["response_text"],
|
||||
correct=score == 1.0,
|
||||
)
|
||||
rows.append(row)
|
||||
|
||||
return pl.DataFrame(rows)
|
||||
|
||||
|
||||
def _compute_id_from_object(obj):
|
||||
if isinstance(obj, pl.Series):
|
||||
obj = obj.to_list()
|
||||
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _compute_df_raw(args):
|
||||
return pl.concat(
|
||||
[
|
||||
_read_df_raw(p, category=category, trial_index=i)
|
||||
for category, paths in [
|
||||
("baseline", args.baseline_path),
|
||||
("target", args.target_path),
|
||||
]
|
||||
for i, p in enumerate(paths)
|
||||
_read_df_raw(
|
||||
path=info["path"],
|
||||
category=info["category"],
|
||||
trial_index=info["trial_index"],
|
||||
)
|
||||
for info in _get_file_infos(args=args)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _get_file_infos(args):
|
||||
return [
|
||||
dict(path=path, category=category, trial_index=trial_index)
|
||||
for category, paths in [
|
||||
("baseline", args.baseline_path),
|
||||
("target", args.target_path),
|
||||
]
|
||||
for trial_index, path in enumerate(paths)
|
||||
]
|
||||
|
||||
|
||||
def _read_df_raw(path: str, category: str, trial_index: int):
|
||||
return pl.read_ndjson(path).with_columns(
|
||||
category=pl.lit(category), trial_index=trial_index
|
||||
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
|
||||
print("Transform mode: SGLang bench")
|
||||
return df
|
||||
else:
|
||||
raise Exception(f"Unknown data: {df.columns}")
|
||||
raise Exception(
|
||||
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
|
||||
)
|
||||
|
||||
|
||||
def _compute_df_meta(df_input: pl.DataFrame):
|
||||
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
|
||||
|
||||
|
||||
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
||||
assert len(set(df_one_prompt["prompt"])) == 1
|
||||
assert (
|
||||
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
|
||||
)
|
||||
|
||||
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
||||
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
||||
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
||||
parser.add_argument("--data-type", type=str, default="auto")
|
||||
parser.add_argument("--baseline-path", type=str, nargs="+")
|
||||
parser.add_argument("--target-path", type=str, nargs="+")
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user