#!/usr/bin/env python3 """Evaluate model on MATH-500 dataset (harder math problems).""" import argparse import json import re from datasets import load_dataset from vllm import LLM, SamplingParams def extract_answer(response: str) -> str: """Extract the final answer from model response.""" # Look for boxed answer first (common in MATH format) boxed_match = re.search(r'\\boxed\{([^}]+)\}', response) if boxed_match: return boxed_match.group(1).strip() # Look for "The answer is X" pattern answer_match = re.search(r'[Tt]he (?:final )?answer is[:\s]*([^\n.]+)', response) if answer_match: return answer_match.group(1).strip() # Look for "= X" at the end equals_match = re.search(r'=\s*([^\n=]+?)\s*$', response) if equals_match: return equals_match.group(1).strip() # Return last line as fallback lines = [l.strip() for l in response.strip().split('\n') if l.strip()] return lines[-1] if lines else "" def normalize_answer(answer: str) -> str: """Normalize answer for comparison.""" # Remove common formatting answer = answer.strip() answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer) answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer) answer = re.sub(r'\\left|\\right', '', answer) answer = re.sub(r'\$', '', answer) answer = answer.strip() return answer.lower() def answers_match(predicted: str, expected: str) -> bool: """Check if answers match (with some tolerance).""" pred_norm = normalize_answer(predicted) exp_norm = normalize_answer(expected) # Direct match if pred_norm == exp_norm: return True # Try numeric comparison try: pred_num = float(re.sub(r'[^\d.-]', '', pred_norm)) exp_num = float(re.sub(r'[^\d.-]', '', exp_norm)) if abs(pred_num - exp_num) < 1e-6: return True except: pass # Check if one contains the other if exp_norm in pred_norm or pred_norm in exp_norm: return True return False def main(): parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, default="final_model") parser.add_argument("--limit", type=int, default=100) args = parser.parse_args() print(f"Loading MATH-500 dataset...") dataset = load_dataset("HuggingFaceH4/MATH-500", split="test") if args.limit: dataset = dataset.select(range(min(args.limit, len(dataset)))) print(f"Evaluating {len(dataset)} problems...") # Load model print(f"Loading model from {args.model_path}...") llm = LLM( model=args.model_path, dtype="bfloat16", max_model_len=4096, gpu_memory_utilization=0.9, ) sampling_params = SamplingParams( temperature=0, max_tokens=2048, stop=["<|im_end|>", "<|endoftext|>"], ) # Prepare prompts prompts = [] for item in dataset: problem = item["problem"] prompt = f"<|im_start|>user\n{problem}<|im_end|>\n<|im_start|>assistant\n" prompts.append(prompt) # Generate print("Generating responses...") outputs = llm.generate(prompts, sampling_params) # Evaluate correct = 0 results_by_level = {} results_by_subject = {} for i, (item, output) in enumerate(zip(dataset, outputs)): response = output.outputs[0].text predicted = extract_answer(response) expected = item["answer"] level = item["level"] subject = item["subject"] is_correct = answers_match(predicted, expected) if is_correct: correct += 1 # Track by level if level not in results_by_level: results_by_level[level] = {"correct": 0, "total": 0} results_by_level[level]["total"] += 1 if is_correct: results_by_level[level]["correct"] += 1 # Track by subject if subject not in results_by_subject: results_by_subject[subject] = {"correct": 0, "total": 0} results_by_subject[subject]["total"] += 1 if is_correct: results_by_subject[subject]["correct"] += 1 if (i + 1) % 20 == 0: print(f"Progress: {i+1}/{len(dataset)}, Accuracy so far: {correct/(i+1)*100:.1f}%") # Print results accuracy = correct / len(dataset) * 100 print(f"\n{'='*60}") print(f"MATH-500 Results ({len(dataset)} problems)") print(f"{'='*60}") print(f"Overall Accuracy: {accuracy:.1f}% ({correct}/{len(dataset)})") print(f"\nBy Level:") for level in sorted(results_by_level.keys()): stats = results_by_level[level] acc = stats["correct"] / stats["total"] * 100 print(f" {level}: {acc:.1f}% ({stats['correct']}/{stats['total']})") print(f"\nBy Subject:") for subject in sorted(results_by_subject.keys()): stats = results_by_subject[subject] acc = stats["correct"] / stats["total"] * 100 print(f" {subject}: {acc:.1f}% ({stats['correct']}/{stats['total']})") if __name__ == "__main__": main()