adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
206
benchmark/reasoning_benchmark/eval_utils.py
Normal file
206
benchmark/reasoning_benchmark/eval_utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Adapted from https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
|
||||
|
||||
from math import isclose
|
||||
|
||||
import regex
|
||||
from sympy import N, simplify
|
||||
from sympy.parsing.latex import parse_latex
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
||||
|
||||
def parse_digits(num):
|
||||
# format: 234.23 || 23%
|
||||
num = regex.sub(",", "", str(num))
|
||||
try:
|
||||
return float(num)
|
||||
except:
|
||||
if num.endswith("%"):
|
||||
num = num[:-1]
|
||||
if num.endswith("\\"):
|
||||
num = num[:-1]
|
||||
try:
|
||||
return float(num) / 100
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def is_digit(num):
|
||||
# paired with parse_digits
|
||||
return parse_digits(num) is not None
|
||||
|
||||
|
||||
def symbolic_equal(a, b):
|
||||
def _parse(s):
|
||||
for f in [parse_latex, parse_expr]:
|
||||
try:
|
||||
return f(s)
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
a = _parse(a)
|
||||
b = _parse(b)
|
||||
|
||||
try:
|
||||
if simplify(a - b) == 0:
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
if isclose(N(a), N(b), abs_tol=1e-3):
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def math_equal(prediction, reference, include_percentage=True, is_close=True):
|
||||
"""
|
||||
Exact match of math if and only if:
|
||||
1. numerical equal: both can convert to float and are equal
|
||||
2. symbolic equal: both can convert to sympy expression and are equal
|
||||
"""
|
||||
if str(prediction) == str(reference):
|
||||
return True
|
||||
|
||||
try: # 1. numerical equal
|
||||
if is_digit(prediction) and is_digit(reference):
|
||||
prediction = parse_digits(prediction)
|
||||
reference = parse_digits(reference)
|
||||
# number questions
|
||||
if include_percentage:
|
||||
gt_result = [reference / 100, reference, reference * 100]
|
||||
else:
|
||||
gt_result = [reference]
|
||||
for item in gt_result:
|
||||
try:
|
||||
if is_close:
|
||||
if isclose(item, prediction, abs_tol=1e-3):
|
||||
return True
|
||||
else:
|
||||
if item == prediction:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
if not prediction and prediction not in [0, False]:
|
||||
return False
|
||||
|
||||
# 2. symbolic equal
|
||||
reference = str(reference).strip()
|
||||
prediction = str(prediction).strip()
|
||||
|
||||
if (
|
||||
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
||||
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
||||
):
|
||||
pred_parts = prediction[1:-1].split(",")
|
||||
ref_parts = reference[1:-1].split(",")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i], ref_parts[i], include_percentage, is_close
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
# Add back matrix comparison
|
||||
if (
|
||||
(
|
||||
prediction.startswith("\\begin{pmatrix}")
|
||||
or prediction.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
prediction.endswith("\\end{pmatrix}")
|
||||
or prediction.endswith("\\end{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.startswith("\\begin{pmatrix}")
|
||||
or reference.startswith("\\begin{bmatrix}")
|
||||
)
|
||||
and (
|
||||
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
||||
)
|
||||
):
|
||||
pred_lines = [
|
||||
line.strip()
|
||||
for line in prediction[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
ref_lines = [
|
||||
line.strip()
|
||||
for line in reference[
|
||||
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
||||
].split("\\\\")
|
||||
if line.strip()
|
||||
]
|
||||
matched = True
|
||||
if len(pred_lines) == len(ref_lines):
|
||||
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
||||
pred_parts = pred_line.split("&")
|
||||
ref_parts = ref_line.split("&")
|
||||
if len(pred_parts) == len(ref_parts):
|
||||
if not all(
|
||||
[
|
||||
math_equal(
|
||||
pred_parts[i],
|
||||
ref_parts[i],
|
||||
include_percentage,
|
||||
is_close,
|
||||
)
|
||||
for i in range(len(pred_parts))
|
||||
]
|
||||
):
|
||||
matched = False
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if not matched:
|
||||
break
|
||||
else:
|
||||
matched = False
|
||||
if matched:
|
||||
return True
|
||||
|
||||
# Add back equation comparison
|
||||
if prediction.count("=") == 1 and reference.count("=") == 1:
|
||||
pred = prediction.split("=")
|
||||
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
||||
ref = reference.split("=")
|
||||
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
||||
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
||||
return True
|
||||
elif (
|
||||
prediction.count("=") == 1
|
||||
and len(prediction.split("=")[0].strip()) <= 2
|
||||
and "=" not in reference
|
||||
):
|
||||
if math_equal(
|
||||
prediction.split("=")[1], reference, include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
elif (
|
||||
reference.count("=") == 1
|
||||
and len(reference.split("=")[0].strip()) <= 2
|
||||
and "=" not in prediction
|
||||
):
|
||||
if math_equal(
|
||||
prediction, reference.split("=")[1], include_percentage, is_close
|
||||
):
|
||||
return True
|
||||
|
||||
# symbolic equal with sympy
|
||||
if symbolic_equal(prediction, reference):
|
||||
return True
|
||||
|
||||
return False
|
||||
Reference in New Issue
Block a user