94 lines
2.7 KiB
Python
94 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Simple GSM8K evaluation script for SmolLM3-3B-GSM8K-SFT
|
|
|
|
Requirements:
|
|
pip install transformers datasets vllm
|
|
|
|
Usage:
|
|
python evaluate_gsm8k.py --model HuggingFaceTB/SmolLM3-3B-GSM8K-SFT --samples 100
|
|
"""
|
|
|
|
import argparse
|
|
import re
|
|
from datasets import load_dataset
|
|
from vllm import LLM, SamplingParams
|
|
|
|
def extract_answer(text: str) -> str:
|
|
"""Extract numerical answer from model output."""
|
|
# Look for #### followed by number
|
|
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
|
if match:
|
|
return match.group(1).replace(',', '')
|
|
|
|
# Look for "answer is X" pattern
|
|
match = re.search(r'answer is[:\s]*(-?[\d,]+(?:\.\d+)?)', text.lower())
|
|
if match:
|
|
return match.group(1).replace(',', '')
|
|
|
|
# Extract last number
|
|
numbers = re.findall(r'-?[\d,]+(?:\.\d+)?', text)
|
|
if numbers:
|
|
return numbers[-1].replace(',', '')
|
|
|
|
return ""
|
|
|
|
def extract_gold_answer(text: str) -> str:
|
|
"""Extract gold answer from GSM8K format."""
|
|
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
|
if match:
|
|
return match.group(1).replace(',', '')
|
|
return ""
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B-GSM8K-SFT")
|
|
parser.add_argument("--samples", type=int, default=100)
|
|
args = parser.parse_args()
|
|
|
|
print(f"Loading model: {args.model}")
|
|
llm = LLM(
|
|
model=args.model,
|
|
trust_remote_code=True,
|
|
gpu_memory_utilization=0.9,
|
|
)
|
|
|
|
tokenizer = llm.get_tokenizer()
|
|
sampling_params = SamplingParams(
|
|
max_tokens=512,
|
|
temperature=0.0,
|
|
stop=["<|im_end|>"],
|
|
)
|
|
|
|
print("Loading GSM8K test set...")
|
|
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
|
dataset = dataset.select(range(min(args.samples, len(dataset))))
|
|
|
|
# Prepare prompts
|
|
prompts = []
|
|
for example in dataset:
|
|
messages = [{"role": "user", "content": example["question"]}]
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
prompts.append(prompt)
|
|
|
|
# Generate
|
|
print(f"Evaluating {len(prompts)} samples...")
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Score
|
|
correct = 0
|
|
for i, (example, output) in enumerate(zip(dataset, outputs)):
|
|
pred = extract_answer(output.outputs[0].text)
|
|
gold = extract_gold_answer(example["answer"])
|
|
|
|
if pred == gold:
|
|
correct += 1
|
|
|
|
accuracy = correct / len(dataset) * 100
|
|
print(f"\nResults:")
|
|
print(f" Correct: {correct}/{len(dataset)}")
|
|
print(f" Accuracy: {accuracy:.1f}%")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|