#!/usr/bin/env python3 """ Simple GSM8K evaluation script for SmolLM3-3B-GSM8K-SFT Requirements: pip install transformers datasets vllm Usage: python evaluate_gsm8k.py --model HuggingFaceTB/SmolLM3-3B-GSM8K-SFT --samples 100 """ import argparse import re from datasets import load_dataset from vllm import LLM, SamplingParams def extract_answer(text: str) -> str: """Extract numerical answer from model output.""" # Look for #### followed by number match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text) if match: return match.group(1).replace(',', '') # Look for "answer is X" pattern match = re.search(r'answer is[:\s]*(-?[\d,]+(?:\.\d+)?)', text.lower()) if match: return match.group(1).replace(',', '') # Extract last number numbers = re.findall(r'-?[\d,]+(?:\.\d+)?', text) if numbers: return numbers[-1].replace(',', '') return "" def extract_gold_answer(text: str) -> str: """Extract gold answer from GSM8K format.""" match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text) if match: return match.group(1).replace(',', '') return "" def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B-GSM8K-SFT") parser.add_argument("--samples", type=int, default=100) args = parser.parse_args() print(f"Loading model: {args.model}") llm = LLM( model=args.model, trust_remote_code=True, gpu_memory_utilization=0.9, ) tokenizer = llm.get_tokenizer() sampling_params = SamplingParams( max_tokens=512, temperature=0.0, stop=["<|im_end|>"], ) print("Loading GSM8K test set...") dataset = load_dataset("openai/gsm8k", "main", split="test") dataset = dataset.select(range(min(args.samples, len(dataset)))) # Prepare prompts prompts = [] for example in dataset: messages = [{"role": "user", "content": example["question"]}] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) prompts.append(prompt) # Generate print(f"Evaluating {len(prompts)} samples...") outputs = llm.generate(prompts, sampling_params) # Score correct = 0 for i, (example, output) in enumerate(zip(dataset, outputs)): pred = extract_answer(output.outputs[0].text) gold = extract_gold_answer(example["answer"]) if pred == gold: correct += 1 accuracy = correct / len(dataset) * 100 print(f"\nResults:") print(f" Correct: {correct}/{len(dataset)}") print(f" Accuracy: {accuracy:.1f}%") if __name__ == "__main__": main()