#!/usr/bin/env python3 """ Combine GSM8K training data with MetaMathQA GSM-related examples. This creates a larger, more diverse training set. """ import json import re from datasets import load_dataset def extract_answer_gsm8k(answer_text): """Extract the final numerical answer from GSM8K answer format.""" match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text) if match: return match.group(1).replace(',', '') return None def format_reasoning_gsm8k(answer_text): """Convert GSM8K step-by-step format to thinking format.""" reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip() reasoning = re.sub(r'<<[^>]+>>', '', reasoning) return reasoning def extract_answer_metamath(response): """Extract answer from MetaMathQA format (usually ends with boxed answer).""" # Try to find boxed answer match = re.search(r'\\boxed\{([^}]+)\}', response) if match: return match.group(1).strip() # Try to find "the answer is X" pattern match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE) if match: return match.group(1).replace(',', '') # Try to find "= X" at the end match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response) if match: return match.group(1).replace(',', '') return None def create_gsm8k_example(question, answer): """Create a training example from GSM8K format.""" final_answer = extract_answer_gsm8k(answer) reasoning = format_reasoning_gsm8k(answer) if final_answer is None: return None assistant_content = f""" Let me solve this step by step. {reasoning} Therefore, the answer is {final_answer}. The answer is {final_answer}""" return { "messages": [ {"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"}, {"role": "assistant", "content": assistant_content} ] } def create_metamath_example(query, response): """Create a training example from MetaMathQA format.""" # Clean up the response - remove LaTeX formatting artifacts clean_response = response.replace('\\n', '\n').strip() # Extract the answer final_answer = extract_answer_metamath(clean_response) if final_answer is None: return None # Remove the boxed answer and everything after for reasoning reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip() reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip() # Skip if reasoning is too short if len(reasoning) < 50: return None assistant_content = f""" Let me solve this step by step. {reasoning} Therefore, the answer is {final_answer}. The answer is {final_answer}""" return { "messages": [ {"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"}, {"role": "assistant", "content": assistant_content} ] } def main(): training_data = [] # Load GSM8K training data print("Loading GSM8K dataset...") gsm8k = load_dataset("openai/gsm8k", "main", split="train") print(f"Loaded {len(gsm8k)} GSM8K examples") gsm8k_count = 0 for example in gsm8k: formatted = create_gsm8k_example(example['question'], example['answer']) if formatted: training_data.append(formatted) gsm8k_count += 1 print(f"Added {gsm8k_count} GSM8K examples") # Load MetaMathQA - only GSM-related examples print("\nLoading MetaMathQA dataset...") metamath = load_dataset("meta-math/MetaMathQA", split="train") print(f"Loaded {len(metamath)} MetaMathQA examples") # Filter for GSM-related examples only metamath_count = 0 for example in metamath: if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc. formatted = create_metamath_example(example['query'], example['response']) if formatted: training_data.append(formatted) metamath_count += 1 print(f"Added {metamath_count} MetaMathQA GSM examples") print(f"\nTotal training examples: {len(training_data)}") # Shuffle the data import random random.seed(42) random.shuffle(training_data) # Save to JSONL output_file = "combined_math_train.jsonl" with open(output_file, 'w') as f: for item in training_data: f.write(json.dumps(item) + '\n') print(f"Saved to {output_file}") # Show samples print("\n=== Sample GSM8K example ===") for item in training_data[:10]: if "Natalia" in item['messages'][0]['content']: print(json.dumps(item, indent=2)[:500]) break if __name__ == "__main__": main()