Benchmark for reasoning models (#3532)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
54
benchmark/reasoning_benchmark/README.md
Normal file
54
benchmark/reasoning_benchmark/README.md
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Run benchmark
|
||||||
|
|
||||||
|
This benchmark is primarily intended to be used with reasoning models like `DeepSeek-R1` and its distilled models like `DeepSeek-R1-Distill-Qwen-1.5B`. Please use
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install antlr4-python3-runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
for `parse_latex` which we use for symbolic equality check.
|
||||||
|
|
||||||
|
## Benchmark sglang
|
||||||
|
|
||||||
|
1. Launch the Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that depending on the GPU this benchmark will take quiet some time. To employ data parallelism please use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m sglang_router.launch_server --model-path deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --port 30000 --dp-size 4
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Benchmarking
|
||||||
|
|
||||||
|
We use [suggested](https://github.com/deepseek-ai/DeepSeek-R1) parameters of `temperature=0.6`, `top_p=.95`, `max_new_tokens=32768`. The command line argument `num-tries` can be used to evaluate the model multiple times on the same question. We use the suggested `64` from the repo for AIME 2024. For LIMO, we use `8` as the number of tries due to the size of the dataset.
|
||||||
|
|
||||||
|
By default evaluate on LIMO dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py --parallel 256 --num-tries 64 --port 30000
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluate on AIME 2024 dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py --parallel 256 --port 30000 --data-path Maxwell-Jia/AIME_2024 --question-key Problem --answer-key Answer --num-tries 64
|
||||||
|
```
|
||||||
|
|
||||||
|
Evaluate on [AIME 2025 I dataset](https://huggingface.co/datasets/opencompass/AIME2025). For benchmark result see [here](https://matharena.ai/).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 bench_sglang.py --parallel 256 --port 30000 --data-path opencompass/AIME2025 --question-key question --answer-key answer --num-tries 64
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
| Dataset | Num Tries | Accuracy | Reference |
|
||||||
|
|------------|-----------|----------|-----------|
|
||||||
|
| LIMO | 8 | 47.7% | ? |
|
||||||
|
| AIME 2024 | 64 | 33.2% | 28.9% |
|
||||||
|
| AIME 2025 I| 64 | 29.9% | 25.0% |
|
||||||
269
benchmark/reasoning_benchmark/answer_extraction.py
Normal file
269
benchmark/reasoning_benchmark/answer_extraction.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/data_processing/answer_extraction.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import regex
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_fracs(string):
|
||||||
|
substrs = string.split("\\frac")
|
||||||
|
new_str = substrs[0]
|
||||||
|
if len(substrs) > 1:
|
||||||
|
substrs = substrs[1:]
|
||||||
|
for substr in substrs:
|
||||||
|
new_str += "\\frac"
|
||||||
|
if len(substr) > 0 and substr[0] == "{":
|
||||||
|
new_str += substr
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
assert len(substr) >= 2
|
||||||
|
except:
|
||||||
|
return string
|
||||||
|
a = substr[0]
|
||||||
|
b = substr[1]
|
||||||
|
if b != "{":
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||||
|
else:
|
||||||
|
new_str += "{" + a + "}{" + b + "}"
|
||||||
|
else:
|
||||||
|
if len(substr) > 2:
|
||||||
|
post_substr = substr[2:]
|
||||||
|
new_str += "{" + a + "}" + b + post_substr
|
||||||
|
else:
|
||||||
|
new_str += "{" + a + "}" + b
|
||||||
|
string = new_str
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_a_slash_b(string):
|
||||||
|
if len(string.split("/")) != 2:
|
||||||
|
return string
|
||||||
|
a = string.split("/")[0]
|
||||||
|
b = string.split("/")[1]
|
||||||
|
try:
|
||||||
|
if "sqrt" not in a:
|
||||||
|
a = int(a)
|
||||||
|
if "sqrt" not in b:
|
||||||
|
b = int(b)
|
||||||
|
assert string == "{}/{}".format(a, b)
|
||||||
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||||
|
return new_string
|
||||||
|
except:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_sqrt(string):
|
||||||
|
_string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
|
||||||
|
_string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
|
||||||
|
return _string
|
||||||
|
|
||||||
|
|
||||||
|
def _fix_tan(string):
|
||||||
|
_string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
|
||||||
|
_string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
|
||||||
|
return _string
|
||||||
|
|
||||||
|
|
||||||
|
def strip_string(string):
|
||||||
|
string = str(string).strip()
|
||||||
|
# linebreaks
|
||||||
|
string = string.replace("\n", "")
|
||||||
|
|
||||||
|
# right "."
|
||||||
|
string = string.rstrip(".")
|
||||||
|
|
||||||
|
# remove inverse spaces
|
||||||
|
string = string.replace("\\!", "")
|
||||||
|
# string = string.replace("\\ ", "")
|
||||||
|
|
||||||
|
# replace \\ with \
|
||||||
|
# string = string.replace("\\\\", "\\")
|
||||||
|
# string = string.replace("\\\\", "\\")
|
||||||
|
|
||||||
|
if string.startswith("\\text{") and string.endswith("}"):
|
||||||
|
string = string.split("{", 1)[1][:-1]
|
||||||
|
|
||||||
|
# replace tfrac and dfrac with frac
|
||||||
|
string = string.replace("tfrac", "frac")
|
||||||
|
string = string.replace("dfrac", "frac")
|
||||||
|
string = string.replace("cfrac", "frac")
|
||||||
|
|
||||||
|
# remove \left and \right
|
||||||
|
string = string.replace("\\left", "")
|
||||||
|
string = string.replace("\\right", "")
|
||||||
|
|
||||||
|
# Remove unit: miles, dollars if after is not none
|
||||||
|
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
||||||
|
if _string != "" and _string != string:
|
||||||
|
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
||||||
|
string = _string
|
||||||
|
|
||||||
|
# Remove circ (degrees)
|
||||||
|
string = string.replace("^{\\circ}", "").strip()
|
||||||
|
string = string.replace("^\\circ", "").strip()
|
||||||
|
|
||||||
|
string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
|
||||||
|
string = regex.sub(r"p\.m\.$", "", string).strip()
|
||||||
|
string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
|
||||||
|
|
||||||
|
# remove dollar signs
|
||||||
|
string = string.replace("\\$", "")
|
||||||
|
string = string.replace("$", "")
|
||||||
|
|
||||||
|
# string = string.replace("\\text", "")
|
||||||
|
string = string.replace("x\\in", "")
|
||||||
|
|
||||||
|
# remove percentage
|
||||||
|
string = string.replace("\\%", "%")
|
||||||
|
string = string.replace("\%", "%")
|
||||||
|
# string = string.replace("%", "")
|
||||||
|
|
||||||
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||||
|
string = string.replace(" .", " 0.")
|
||||||
|
string = string.replace("{.", "{0.")
|
||||||
|
|
||||||
|
# cdot
|
||||||
|
string = string.replace("\\cdot", "")
|
||||||
|
|
||||||
|
# inf
|
||||||
|
string = string.replace("infinity", "\\infty")
|
||||||
|
if "\\infty" not in string:
|
||||||
|
string = string.replace("inf", "\\infty")
|
||||||
|
string = string.replace("+\\inity", "\\infty")
|
||||||
|
|
||||||
|
# and
|
||||||
|
# string = string.replace("and", "")
|
||||||
|
string = string.replace("\\mathbf", "")
|
||||||
|
string = string.replace("\\mathrm", "")
|
||||||
|
|
||||||
|
# use regex to remove \mbox{...}
|
||||||
|
string = re.sub(r"\\mbox{.*?}", "", string)
|
||||||
|
|
||||||
|
# quote
|
||||||
|
string.replace("'", "")
|
||||||
|
string.replace('"', "")
|
||||||
|
|
||||||
|
# i, j
|
||||||
|
if "j" in string and "i" not in string:
|
||||||
|
string = string.replace("j", "i")
|
||||||
|
|
||||||
|
# replace a.000b where b is not number or b is end, with ab, use regex
|
||||||
|
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
||||||
|
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
||||||
|
|
||||||
|
# if empty, return empty string
|
||||||
|
if len(string) == 0:
|
||||||
|
return string
|
||||||
|
if string[0] == ".":
|
||||||
|
string = "0" + string
|
||||||
|
|
||||||
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||||
|
# if len(string.split("=")) == 2:
|
||||||
|
# if len(string.split("=")[0]) <= 2:
|
||||||
|
# string = string.split("=")[1]
|
||||||
|
|
||||||
|
string = _fix_sqrt(string)
|
||||||
|
string = _fix_tan(string)
|
||||||
|
string = string.replace(" ", "")
|
||||||
|
|
||||||
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||||||
|
string = _fix_fracs(string)
|
||||||
|
|
||||||
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||||
|
string = _fix_a_slash_b(string)
|
||||||
|
|
||||||
|
string = regex.sub(r"(\\|,|\.)+$", "", string)
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def extract_boxed_answers(text):
|
||||||
|
answers = []
|
||||||
|
for piece in text.split("boxed{")[1:]:
|
||||||
|
n = 0
|
||||||
|
for i in range(len(piece)):
|
||||||
|
if piece[i] == "{":
|
||||||
|
n += 1
|
||||||
|
elif piece[i] == "}":
|
||||||
|
n -= 1
|
||||||
|
if n < 0:
|
||||||
|
if i + 1 < len(piece) and piece[i + 1] == "%":
|
||||||
|
answers.append(piece[: i + 1])
|
||||||
|
else:
|
||||||
|
answers.append(piece[:i])
|
||||||
|
break
|
||||||
|
return answers
|
||||||
|
|
||||||
|
|
||||||
|
def extract_program_output(pred_str):
|
||||||
|
"""
|
||||||
|
extract output between the last ```output\n...\n```
|
||||||
|
"""
|
||||||
|
if "```output" not in pred_str:
|
||||||
|
return ""
|
||||||
|
if "```output" in pred_str:
|
||||||
|
pred_str = pred_str.split("```output")[-1]
|
||||||
|
if "```" in pred_str:
|
||||||
|
pred_str = pred_str.split("```")[0]
|
||||||
|
output = pred_str.strip()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(pred_str, exhaust=False):
|
||||||
|
pred = []
|
||||||
|
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
||||||
|
tmp = pred_str.split("final answer is $", 1)[1]
|
||||||
|
pred = [tmp.split("$. I hope", 1)[0].strip()]
|
||||||
|
elif "boxed" in pred_str:
|
||||||
|
pred = extract_boxed_answers(pred_str)
|
||||||
|
elif "he answer is" in pred_str:
|
||||||
|
pred = [pred_str.split("he answer is")[-1].strip()]
|
||||||
|
else:
|
||||||
|
program_output = extract_program_output(pred_str)
|
||||||
|
if program_output != "":
|
||||||
|
# fall back to program
|
||||||
|
pred.append(program_output)
|
||||||
|
else: # use the last number
|
||||||
|
pattern = "-?\d*\.?\d+"
|
||||||
|
ans = re.findall(pattern, pred_str.replace(",", ""))
|
||||||
|
if len(ans) >= 1:
|
||||||
|
ans = ans[-1]
|
||||||
|
else:
|
||||||
|
ans = ""
|
||||||
|
if ans:
|
||||||
|
pred.append(ans)
|
||||||
|
|
||||||
|
# multiple line
|
||||||
|
_pred = []
|
||||||
|
for ans in pred:
|
||||||
|
ans = ans.strip().split("\n")[0]
|
||||||
|
ans = ans.lstrip(":")
|
||||||
|
ans = ans.rstrip(".")
|
||||||
|
ans = ans.rstrip("/")
|
||||||
|
ans = strip_string(ans)
|
||||||
|
_pred.append(ans)
|
||||||
|
if exhaust:
|
||||||
|
return _pred
|
||||||
|
else:
|
||||||
|
return _pred[-1] if _pred else ""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_math_answer(question, reasoning, task):
|
||||||
|
answer = []
|
||||||
|
for ans in extract_answer(reasoning, exhaust=True):
|
||||||
|
if "separated by commas" in question and all(ch not in ans for ch in "()[]"):
|
||||||
|
answer.extend([a.strip() for a in ans.split(",")])
|
||||||
|
elif regex.search(r"\\text\{\s*and\s*\}", ans):
|
||||||
|
answer.extend(
|
||||||
|
[
|
||||||
|
a.strip()
|
||||||
|
for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split(
|
||||||
|
"[SEP]"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
answer.append(ans.strip())
|
||||||
|
return answer
|
||||||
119
benchmark/reasoning_benchmark/bench_sglang.py
Normal file
119
benchmark/reasoning_benchmark/bench_sglang.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import answer_extraction
|
||||||
|
import eval_utils
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def reasoning_gen(s, question: str):
|
||||||
|
s += sgl.user(
|
||||||
|
question
|
||||||
|
+ "\nPlease reason step by step, and put your final answer within \boxed{}."
|
||||||
|
)
|
||||||
|
s += sgl.assistant(
|
||||||
|
sgl.gen(
|
||||||
|
"answer",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(path: str, question_key: str, answer_key: str, num_tries: int):
|
||||||
|
raw_dataset = load_dataset(path)
|
||||||
|
questions = []
|
||||||
|
answers = []
|
||||||
|
for data in raw_dataset["train"]:
|
||||||
|
question = data[question_key]
|
||||||
|
answer = data[answer_key]
|
||||||
|
for _ in range(num_tries):
|
||||||
|
questions.append({"question": question})
|
||||||
|
answers.append({"answer": answer})
|
||||||
|
return questions, answers
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
# Select backend
|
||||||
|
sgl.set_default_backend(select_sglang_backend(args))
|
||||||
|
|
||||||
|
# Get dataset
|
||||||
|
questions, answers = convert_dataset(
|
||||||
|
args.data_path, args.question_key, args.answer_key, args.num_tries
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run requests
|
||||||
|
tic = time.time()
|
||||||
|
states = reasoning_gen.run_batch(
|
||||||
|
questions,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
temperature=0.6,
|
||||||
|
max_new_tokens=32768,
|
||||||
|
top_p=0.95,
|
||||||
|
)
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
# Extract answers
|
||||||
|
correct = 0
|
||||||
|
for i, state in enumerate(states):
|
||||||
|
try:
|
||||||
|
pred_answer = answer_extraction.extract_math_answer(
|
||||||
|
questions[i]["question"], state["answer"], "limo"
|
||||||
|
)
|
||||||
|
gt_answer = str(answers[i]["answer"])
|
||||||
|
# Use last answer if multiple were extracted
|
||||||
|
pred_answer = (
|
||||||
|
pred_answer[-1] if isinstance(pred_answer, list) else pred_answer
|
||||||
|
)
|
||||||
|
correct += 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error extracting answer: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Calculate accuracy
|
||||||
|
accuracy = correct / len(questions)
|
||||||
|
print(f"Accuracy: {accuracy}")
|
||||||
|
|
||||||
|
# Calculate output throughput
|
||||||
|
num_output_tokens = sum(
|
||||||
|
s.get_meta_info("answer")["completion_tokens"] for s in states
|
||||||
|
)
|
||||||
|
output_throughput = num_output_tokens / latency
|
||||||
|
print(f"Output throughput: {output_throughput} token/s")
|
||||||
|
|
||||||
|
# Dump results
|
||||||
|
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||||
|
|
||||||
|
# Write results
|
||||||
|
with open(args.result_file, "a") as fout:
|
||||||
|
value = {
|
||||||
|
"task": "limo",
|
||||||
|
"backend": args.backend,
|
||||||
|
"latency": round(latency, 3),
|
||||||
|
"accuracy": round(accuracy, 3),
|
||||||
|
"num_requests": len(questions),
|
||||||
|
"other": {
|
||||||
|
"num_questions": len(questions),
|
||||||
|
"parallel": args.parallel,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--data-path", type=str, default="GAIR/LIMO")
|
||||||
|
parser.add_argument("--question-key", type=str, default="question")
|
||||||
|
parser.add_argument("--answer-key", type=str, default="answer")
|
||||||
|
parser.add_argument("--num-tries", type=int, default=1)
|
||||||
|
add_common_sglang_args_and_parse(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
206
benchmark/reasoning_benchmark/eval_utils.py
Normal file
206
benchmark/reasoning_benchmark/eval_utils.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
|
||||||
|
|
||||||
|
from math import isclose
|
||||||
|
|
||||||
|
import regex
|
||||||
|
from sympy import N, simplify
|
||||||
|
from sympy.parsing.latex import parse_latex
|
||||||
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
||||||
|
|
||||||
|
def parse_digits(num):
|
||||||
|
# format: 234.23 || 23%
|
||||||
|
num = regex.sub(",", "", str(num))
|
||||||
|
try:
|
||||||
|
return float(num)
|
||||||
|
except:
|
||||||
|
if num.endswith("%"):
|
||||||
|
num = num[:-1]
|
||||||
|
if num.endswith("\\"):
|
||||||
|
num = num[:-1]
|
||||||
|
try:
|
||||||
|
return float(num) / 100
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_digit(num):
|
||||||
|
# paired with parse_digits
|
||||||
|
return parse_digits(num) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_equal(a, b):
|
||||||
|
def _parse(s):
|
||||||
|
for f in [parse_latex, parse_expr]:
|
||||||
|
try:
|
||||||
|
return f(s)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return s
|
||||||
|
|
||||||
|
a = _parse(a)
|
||||||
|
b = _parse(b)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if simplify(a - b) == 0:
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def math_equal(prediction, reference, include_percentage=True, is_close=True):
|
||||||
|
"""
|
||||||
|
Exact match of math if and only if:
|
||||||
|
1. numerical equal: both can convert to float and are equal
|
||||||
|
2. symbolic equal: both can convert to sympy expression and are equal
|
||||||
|
"""
|
||||||
|
if str(prediction) == str(reference):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try: # 1. numerical equal
|
||||||
|
if is_digit(prediction) and is_digit(reference):
|
||||||
|
prediction = parse_digits(prediction)
|
||||||
|
reference = parse_digits(reference)
|
||||||
|
# number questions
|
||||||
|
if include_percentage:
|
||||||
|
gt_result = [reference / 100, reference, reference * 100]
|
||||||
|
else:
|
||||||
|
gt_result = [reference]
|
||||||
|
for item in gt_result:
|
||||||
|
try:
|
||||||
|
if is_close:
|
||||||
|
if isclose(item, prediction, abs_tol=1e-3):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
if item == prediction:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not prediction and prediction not in [0, False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 2. symbolic equal
|
||||||
|
reference = str(reference).strip()
|
||||||
|
prediction = str(prediction).strip()
|
||||||
|
|
||||||
|
if (
|
||||||
|
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||||
|
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||||
|
):
|
||||||
|
pred_parts = prediction[1:-1].split(",")
|
||||||
|
ref_parts = reference[1:-1].split(",")
|
||||||
|
if len(pred_parts) == len(ref_parts):
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
math_equal(
|
||||||
|
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||||
|
)
|
||||||
|
for i in range(len(pred_parts))
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Add back matrix comparison
|
||||||
|
if (
|
||||||
|
(
|
||||||
|
prediction.startswith("\\begin{pmatrix}")
|
||||||
|
or prediction.startswith("\\begin{bmatrix}")
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
prediction.endswith("\\end{pmatrix}")
|
||||||
|
or prediction.endswith("\\end{bmatrix}")
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
reference.startswith("\\begin{pmatrix}")
|
||||||
|
or reference.startswith("\\begin{bmatrix}")
|
||||||
|
)
|
||||||
|
and (
|
||||||
|
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
||||||
|
)
|
||||||
|
):
|
||||||
|
pred_lines = [
|
||||||
|
line.strip()
|
||||||
|
for line in prediction[
|
||||||
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||||
|
].split("\\\\")
|
||||||
|
if line.strip()
|
||||||
|
]
|
||||||
|
ref_lines = [
|
||||||
|
line.strip()
|
||||||
|
for line in reference[
|
||||||
|
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||||
|
].split("\\\\")
|
||||||
|
if line.strip()
|
||||||
|
]
|
||||||
|
matched = True
|
||||||
|
if len(pred_lines) == len(ref_lines):
|
||||||
|
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||||
|
pred_parts = pred_line.split("&")
|
||||||
|
ref_parts = ref_line.split("&")
|
||||||
|
if len(pred_parts) == len(ref_parts):
|
||||||
|
if not all(
|
||||||
|
[
|
||||||
|
math_equal(
|
||||||
|
pred_parts[i],
|
||||||
|
ref_parts[i],
|
||||||
|
include_percentage,
|
||||||
|
is_close,
|
||||||
|
)
|
||||||
|
for i in range(len(pred_parts))
|
||||||
|
]
|
||||||
|
):
|
||||||
|
matched = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
matched = False
|
||||||
|
if not matched:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
matched = False
|
||||||
|
if matched:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Add back equation comparison
|
||||||
|
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||||
|
pred = prediction.split("=")
|
||||||
|
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||||
|
ref = reference.split("=")
|
||||||
|
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||||
|
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
prediction.count("=") == 1
|
||||||
|
and len(prediction.split("=")[0].strip()) <= 2
|
||||||
|
and "=" not in reference
|
||||||
|
):
|
||||||
|
if math_equal(
|
||||||
|
prediction.split("=")[1], reference, include_percentage, is_close
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif (
|
||||||
|
reference.count("=") == 1
|
||||||
|
and len(reference.split("=")[0].strip()) <= 2
|
||||||
|
and "=" not in prediction
|
||||||
|
):
|
||||||
|
if math_equal(
|
||||||
|
prediction, reference.split("=")[1], include_percentage, is_close
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# symbolic equal with sympy
|
||||||
|
if symbolic_equal(prediction, reference):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
Reference in New Issue
Block a user