#!/usr/bin/env python3
"""
Combine GSM8K training data with MetaMathQA GSM-related examples.
This creates a larger, more diverse training set.
"""
import json
import re
from datasets import load_dataset
def extract_answer_gsm8k(answer_text):
"""Extract the final numerical answer from GSM8K answer format."""
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
if match:
return match.group(1).replace(',', '')
return None
def format_reasoning_gsm8k(answer_text):
"""Convert GSM8K step-by-step format to thinking format."""
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
return reasoning
def extract_answer_metamath(response):
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
# Try to find boxed answer
match = re.search(r'\\boxed\{([^}]+)\}', response)
if match:
return match.group(1).strip()
# Try to find "the answer is X" pattern
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
if match:
return match.group(1).replace(',', '')
# Try to find "= X" at the end
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
if match:
return match.group(1).replace(',', '')
return None
def create_gsm8k_example(question, answer):
"""Create a training example from GSM8K format."""
final_answer = extract_answer_gsm8k(answer)
reasoning = format_reasoning_gsm8k(answer)
if final_answer is None:
return None
assistant_content = f"""
Let me solve this step by step.
{reasoning}
Therefore, the answer is {final_answer}.
The answer is {final_answer}"""
return {
"messages": [
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
{"role": "assistant", "content": assistant_content}
]
}
def create_metamath_example(query, response):
"""Create a training example from MetaMathQA format."""
# Clean up the response - remove LaTeX formatting artifacts
clean_response = response.replace('\\n', '\n').strip()
# Extract the answer
final_answer = extract_answer_metamath(clean_response)
if final_answer is None:
return None
# Remove the boxed answer and everything after for reasoning
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
# Skip if reasoning is too short
if len(reasoning) < 50:
return None
assistant_content = f"""
Let me solve this step by step.
{reasoning}
Therefore, the answer is {final_answer}.
The answer is {final_answer}"""
return {
"messages": [
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
{"role": "assistant", "content": assistant_content}
]
}
def main():
training_data = []
# Load GSM8K training data
print("Loading GSM8K dataset...")
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
print(f"Loaded {len(gsm8k)} GSM8K examples")
gsm8k_count = 0
for example in gsm8k:
formatted = create_gsm8k_example(example['question'], example['answer'])
if formatted:
training_data.append(formatted)
gsm8k_count += 1
print(f"Added {gsm8k_count} GSM8K examples")
# Load MetaMathQA - only GSM-related examples
print("\nLoading MetaMathQA dataset...")
metamath = load_dataset("meta-math/MetaMathQA", split="train")
print(f"Loaded {len(metamath)} MetaMathQA examples")
# Filter for GSM-related examples only
metamath_count = 0
for example in metamath:
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
formatted = create_metamath_example(example['query'], example['response'])
if formatted:
training_data.append(formatted)
metamath_count += 1
print(f"Added {metamath_count} MetaMathQA GSM examples")
print(f"\nTotal training examples: {len(training_data)}")
# Shuffle the data
import random
random.seed(42)
random.shuffle(training_data)
# Save to JSONL
output_file = "combined_math_train.jsonl"
with open(output_file, 'w') as f:
for item in training_data:
f.write(json.dumps(item) + '\n')
print(f"Saved to {output_file}")
# Show samples
print("\n=== Sample GSM8K example ===")
for item in training_data[:10]:
if "Natalia" in item['messages'][0]['content']:
print(json.dumps(item, indent=2)[:500])
break
if __name__ == "__main__":
main()