Variance measure for reasoning benchmark (#3677)
This commit is contained in:
@@ -4,6 +4,7 @@ import time
|
|||||||
|
|
||||||
import answer_extraction
|
import answer_extraction
|
||||||
import eval_utils
|
import eval_utils
|
||||||
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
@@ -61,26 +62,40 @@ def main(args):
|
|||||||
)
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Extract answers
|
# Extract results and record outcomes in a list.
|
||||||
correct = 0
|
outcomes = []
|
||||||
for i, state in enumerate(states):
|
for i, state in enumerate(states):
|
||||||
try:
|
try:
|
||||||
pred_answer = answer_extraction.extract_math_answer(
|
pred_answer = answer_extraction.extract_math_answer(
|
||||||
questions[i]["question"], state["answer"], "limo"
|
questions[i]["question"], state["answer"], "limo"
|
||||||
)
|
)
|
||||||
gt_answer = str(answers[i]["answer"])
|
gt_answer = str(answers[i]["answer"])
|
||||||
# Use last answer if multiple were extracted
|
|
||||||
pred_answer = (
|
pred_answer = (
|
||||||
pred_answer[-1] if isinstance(pred_answer, list) else pred_answer
|
pred_answer[-1] if isinstance(pred_answer, list) else pred_answer
|
||||||
)
|
)
|
||||||
correct += 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0
|
is_correct = 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error extracting answer: {e}")
|
print(f"Error extracting answer: {e}")
|
||||||
pass
|
is_correct = 0
|
||||||
|
|
||||||
# Calculate accuracy
|
outcomes.append(is_correct)
|
||||||
accuracy = correct / len(questions)
|
|
||||||
print(f"Accuracy: {accuracy}")
|
# Calculate overall accuracy using numpy
|
||||||
|
overall_accuracy = np.mean(outcomes)
|
||||||
|
print(f"Overall Accuracy: {overall_accuracy}")
|
||||||
|
|
||||||
|
# Calculate mean standard error over questions if num_tries >= 2
|
||||||
|
if args.num_tries > 1:
|
||||||
|
outcomes_np = np.array(outcomes).reshape(-1, args.num_tries)
|
||||||
|
# Using sample standard deviation with ddof=1
|
||||||
|
std_per_question = np.std(outcomes_np, axis=1, ddof=1)
|
||||||
|
# Compute the standard error for each question: std / sqrt(num_tries)
|
||||||
|
se_per_question = std_per_question / np.sqrt(args.num_tries)
|
||||||
|
mean_se = se_per_question.mean()
|
||||||
|
print(f"Mean Standard Error of Accuracy across questions: {mean_se}")
|
||||||
|
else:
|
||||||
|
mean_se = None
|
||||||
|
print("Not enough samples per question to compute standard error.")
|
||||||
|
|
||||||
# Calculate output throughput
|
# Calculate output throughput
|
||||||
num_output_tokens = sum(
|
num_output_tokens = sum(
|
||||||
@@ -98,7 +113,8 @@ def main(args):
|
|||||||
"task": "limo",
|
"task": "limo",
|
||||||
"backend": args.backend,
|
"backend": args.backend,
|
||||||
"latency": round(latency, 3),
|
"latency": round(latency, 3),
|
||||||
"accuracy": round(accuracy, 3),
|
"overall_accuracy": round(overall_accuracy, 3),
|
||||||
|
"mean_se_accuracy": round(mean_se, 3) if mean_se is not None else None,
|
||||||
"num_requests": len(questions),
|
"num_requests": len(questions),
|
||||||
"other": {
|
"other": {
|
||||||
"num_questions": len(questions),
|
"num_questions": len(questions),
|
||||||
|
|||||||
Reference in New Issue
Block a user