Files
SmolLM3-3B-GSM8K-SFT/training/evaluate_gsm8k.py
ModelHub XC 99f42b03e6 初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/SmolLM3-3B-GSM8K-SFT
Source: Original Platform
2026-06-04 13:01:17 +08:00

94 lines
2.7 KiB
Python

#!/usr/bin/env python3
"""
Simple GSM8K evaluation script for SmolLM3-3B-GSM8K-SFT
Requirements:
pip install transformers datasets vllm
Usage:
python evaluate_gsm8k.py --model HuggingFaceTB/SmolLM3-3B-GSM8K-SFT --samples 100
"""
import argparse
import re
from datasets import load_dataset
from vllm import LLM, SamplingParams
def extract_answer(text: str) -> str:
"""Extract numerical answer from model output."""
# Look for #### followed by number
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
if match:
return match.group(1).replace(',', '')
# Look for "answer is X" pattern
match = re.search(r'answer is[:\s]*(-?[\d,]+(?:\.\d+)?)', text.lower())
if match:
return match.group(1).replace(',', '')
# Extract last number
numbers = re.findall(r'-?[\d,]+(?:\.\d+)?', text)
if numbers:
return numbers[-1].replace(',', '')
return ""
def extract_gold_answer(text: str) -> str:
"""Extract gold answer from GSM8K format."""
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
if match:
return match.group(1).replace(',', '')
return ""
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B-GSM8K-SFT")
parser.add_argument("--samples", type=int, default=100)
args = parser.parse_args()
print(f"Loading model: {args.model}")
llm = LLM(
model=args.model,
trust_remote_code=True,
gpu_memory_utilization=0.9,
)
tokenizer = llm.get_tokenizer()
sampling_params = SamplingParams(
max_tokens=512,
temperature=0.0,
stop=["<|im_end|>"],
)
print("Loading GSM8K test set...")
dataset = load_dataset("openai/gsm8k", "main", split="test")
dataset = dataset.select(range(min(args.samples, len(dataset))))
# Prepare prompts
prompts = []
for example in dataset:
messages = [{"role": "user", "content": example["question"]}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompts.append(prompt)
# Generate
print(f"Evaluating {len(prompts)} samples...")
outputs = llm.generate(prompts, sampling_params)
# Score
correct = 0
for i, (example, output) in enumerate(zip(dataset, outputs)):
pred = extract_answer(output.outputs[0].text)
gold = extract_gold_answer(example["answer"])
if pred == gold:
correct += 1
accuracy = correct / len(dataset) * 100
print(f"\nResults:")
print(f" Correct: {correct}/{len(dataset)}")
print(f" Accuracy: {accuracy:.1f}%")
if __name__ == "__main__":
main()