diff --git a/benchmark/reasoning_benchmark/bench_sglang.py b/benchmark/reasoning_benchmark/bench_sglang.py index de0f13971..c83204960 100644 --- a/benchmark/reasoning_benchmark/bench_sglang.py +++ b/benchmark/reasoning_benchmark/bench_sglang.py @@ -4,6 +4,7 @@ import time import answer_extraction import eval_utils +import numpy as np from datasets import load_dataset import sglang as sgl @@ -61,26 +62,40 @@ def main(args): ) latency = time.time() - tic - # Extract answers - correct = 0 + # Extract results and record outcomes in a list. + outcomes = [] for i, state in enumerate(states): try: pred_answer = answer_extraction.extract_math_answer( questions[i]["question"], state["answer"], "limo" ) gt_answer = str(answers[i]["answer"]) - # Use last answer if multiple were extracted pred_answer = ( pred_answer[-1] if isinstance(pred_answer, list) else pred_answer ) - correct += 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0 + is_correct = 1 if eval_utils.math_equal(pred_answer, gt_answer) else 0 except Exception as e: print(f"Error extracting answer: {e}") - pass + is_correct = 0 - # Calculate accuracy - accuracy = correct / len(questions) - print(f"Accuracy: {accuracy}") + outcomes.append(is_correct) + + # Calculate overall accuracy using numpy + overall_accuracy = np.mean(outcomes) + print(f"Overall Accuracy: {overall_accuracy}") + + # Calculate mean standard error over questions if num_tries >= 2 + if args.num_tries > 1: + outcomes_np = np.array(outcomes).reshape(-1, args.num_tries) + # Using sample standard deviation with ddof=1 + std_per_question = np.std(outcomes_np, axis=1, ddof=1) + # Compute the standard error for each question: std / sqrt(num_tries) + se_per_question = std_per_question / np.sqrt(args.num_tries) + mean_se = se_per_question.mean() + print(f"Mean Standard Error of Accuracy across questions: {mean_se}") + else: + mean_se = None + print("Not enough samples per question to compute standard error.") # Calculate output throughput num_output_tokens = sum( @@ -98,7 +113,8 @@ def main(args): "task": "limo", "backend": args.backend, "latency": round(latency, 3), - "accuracy": round(accuracy, 3), + "overall_accuracy": round(overall_accuracy, 3), + "mean_se_accuracy": round(mean_se, 3) if mean_se is not None else None, "num_requests": len(questions), "other": { "num_questions": len(questions),