Support simple evals in text comparator (#8867)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -13,7 +14,11 @@ Supported inputs:
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
df_input = _transform_df_input(_compute_df_raw(args))
|
if args.data_type == "simple_evals":
|
||||||
|
df_input = _compute_df_input_mode_simple_evals(args)
|
||||||
|
else:
|
||||||
|
df_input = _transform_df_input(_compute_df_raw(args))
|
||||||
|
|
||||||
assert all(
|
assert all(
|
||||||
c in df_input.columns
|
c in df_input.columns
|
||||||
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
||||||
@@ -37,8 +42,9 @@ def main(args):
|
|||||||
df_meta=df_meta.to_dicts(),
|
df_meta=df_meta.to_dicts(),
|
||||||
df_good_to_bad=df_good_to_bad.to_dicts(),
|
df_good_to_bad=df_good_to_bad.to_dicts(),
|
||||||
df_bad_to_good=df_bad_to_good.to_dicts(),
|
df_bad_to_good=df_bad_to_good.to_dicts(),
|
||||||
)
|
),
|
||||||
)
|
indent=4,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.disable_print_details:
|
if not args.disable_print_details:
|
||||||
@@ -65,19 +71,70 @@ def main(args):
|
|||||||
print(df)
|
print(df)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_df_input_mode_simple_evals(args):
|
||||||
|
return pl.concat(
|
||||||
|
[
|
||||||
|
_compute_df_input_one_mode_simple_evals(**info)
|
||||||
|
for info in _get_file_infos(args=args)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
|
||||||
|
data = json.loads(Path(path).read_text())
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for single_eval_result in data["metadata"]["single_eval_results"]:
|
||||||
|
prompt = single_eval_result["example_level_metadata"][
|
||||||
|
"actual_queried_prompt_messages"
|
||||||
|
]
|
||||||
|
score = single_eval_result["score"]
|
||||||
|
assert score in {0.0, 1.0}, f"{score=}"
|
||||||
|
|
||||||
|
row = dict(
|
||||||
|
category=category,
|
||||||
|
trial_index=trial_index,
|
||||||
|
prompt_id=_compute_id_from_object(prompt),
|
||||||
|
prompt=json.dumps(prompt),
|
||||||
|
output=single_eval_result["example_level_metadata"]["response_text"],
|
||||||
|
correct=score == 1.0,
|
||||||
|
)
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
return pl.DataFrame(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_id_from_object(obj):
|
||||||
|
if isinstance(obj, pl.Series):
|
||||||
|
obj = obj.to_list()
|
||||||
|
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _compute_df_raw(args):
|
def _compute_df_raw(args):
|
||||||
return pl.concat(
|
return pl.concat(
|
||||||
[
|
[
|
||||||
_read_df_raw(p, category=category, trial_index=i)
|
_read_df_raw(
|
||||||
for category, paths in [
|
path=info["path"],
|
||||||
("baseline", args.baseline_path),
|
category=info["category"],
|
||||||
("target", args.target_path),
|
trial_index=info["trial_index"],
|
||||||
]
|
)
|
||||||
for i, p in enumerate(paths)
|
for info in _get_file_infos(args=args)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_infos(args):
|
||||||
|
return [
|
||||||
|
dict(path=path, category=category, trial_index=trial_index)
|
||||||
|
for category, paths in [
|
||||||
|
("baseline", args.baseline_path),
|
||||||
|
("target", args.target_path),
|
||||||
|
]
|
||||||
|
for trial_index, path in enumerate(paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _read_df_raw(path: str, category: str, trial_index: int):
|
def _read_df_raw(path: str, category: str, trial_index: int):
|
||||||
return pl.read_ndjson(path).with_columns(
|
return pl.read_ndjson(path).with_columns(
|
||||||
category=pl.lit(category), trial_index=trial_index
|
category=pl.lit(category), trial_index=trial_index
|
||||||
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
|
|||||||
print("Transform mode: SGLang bench")
|
print("Transform mode: SGLang bench")
|
||||||
return df
|
return df
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Unknown data: {df.columns}")
|
raise Exception(
|
||||||
|
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _compute_df_meta(df_input: pl.DataFrame):
|
def _compute_df_meta(df_input: pl.DataFrame):
|
||||||
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
|
|||||||
|
|
||||||
|
|
||||||
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
||||||
assert len(set(df_one_prompt["prompt"])) == 1
|
assert (
|
||||||
|
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
|
||||||
|
)
|
||||||
|
|
||||||
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
||||||
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
||||||
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
||||||
|
parser.add_argument("--data-type", type=str, default="auto")
|
||||||
parser.add_argument("--baseline-path", type=str, nargs="+")
|
parser.add_argument("--baseline-path", type=str, nargs="+")
|
||||||
parser.add_argument("--target-path", type=str, nargs="+")
|
parser.add_argument("--target-path", type=str, nargs="+")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user