初始化项目,由ModelHub XC社区提供模型
Model: nv-community/AceMath-7B-Instruct Source: Original Platform
This commit is contained in:
600
evaluation/grader.py
Normal file
600
evaluation/grader.py
Normal file
@@ -0,0 +1,600 @@
|
||||
|
||||
"""
|
||||
This script is adapted from Qwen2.5-Math
|
||||
https://github.com/QwenLM/Qwen2.5-Math/blob/main/evaluation/grader.py
|
||||
"""
|
||||
|
||||
import re
|
||||
import regex
|
||||
import multiprocessing
|
||||
from math import isclose
|
||||
from typing import Union
|
||||
from collections import defaultdict
|
||||
|
||||
from sympy import simplify, N
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
|
||||
def latex2sympy(sympy: str, variable_values={}):
|
||||
# record frac
|
||||
global frac_type
|
||||
if sympy.find(r'\frac') != -1:
|
||||
frac_type = r'\frac'
|
||||
if sympy.find(r'\dfrac') != -1:
|
||||
frac_type = r'\dfrac'
|
||||
if sympy.find(r'\tfrac') != -1:
|
||||
frac_type = r'\tfrac'
|
||||
sympy = sympy.replace(r'\dfrac', r'\frac')
|
||||
sympy = sympy.replace(r'\tfrac', r'\frac')
|
||||
# Translate Transpose
|
||||
sympy = sympy.replace(r'\mathrm{T}', 'T', -1)
|
||||
# Translate Derivative
|
||||
sympy = sympy.replace(r'\mathrm{d}', 'd', -1).replace(r'{\rm d}', 'd', -1)
|
||||
# Translate Matrix
|
||||
sympy = sympy.replace(r'\left[\begin{matrix}', r'\begin{bmatrix}', -1).replace(r'\end{matrix}\right]', r'\end{bmatrix}', -1)
|
||||
# Translate Permutation
|
||||
sympy = re.sub(r"\(([a-zA-Z0-9+\-*/\\ ]+?)\)_{([a-zA-Z0-9+\-*/\\ ]+?)}", r"\\frac{(\1)!}{((\1)-(\2))!}", sympy)
|
||||
# Remove \displaystyle
|
||||
sympy = sympy.replace(r'\displaystyle', ' ', -1)
|
||||
# Remove \quad
|
||||
sympy = sympy.replace(r'\quad', ' ', -1).replace(r'\qquad', ' ', -1).replace(r'~', ' ', -1).replace(r'\,', ' ', -1)
|
||||
# Remove $
|
||||
sympy = sympy.replace(r'$', ' ', -1)
|
||||
|
||||
# variable values
|
||||
global VARIABLE_VALUES
|
||||
if len(variable_values) > 0:
|
||||
VARIABLE_VALUES = variable_values
|
||||
else:
|
||||
VARIABLE_VALUES = {}
|
||||
|
||||
# setup listener
|
||||
matherror = MathErrorListener(sympy)
|
||||
|
||||
# stream input
|
||||
stream = InputStream(sympy)
|
||||
lex = PSLexer(stream)
|
||||
lex.removeErrorListeners()
|
||||
lex.addErrorListener(matherror)
|
||||
|
||||
tokens = CommonTokenStream(lex)
|
||||
parser = PSParser(tokens)
|
||||
|
||||
# remove default console error listener
|
||||
parser.removeErrorListeners()
|
||||
parser.addErrorListener(matherror)
|
||||
|
||||
# process the input
|
||||
return_data = None
|
||||
math = parser.math()
|
||||
|
||||
# if a list
|
||||
if math.relation_list():
|
||||
return_data = []
|
||||
|
||||
# go over list items
|
||||
relation_list = math.relation_list().relation_list_content()
|
||||
for list_item in relation_list.relation():
|
||||
expr = convert_relation(list_item)
|
||||
return_data.append(expr)
|
||||
|
||||
# if not, do default
|
||||
else:
|
||||
relation = math.relation()
|
||||
return_data = convert_relation(relation)
|
||||
|
||||
return return_data
|
||||
|
||||
|
||||
def math_answer_cleaning(answer, dataset_name):
|
||||
"""
|
||||
remove irrelevant strings and unify the answer format before checking whether the answers are equal
|
||||
"""
|
||||
def _is_completely_wrapped_by_text(input_string):
|
||||
pattern = r'^\\text{(.*)}$'
|
||||
match = re.match(pattern, input_string)
|
||||
if match:
|
||||
## input_string is completely wrapped by \text{}
|
||||
extracted_content = match.group(1)
|
||||
extracted_content = extracted_content.replace("(", "").replace(")", "").replace(",", "")
|
||||
return extracted_content
|
||||
else:
|
||||
return None
|
||||
|
||||
## remove irrelevant \\text and space
|
||||
extracted_content = _is_completely_wrapped_by_text(answer)
|
||||
answer = extracted_content if extracted_content else answer
|
||||
|
||||
## e.g., convert 5,\!460 into 5460; convert 14{,}916 into 14916 convert \$4 into 4
|
||||
answer = answer.replace(",\!", "").replace("{,}", "").replace("\$", "")
|
||||
## e.g., convert \dfrac{3}{2} into frac{3}{2}
|
||||
answer = answer.replace("dfrac{", "frac{").replace("tfrac{", "frac{")
|
||||
## e.g., convert 121^\circ into 121
|
||||
answer = answer.replace("^\circ", "")
|
||||
answer = answer.replace("^{\circ}", "")
|
||||
## remove \quad
|
||||
answer = answer.replace("\quad", "")
|
||||
## remove space
|
||||
answer = answer.replace(" ", "")
|
||||
## remove \n
|
||||
answer = answer.replace("\n", "").replace("\\n", "")
|
||||
## e.g., convert 3.54\times10^{10} into 3.54e10
|
||||
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^{([+-]?\d+)}', r'\1e\2', answer)
|
||||
## e.g., convert 3.54\times10^10 into 3.54e10
|
||||
answer = re.sub(r'([+-]?\d*\.?\d+)[\\]times10\^([+-]?\d+)', r'\1e\2', answer)
|
||||
## e.g., convert 558\,\text{nm} into 558
|
||||
answer = re.sub(r'\\,\\text\{.*?\}', '', answer)
|
||||
## e.g., convert 558\text{nm} into 558
|
||||
answer = re.sub(r'\\text\{.*?\}', '', answer)
|
||||
## e.g., convert 2^{10} into 2^10
|
||||
answer = re.sub(r'(\d+)\^{(\d+)}', r'\1^\2', answer)
|
||||
## lowercase
|
||||
answer = answer.lower()
|
||||
|
||||
if dataset_name == "collegemath":
|
||||
## convert 558\mathrm{ft} into 558
|
||||
answer = re.sub(r'\\mathrm\{.*?\}', '', answer)
|
||||
## clean noisy answer
|
||||
answer = re.sub(r'\$\([^)]*\)', '', answer)
|
||||
if answer.endswith("-"):
|
||||
answer = answer[:-1]
|
||||
if answer.endswith("."):
|
||||
answer = answer[:-1]
|
||||
if answer.endswith("hours"):
|
||||
answer = answer[:-len("hours")]
|
||||
## extract final answer after '=' or ':'
|
||||
if "=" in answer:
|
||||
answer = answer.split("=", 1)[1]
|
||||
if ":" in answer:
|
||||
answer = answer.split(":", 1)[1]
|
||||
## \emptyset and \oslash both reprsent empty set in latex
|
||||
answer = answer.replace("\\emptyset", "\\oslash")
|
||||
if dataset_name == "gsm8k":
|
||||
# Example: 5,600 -> 5600
|
||||
answer = answer.replace(',', '')
|
||||
if dataset_name == "gaokao2023en":
|
||||
unit_strings = ['students', 'dollars', 'boxes', 'feet', 'kilometers', 'meters', 'degreesontheBreadusscale', '$', 'a.m.', 'am', 'minutes']
|
||||
for unit in unit_strings:
|
||||
answer = answer.replace(unit, "")
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def extract_final_answer(output):
|
||||
pattern_re = re.compile(r"\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL)
|
||||
all_matches = pattern_re.findall(output)
|
||||
|
||||
if len(all_matches) >= 1:
|
||||
extracted_answer = all_matches[-1]
|
||||
else:
|
||||
extracted_answer = None
|
||||
|
||||
return extracted_answer, all_matches
|
||||
|
||||
|
||||
def round_number(answer):
|
||||
def _is_float(string):
|
||||
try:
|
||||
float(string)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
if _is_float(answer) and float(answer) < 1:
|
||||
## to consider the case like 5.56e-10 (convert 5.56e-10 into 5.6e-10)
|
||||
## still return a string type
|
||||
return f"{float(answer):.2g}"
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
def choice_answer_clean(pred: str):
|
||||
pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
|
||||
# Clean the answer based on the dataset
|
||||
tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
|
||||
if tmp:
|
||||
pred = tmp
|
||||
else:
|
||||
pred = [pred.strip().strip(".")]
|
||||
pred = pred[-1]
|
||||
# Remove the period at the end, again!
|
||||
pred = pred.rstrip(".").rstrip("/")
|
||||
return pred
|
||||
|
||||
|
||||
def parse_digits(num):
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def str_to_pmatrix(input_str):
|
||||
input_str = input_str.strip()
|
||||
matrix_str = re.findall(r"\{.*,.*\}", input_str)
|
||||
pmatrix_list = []
|
||||
|
||||
for m in matrix_str:
|
||||
m = m.strip("{}")
|
||||
pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
|
||||
pmatrix_list.append(pmatrix)
|
||||
|
||||
return ", ".join(pmatrix_list)
|
||||
|
||||
|
||||
def math_equal(
|
||||
prediction: Union[bool, float, str],
|
||||
reference: Union[float, str],
|
||||
include_percentage: bool = True,
|
||||
is_close: bool = True,
|
||||
timeout: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
if prediction is None or reference is None:
|
||||
return False
|
||||
if str(prediction.strip().lower()) == str(reference.strip().lower()):
|
||||
return True
|
||||
if (
|
||||
reference in ["A", "B", "C", "D", "E"]
|
||||
and choice_answer_clean(prediction) == reference
|
||||
):
|
||||
return True
|
||||
|
||||
# fraction equal
|
||||
if fraction_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
try: # numerical equal
|
||||
if round_number(prediction) == round_number(reference):
|
||||
return True
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if numeric_equal(prediction, item):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
## pmatrix (amps)
|
||||
if "pmatrix" in prediction and not "pmatrix" in reference:
|
||||
reference = str_to_pmatrix(reference)
|
||||
|
||||
## deal with [], (), {}
|
||||
pred_str, ref_str = prediction, reference
|
||||
if (
|
||||
prediction.startswith("[")
|
||||
and prediction.endswith("]")
|
||||
and not reference.startswith("(")
|
||||
) or (
|
||||
prediction.startswith("(")
|
||||
and prediction.endswith(")")
|
||||
and not reference.startswith("[")
|
||||
):
|
||||
pred_str = pred_str.strip("[]()")
|
||||
ref_str = ref_str.strip("[]()")
|
||||
for s in ["{", "}", "(", ")"]:
|
||||
ref_str = ref_str.replace(s, "")
|
||||
pred_str = pred_str.replace(s, "")
|
||||
if pred_str.lower() == ref_str.lower():
|
||||
return True
|
||||
|
||||
## [a, b] vs. [c, d], return a==c and b==d
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
if (
|
||||
(
|
||||
prediction.startswith("\\begin{pmatrix}")
|
||||
or prediction.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
prediction.endswith("\\end{pmatrix}")
|
||||
or prediction.endswith("\\end{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.startswith("\\begin{pmatrix}")
|
||||
or reference.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
||||
)
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i],
|
||||
ref_parts[i],
|
||||
include_percentage,
|
||||
is_close,
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif (
|
||||
prediction.count("=") == 1
|
||||
and len(prediction.split("=")[0].strip()) <= 2
|
||||
and "=" not in reference
|
||||
):
|
||||
if math_equal(
|
||||
prediction.split("=")[1], reference, include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
reference.count("=") == 1
|
||||
and len(reference.split("=")[0].strip()) <= 2
|
||||
and "=" not in prediction
|
||||
):
|
||||
if math_equal(
|
||||
prediction, reference.split("=")[1], include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if timeout:
|
||||
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
||||
return True
|
||||
else:
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def numeric_equal(prediction: float, reference: float):
|
||||
# Note that relative tolerance has significant impact
|
||||
# on the result of the synthesized GSM-Hard dataset
|
||||
# if reference.is_integer():
|
||||
# return isclose(reference, round(prediction), abs_tol=1e-4)
|
||||
# else:
|
||||
# prediction = round(prediction, len(str(reference).split(".")[-1]))
|
||||
return isclose(reference, prediction, rel_tol=1e-4)
|
||||
|
||||
|
||||
def fraction_equal(prediction, reference):
|
||||
def _calculate_numbers(input_string):
|
||||
try:
|
||||
result = eval(input_string)
|
||||
return result
|
||||
except:
|
||||
return None
|
||||
|
||||
reference = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', reference)
|
||||
prediction = re.sub(r'\\frac{(.*?)}{(.*?)}', r'(\1/\2)', prediction)
|
||||
|
||||
if reference == prediction:
|
||||
return True
|
||||
|
||||
reference = _calculate_numbers(reference)
|
||||
prediction = _calculate_numbers(prediction)
|
||||
|
||||
if reference and reference == prediction:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr, latex2sympy]:
|
||||
try:
|
||||
return f(s.replace("\\\\", "\\"))
|
||||
except:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
# direct equal
|
||||
try:
|
||||
if str(a) == str(b) or a == b:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# simplify equal
|
||||
try:
|
||||
if a.equals(b) or simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# equation equal
|
||||
try:
|
||||
if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if numeric_equal(float(N(a)), float(N(b))):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
# matrix
|
||||
try:
|
||||
# if a and b are matrix
|
||||
if a.shape == b.shape:
|
||||
_a = a.applyfunc(lambda x: round(x, 3))
|
||||
_b = b.applyfunc(lambda x: round(x, 3))
|
||||
if _a.equals(_b):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def symbolic_equal_process(a, b, output_queue):
|
||||
result = symbolic_equal(a, b)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def math_equal_process(prediction, reference, output_queue):
|
||||
result = math_equal(prediction, reference, timeout=True)
|
||||
output_queue.put(result)
|
||||
|
||||
|
||||
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
||||
output_queue = multiprocessing.Queue()
|
||||
process_args = args + (output_queue,)
|
||||
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
||||
process.start()
|
||||
process.join(timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
process.join()
|
||||
return False
|
||||
|
||||
return output_queue.get()
|
||||
|
||||
|
||||
def check_correctness_of_multiple_answer_cases(prediction, reference, all_matches):
|
||||
|
||||
if prediction.replace(",", "").replace("$", "") == reference.replace(",", "").replace("$", ""):
|
||||
return True
|
||||
|
||||
if not prediction.split("=")[-1] == reference.split("=")[-1].replace("$", ""):
|
||||
return False
|
||||
|
||||
if "," in reference or "or" in reference or "and" in reference:
|
||||
## there are multiple answers
|
||||
if len(all_matches) <= 1:
|
||||
return False
|
||||
|
||||
prediction1 = prediction.split("=")[-1]
|
||||
prediction2 = all_matches[-2].split("=")[-1]
|
||||
reference = reference.replace("$", "")
|
||||
if "or" in reference:
|
||||
gold_list = reference.split("or", 1)
|
||||
elif "and" in reference:
|
||||
gold_list = reference.split("and", 1)
|
||||
else:
|
||||
gold_list = reference.split(",", 1)
|
||||
|
||||
reference1 = gold_list[-1].split("=")[-1]
|
||||
reference2 = gold_list[-2].split("=")[-1]
|
||||
|
||||
if math_equal(prediction1, reference1) and math_equal(prediction2, reference2):
|
||||
return True
|
||||
elif math_equal(prediction2, reference1) and math_equal(prediction1, reference2):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def is_equal(model_output, reference, dataset_name):
|
||||
|
||||
extracted_model_answer, all_matches = extract_final_answer(model_output)
|
||||
if extracted_model_answer is None or reference is None:
|
||||
return False
|
||||
|
||||
extracted_model_answer = math_answer_cleaning(extracted_model_answer, dataset_name)
|
||||
reference = math_answer_cleaning(reference, dataset_name)
|
||||
|
||||
# if math_equal(prediction, reference, timeout=True):
|
||||
if call_with_timeout(math_equal_process, extracted_model_answer, reference):
|
||||
return True
|
||||
|
||||
if dataset_name == "collegemath":
|
||||
return check_correctness_of_multiple_answer_cases(extracted_model_answer, reference, all_matches)
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user