初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/SmolLM3-3B-GSM8K-SFT Source: Original Platform
This commit is contained in:
93
training/evaluate_gsm8k.py
Normal file
93
training/evaluate_gsm8k.py
Normal file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple GSM8K evaluation script for SmolLM3-3B-GSM8K-SFT
|
||||
|
||||
Requirements:
|
||||
pip install transformers datasets vllm
|
||||
|
||||
Usage:
|
||||
python evaluate_gsm8k.py --model HuggingFaceTB/SmolLM3-3B-GSM8K-SFT --samples 100
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract numerical answer from model output."""
|
||||
# Look for #### followed by number
|
||||
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
|
||||
# Look for "answer is X" pattern
|
||||
match = re.search(r'answer is[:\s]*(-?[\d,]+(?:\.\d+)?)', text.lower())
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
|
||||
# Extract last number
|
||||
numbers = re.findall(r'-?[\d,]+(?:\.\d+)?', text)
|
||||
if numbers:
|
||||
return numbers[-1].replace(',', '')
|
||||
|
||||
return ""
|
||||
|
||||
def extract_gold_answer(text: str) -> str:
|
||||
"""Extract gold answer from GSM8K format."""
|
||||
match = re.search(r'####\s*(-?[\d,]+(?:\.\d+)?)', text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return ""
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="HuggingFaceTB/SmolLM3-3B-GSM8K-SFT")
|
||||
parser.add_argument("--samples", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading model: {args.model}")
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.9,
|
||||
)
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=512,
|
||||
temperature=0.0,
|
||||
stop=["<|im_end|>"],
|
||||
)
|
||||
|
||||
print("Loading GSM8K test set...")
|
||||
dataset = load_dataset("openai/gsm8k", "main", split="test")
|
||||
dataset = dataset.select(range(min(args.samples, len(dataset))))
|
||||
|
||||
# Prepare prompts
|
||||
prompts = []
|
||||
for example in dataset:
|
||||
messages = [{"role": "user", "content": example["question"]}]
|
||||
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
prompts.append(prompt)
|
||||
|
||||
# Generate
|
||||
print(f"Evaluating {len(prompts)} samples...")
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Score
|
||||
correct = 0
|
||||
for i, (example, output) in enumerate(zip(dataset, outputs)):
|
||||
pred = extract_answer(output.outputs[0].text)
|
||||
gold = extract_gold_answer(example["answer"])
|
||||
|
||||
if pred == gold:
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / len(dataset) * 100
|
||||
print(f"\nResults:")
|
||||
print(f" Correct: {correct}/{len(dataset)}")
|
||||
print(f" Accuracy: {accuracy:.1f}%")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user