155 lines
4.9 KiB
Python
155 lines
4.9 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Combine GSM8K training data with MetaMathQA GSM-related examples.
|
||
|
|
This creates a larger, more diverse training set.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
from datasets import load_dataset
|
||
|
|
|
||
|
|
def extract_answer_gsm8k(answer_text):
|
||
|
|
"""Extract the final numerical answer from GSM8K answer format."""
|
||
|
|
match = re.search(r'####\s*(-?[\d,]+\.?\d*)', answer_text)
|
||
|
|
if match:
|
||
|
|
return match.group(1).replace(',', '')
|
||
|
|
return None
|
||
|
|
|
||
|
|
def format_reasoning_gsm8k(answer_text):
|
||
|
|
"""Convert GSM8K step-by-step format to thinking format."""
|
||
|
|
reasoning = re.sub(r'####\s*-?[\d,]+\.?\d*\s*$', '', answer_text).strip()
|
||
|
|
reasoning = re.sub(r'<<[^>]+>>', '', reasoning)
|
||
|
|
return reasoning
|
||
|
|
|
||
|
|
def extract_answer_metamath(response):
|
||
|
|
"""Extract answer from MetaMathQA format (usually ends with boxed answer)."""
|
||
|
|
# Try to find boxed answer
|
||
|
|
match = re.search(r'\\boxed\{([^}]+)\}', response)
|
||
|
|
if match:
|
||
|
|
return match.group(1).strip()
|
||
|
|
# Try to find "the answer is X" pattern
|
||
|
|
match = re.search(r'the answer is[:\s]*\$?(-?[\d,]+\.?\d*)', response, re.IGNORECASE)
|
||
|
|
if match:
|
||
|
|
return match.group(1).replace(',', '')
|
||
|
|
# Try to find "= X" at the end
|
||
|
|
match = re.search(r'=\s*\$?(-?[\d,]+\.?\d*)\s*(?:dollars?|\.)?$', response)
|
||
|
|
if match:
|
||
|
|
return match.group(1).replace(',', '')
|
||
|
|
return None
|
||
|
|
|
||
|
|
def create_gsm8k_example(question, answer):
|
||
|
|
"""Create a training example from GSM8K format."""
|
||
|
|
final_answer = extract_answer_gsm8k(answer)
|
||
|
|
reasoning = format_reasoning_gsm8k(answer)
|
||
|
|
|
||
|
|
if final_answer is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
assistant_content = f"""<think>
|
||
|
|
Let me solve this step by step.
|
||
|
|
|
||
|
|
{reasoning}
|
||
|
|
|
||
|
|
Therefore, the answer is {final_answer}.
|
||
|
|
</think>
|
||
|
|
|
||
|
|
The answer is {final_answer}"""
|
||
|
|
|
||
|
|
return {
|
||
|
|
"messages": [
|
||
|
|
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{question}"},
|
||
|
|
{"role": "assistant", "content": assistant_content}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
def create_metamath_example(query, response):
|
||
|
|
"""Create a training example from MetaMathQA format."""
|
||
|
|
# Clean up the response - remove LaTeX formatting artifacts
|
||
|
|
clean_response = response.replace('\\n', '\n').strip()
|
||
|
|
|
||
|
|
# Extract the answer
|
||
|
|
final_answer = extract_answer_metamath(clean_response)
|
||
|
|
if final_answer is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Remove the boxed answer and everything after for reasoning
|
||
|
|
reasoning = re.sub(r'\\boxed\{[^}]+\}.*$', '', clean_response, flags=re.DOTALL).strip()
|
||
|
|
reasoning = re.sub(r'The answer is.*$', '', reasoning, flags=re.IGNORECASE | re.DOTALL).strip()
|
||
|
|
|
||
|
|
# Skip if reasoning is too short
|
||
|
|
if len(reasoning) < 50:
|
||
|
|
return None
|
||
|
|
|
||
|
|
assistant_content = f"""<think>
|
||
|
|
Let me solve this step by step.
|
||
|
|
|
||
|
|
{reasoning}
|
||
|
|
|
||
|
|
Therefore, the answer is {final_answer}.
|
||
|
|
</think>
|
||
|
|
|
||
|
|
The answer is {final_answer}"""
|
||
|
|
|
||
|
|
return {
|
||
|
|
"messages": [
|
||
|
|
{"role": "user", "content": f"Solve the following math problem step by step. Show your reasoning and then provide the final answer.\n\n{query}"},
|
||
|
|
{"role": "assistant", "content": assistant_content}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
def main():
|
||
|
|
training_data = []
|
||
|
|
|
||
|
|
# Load GSM8K training data
|
||
|
|
print("Loading GSM8K dataset...")
|
||
|
|
gsm8k = load_dataset("openai/gsm8k", "main", split="train")
|
||
|
|
print(f"Loaded {len(gsm8k)} GSM8K examples")
|
||
|
|
|
||
|
|
gsm8k_count = 0
|
||
|
|
for example in gsm8k:
|
||
|
|
formatted = create_gsm8k_example(example['question'], example['answer'])
|
||
|
|
if formatted:
|
||
|
|
training_data.append(formatted)
|
||
|
|
gsm8k_count += 1
|
||
|
|
print(f"Added {gsm8k_count} GSM8K examples")
|
||
|
|
|
||
|
|
# Load MetaMathQA - only GSM-related examples
|
||
|
|
print("\nLoading MetaMathQA dataset...")
|
||
|
|
metamath = load_dataset("meta-math/MetaMathQA", split="train")
|
||
|
|
print(f"Loaded {len(metamath)} MetaMathQA examples")
|
||
|
|
|
||
|
|
# Filter for GSM-related examples only
|
||
|
|
metamath_count = 0
|
||
|
|
for example in metamath:
|
||
|
|
if 'GSM' in example['type']: # GSM_Rephrased, GSM_SV, GSM_AnsAug, etc.
|
||
|
|
formatted = create_metamath_example(example['query'], example['response'])
|
||
|
|
if formatted:
|
||
|
|
training_data.append(formatted)
|
||
|
|
metamath_count += 1
|
||
|
|
print(f"Added {metamath_count} MetaMathQA GSM examples")
|
||
|
|
|
||
|
|
print(f"\nTotal training examples: {len(training_data)}")
|
||
|
|
|
||
|
|
# Shuffle the data
|
||
|
|
import random
|
||
|
|
random.seed(42)
|
||
|
|
random.shuffle(training_data)
|
||
|
|
|
||
|
|
# Save to JSONL
|
||
|
|
output_file = "combined_math_train.jsonl"
|
||
|
|
with open(output_file, 'w') as f:
|
||
|
|
for item in training_data:
|
||
|
|
f.write(json.dumps(item) + '\n')
|
||
|
|
|
||
|
|
print(f"Saved to {output_file}")
|
||
|
|
|
||
|
|
# Show samples
|
||
|
|
print("\n=== Sample GSM8K example ===")
|
||
|
|
for item in training_data[:10]:
|
||
|
|
if "Natalia" in item['messages'][0]['content']:
|
||
|
|
print(json.dumps(item, indent=2)[:500])
|
||
|
|
break
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|