160 lines
5.1 KiB
Python
160 lines
5.1 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Evaluate model on MATH-500 dataset (harder math problems)."""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
from datasets import load_dataset
|
||
|
|
from vllm import LLM, SamplingParams
|
||
|
|
|
||
|
|
def extract_answer(response: str) -> str:
|
||
|
|
"""Extract the final answer from model response."""
|
||
|
|
# Look for boxed answer first (common in MATH format)
|
||
|
|
boxed_match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||
|
|
if boxed_match:
|
||
|
|
return boxed_match.group(1).strip()
|
||
|
|
|
||
|
|
# Look for "The answer is X" pattern
|
||
|
|
answer_match = re.search(r'[Tt]he (?:final )?answer is[:\s]*([^\n.]+)', response)
|
||
|
|
if answer_match:
|
||
|
|
return answer_match.group(1).strip()
|
||
|
|
|
||
|
|
# Look for "= X" at the end
|
||
|
|
equals_match = re.search(r'=\s*([^\n=]+?)\s*$', response)
|
||
|
|
if equals_match:
|
||
|
|
return equals_match.group(1).strip()
|
||
|
|
|
||
|
|
# Return last line as fallback
|
||
|
|
lines = [l.strip() for l in response.strip().split('\n') if l.strip()]
|
||
|
|
return lines[-1] if lines else ""
|
||
|
|
|
||
|
|
def normalize_answer(answer: str) -> str:
|
||
|
|
"""Normalize answer for comparison."""
|
||
|
|
# Remove common formatting
|
||
|
|
answer = answer.strip()
|
||
|
|
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
|
||
|
|
answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer)
|
||
|
|
answer = re.sub(r'\\left|\\right', '', answer)
|
||
|
|
answer = re.sub(r'\$', '', answer)
|
||
|
|
answer = answer.strip()
|
||
|
|
return answer.lower()
|
||
|
|
|
||
|
|
def answers_match(predicted: str, expected: str) -> bool:
|
||
|
|
"""Check if answers match (with some tolerance)."""
|
||
|
|
pred_norm = normalize_answer(predicted)
|
||
|
|
exp_norm = normalize_answer(expected)
|
||
|
|
|
||
|
|
# Direct match
|
||
|
|
if pred_norm == exp_norm:
|
||
|
|
return True
|
||
|
|
|
||
|
|
# Try numeric comparison
|
||
|
|
try:
|
||
|
|
pred_num = float(re.sub(r'[^\d.-]', '', pred_norm))
|
||
|
|
exp_num = float(re.sub(r'[^\d.-]', '', exp_norm))
|
||
|
|
if abs(pred_num - exp_num) < 1e-6:
|
||
|
|
return True
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
# Check if one contains the other
|
||
|
|
if exp_norm in pred_norm or pred_norm in exp_norm:
|
||
|
|
return True
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--model-path", type=str, default="final_model")
|
||
|
|
parser.add_argument("--limit", type=int, default=100)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
print(f"Loading MATH-500 dataset...")
|
||
|
|
dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
|
||
|
|
|
||
|
|
if args.limit:
|
||
|
|
dataset = dataset.select(range(min(args.limit, len(dataset))))
|
||
|
|
|
||
|
|
print(f"Evaluating {len(dataset)} problems...")
|
||
|
|
|
||
|
|
# Load model
|
||
|
|
print(f"Loading model from {args.model_path}...")
|
||
|
|
llm = LLM(
|
||
|
|
model=args.model_path,
|
||
|
|
dtype="bfloat16",
|
||
|
|
max_model_len=4096,
|
||
|
|
gpu_memory_utilization=0.9,
|
||
|
|
)
|
||
|
|
|
||
|
|
sampling_params = SamplingParams(
|
||
|
|
temperature=0,
|
||
|
|
max_tokens=2048,
|
||
|
|
stop=["<|im_end|>", "<|endoftext|>"],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Prepare prompts
|
||
|
|
prompts = []
|
||
|
|
for item in dataset:
|
||
|
|
problem = item["problem"]
|
||
|
|
prompt = f"<|im_start|>user\n{problem}<|im_end|>\n<|im_start|>assistant\n"
|
||
|
|
prompts.append(prompt)
|
||
|
|
|
||
|
|
# Generate
|
||
|
|
print("Generating responses...")
|
||
|
|
outputs = llm.generate(prompts, sampling_params)
|
||
|
|
|
||
|
|
# Evaluate
|
||
|
|
correct = 0
|
||
|
|
results_by_level = {}
|
||
|
|
results_by_subject = {}
|
||
|
|
|
||
|
|
for i, (item, output) in enumerate(zip(dataset, outputs)):
|
||
|
|
response = output.outputs[0].text
|
||
|
|
predicted = extract_answer(response)
|
||
|
|
expected = item["answer"]
|
||
|
|
level = item["level"]
|
||
|
|
subject = item["subject"]
|
||
|
|
|
||
|
|
is_correct = answers_match(predicted, expected)
|
||
|
|
if is_correct:
|
||
|
|
correct += 1
|
||
|
|
|
||
|
|
# Track by level
|
||
|
|
if level not in results_by_level:
|
||
|
|
results_by_level[level] = {"correct": 0, "total": 0}
|
||
|
|
results_by_level[level]["total"] += 1
|
||
|
|
if is_correct:
|
||
|
|
results_by_level[level]["correct"] += 1
|
||
|
|
|
||
|
|
# Track by subject
|
||
|
|
if subject not in results_by_subject:
|
||
|
|
results_by_subject[subject] = {"correct": 0, "total": 0}
|
||
|
|
results_by_subject[subject]["total"] += 1
|
||
|
|
if is_correct:
|
||
|
|
results_by_subject[subject]["correct"] += 1
|
||
|
|
|
||
|
|
if (i + 1) % 20 == 0:
|
||
|
|
print(f"Progress: {i+1}/{len(dataset)}, Accuracy so far: {correct/(i+1)*100:.1f}%")
|
||
|
|
|
||
|
|
# Print results
|
||
|
|
accuracy = correct / len(dataset) * 100
|
||
|
|
print(f"\n{'='*60}")
|
||
|
|
print(f"MATH-500 Results ({len(dataset)} problems)")
|
||
|
|
print(f"{'='*60}")
|
||
|
|
print(f"Overall Accuracy: {accuracy:.1f}% ({correct}/{len(dataset)})")
|
||
|
|
|
||
|
|
print(f"\nBy Level:")
|
||
|
|
for level in sorted(results_by_level.keys()):
|
||
|
|
stats = results_by_level[level]
|
||
|
|
acc = stats["correct"] / stats["total"] * 100
|
||
|
|
print(f" {level}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||
|
|
|
||
|
|
print(f"\nBy Subject:")
|
||
|
|
for subject in sorted(results_by_subject.keys()):
|
||
|
|
stats = results_by_subject[subject]
|
||
|
|
acc = stats["correct"] / stats["total"] * 100
|
||
|
|
print(f" {subject}: {acc:.1f}% ({stats['correct']}/{stats['total']})")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|