初始化项目,由ModelHub XC社区提供模型
Model: HuggingFaceTB/qwen3-1.7b-gsm8k-sft Source: Original Platform
This commit is contained in:
154
scripts/prepare_combined_data.py
Normal file
154
scripts/prepare_combined_data.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Combine GSM8K training data with MetaMathQA GSM-related examples.
|
||||
This creates a larger, more diverse training set.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datasets import load_dataset
|
||||
|
||||
def extract_answer_gsm8k(answer_text):
|
||||
"""Extract the final numerical answer from GSM8K answer format."""
|
||||
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def format_reasoning_gsm8k(answer_text):
|
||||
"""Convert GSM8K step-by-step format to thinking format."""
|
||||
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
|
||||
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
|
||||
return reasoning
|
||||
|
||||
def extract_answer_metamath(response):
|
||||
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
|
||||
# Try to find boxed answer
|
||||
match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
# Try to find "the answer is X" pattern
|
||||
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
# Try to find "= X" at the end
|
||||
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
|
||||
if match:
|
||||
return match.group(1).replace(',', '')
|
||||
return None
|
||||
|
||||
def create_gsm8k_example(question, answer):
|
||||
"""Create a training example from GSM8K format."""
|
||||
final_answer = extract_answer_gsm8k(answer)
|
||||
reasoning = format_reasoning_gsm8k(answer)
|
||||
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def create_metamath_example(query, response):
|
||||
"""Create a training example from MetaMathQA format."""
|
||||
# Clean up the response - remove LaTeX formatting artifacts
|
||||
clean_response = response.replace('\\n', '\n').strip()
|
||||
|
||||
# Extract the answer
|
||||
final_answer = extract_answer_metamath(clean_response)
|
||||
if final_answer is None:
|
||||
return None
|
||||
|
||||
# Remove the boxed answer and everything after for reasoning
|
||||
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
|
||||
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
|
||||
|
||||
# Skip if reasoning is too short
|
||||
if len(reasoning) < 50:
|
||||
return None
|
||||
|
||||
assistant_content = f"""<think>
|
||||
Let me solve this step by step.
|
||||
|
||||
{reasoning}
|
||||
|
||||
Therefore, the answer is {final_answer}.
|
||||
</think>
|
||||
|
||||
The answer is {final_answer}"""
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
]
|
||||
}
|
||||
|
||||
def main():
|
||||
training_data = []
|
||||
|
||||
# Load GSM8K training data
|
||||
print("Loading GSM8K dataset...")
|
||||
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
|
||||
print(f"Loaded {len(gsm8k)} GSM8K examples")
|
||||
|
||||
gsm8k_count = 0
|
||||
for example in gsm8k:
|
||||
formatted = create_gsm8k_example(example['question'], example['answer'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
gsm8k_count += 1
|
||||
print(f"Added {gsm8k_count} GSM8K examples")
|
||||
|
||||
# Load MetaMathQA - only GSM-related examples
|
||||
print("\nLoading MetaMathQA dataset...")
|
||||
metamath = load_dataset("meta-math/MetaMathQA", split="train")
|
||||
print(f"Loaded {len(metamath)} MetaMathQA examples")
|
||||
|
||||
# Filter for GSM-related examples only
|
||||
metamath_count = 0
|
||||
for example in metamath:
|
||||
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
|
||||
formatted = create_metamath_example(example['query'], example['response'])
|
||||
if formatted:
|
||||
training_data.append(formatted)
|
||||
metamath_count += 1
|
||||
print(f"Added {metamath_count} MetaMathQA GSM examples")
|
||||
|
||||
print(f"\nTotal training examples: {len(training_data)}")
|
||||
|
||||
# Shuffle the data
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(training_data)
|
||||
|
||||
# Save to JSONL
|
||||
output_file = "combined_math_train.jsonl"
|
||||
with open(output_file, 'w') as f:
|
||||
for item in training_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
print(f"Saved to {output_file}")
|
||||
|
||||
# Show samples
|
||||
print("\n=== Sample GSM8K example ===")
|
||||
for item in training_data[:10]:
|
||||
if "Natalia" in item['messages'][0]['content']:
|
||||
print(json.dumps(item, indent=2)[:500])
|
||||
break
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user