初始化项目,由ModelHub XC社区提供模型
Model: prithivMLmods/SmolLM2_135M_Grpo_Gsm8k Source: Original Platform
This commit is contained in:
261
smollm-grpo/SmolLM x Grpo M1.ipynb
Normal file
261
smollm-grpo/SmolLM x Grpo M1.ipynb
Normal file
@@ -0,0 +1,261 @@
|
||||
{
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": [],
|
||||
"gpuType": "T4"
|
||||
},
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
},
|
||||
"accelerator": "GPU"
|
||||
},
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PFGREH3kAkep"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -q git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/accelerate.git\n",
|
||||
"!pip install -q datasets huggingface-hub trl\n",
|
||||
"!pip install -q git+https://github.com/huggingface/peft.git"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import torch\n",
|
||||
"from datasets import load_dataset, Dataset\n",
|
||||
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
||||
"from peft import LoraConfig\n",
|
||||
"from trl import GRPOConfig, GRPOTrainer"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "PTbegv7kAwXd"
|
||||
},
|
||||
"execution_count": 3,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"SYSTEM_PROMPT = \"\"\"\n",
|
||||
"Respond in the following format:\n",
|
||||
"<reasoning>\n",
|
||||
"...\n",
|
||||
"</reasoning>\n",
|
||||
"<answer>\n",
|
||||
"...\n",
|
||||
"</answer>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"XML_COT_FORMAT = \"\"\"\\\n",
|
||||
"<reasoning>\n",
|
||||
"{reasoning}\n",
|
||||
"</reasoning>\n",
|
||||
"<answer>\n",
|
||||
"{answer}\n",
|
||||
"</answer>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"def extract_xml_answer(text: str) -> str:\n",
|
||||
" answer = text.split(\"<answer>\")[-1]\n",
|
||||
" answer = answer.split(\"</answer>\")[0]\n",
|
||||
" return answer.strip()\n",
|
||||
"\n",
|
||||
"def extract_hash_answer(text: str) -> str | None:\n",
|
||||
" if \"####\" not in text:\n",
|
||||
" return None\n",
|
||||
" return text.split(\"####\")[1].strip()\n",
|
||||
"\n",
|
||||
"# uncomment middle messages for 1-shot prompting\n",
|
||||
"def get_gsm8k_questions(split = \"train\") -> Dataset:\n",
|
||||
" data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore\n",
|
||||
" data = data.map(lambda x: { # type: ignore\n",
|
||||
" 'prompt': [\n",
|
||||
" {'role': 'system', 'content': SYSTEM_PROMPT},\n",
|
||||
" #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},\n",
|
||||
" #{'role': 'assistant', 'content': XML_COT_FORMAT.format(\n",
|
||||
" # reasoning=\"9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.\",\n",
|
||||
" # answer=\"7\"\n",
|
||||
" #)},\n",
|
||||
" {'role': 'user', 'content': x['question']}\n",
|
||||
" ],\n",
|
||||
" 'answer': extract_hash_answer(x['answer'])\n",
|
||||
" }) # type: ignore\n",
|
||||
" return data # type: ignore\n",
|
||||
"\n",
|
||||
"dataset = get_gsm8k_questions()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "L1h8s7QzAyE0"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:\n",
|
||||
" responses = [completion[0]['content'] for completion in completions]\n",
|
||||
" q = prompts[0][-1]['content']\n",
|
||||
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
||||
" print('-'*20, f\"Question:\\n{q}\", f\"\\nAnswer:\\n{answer[0]}\", f\"\\nResponse:\\n{responses[0]}\", f\"\\nExtracted:\\n{extracted_responses[0]}\")\n",
|
||||
" return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]\n",
|
||||
"\n",
|
||||
"def int_reward_func(completions, **kwargs) -> list[float]:\n",
|
||||
" responses = [completion[0]['content'] for completion in completions]\n",
|
||||
" extracted_responses = [extract_xml_answer(r) for r in responses]\n",
|
||||
" return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]\n",
|
||||
"\n",
|
||||
"def strict_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
||||
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
||||
" pattern = r\"^<reasoning>\\n.*?\\n</reasoning>\\n<answer>\\n.*?\\n</answer>\\n$\"\n",
|
||||
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
||||
" matches = [re.match(pattern, r) for r in responses]\n",
|
||||
" return [0.5 if match else 0.0 for match in matches]\n",
|
||||
"\n",
|
||||
"def soft_format_reward_func(completions, **kwargs) -> list[float]:\n",
|
||||
" \"\"\"Reward function that checks if the completion has a specific format.\"\"\"\n",
|
||||
" pattern = r\"<reasoning>.*?</reasoning>\\s*<answer>.*?</answer>\"\n",
|
||||
" responses = [completion[0][\"content\"] for completion in completions]\n",
|
||||
" matches = [re.match(pattern, r) for r in responses]\n",
|
||||
" return [0.5 if match else 0.0 for match in matches]\n",
|
||||
"\n",
|
||||
"def count_xml(text) -> float:\n",
|
||||
" count = 0.0\n",
|
||||
" if text.count(\"<reasoning>\\n\") == 1:\n",
|
||||
" count += 0.125\n",
|
||||
" if text.count(\"\\n</reasoning>\\n\") == 1:\n",
|
||||
" count += 0.125\n",
|
||||
" if text.count(\"\\n<answer>\\n\") == 1:\n",
|
||||
" count += 0.125\n",
|
||||
" count -= len(text.split(\"\\n</answer>\\n\")[-1])*0.001\n",
|
||||
" if text.count(\"\\n</answer>\") == 1:\n",
|
||||
" count += 0.125\n",
|
||||
" count -= (len(text.split(\"\\n</answer>\")[-1]) - 1)*0.001\n",
|
||||
" return count\n",
|
||||
"\n",
|
||||
"def xmlcount_reward_func(completions, **kwargs) -> list[float]:\n",
|
||||
" contents = [completion[0][\"content\"] for completion in completions]\n",
|
||||
" return [count_xml(c) for c in contents]"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "R60IIK5FA3Mu"
|
||||
},
|
||||
"execution_count": 5,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"model_name = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n",
|
||||
"#model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
|
||||
"\n",
|
||||
"if \"SmolLM2\" in model_name:\n",
|
||||
" output_dir = \"outputs/SmolLM2-135M-GRPO\"\n",
|
||||
" run_name = \"SmolLM2-135M-GRPO\"\n",
|
||||
"else:\n",
|
||||
" output_dir=\"outputs/Qwen-1.5B-GRPO\"\n",
|
||||
" run_name=\"Qwen-1.5B-GRPO-gsm8k\"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "q-u5UG8pDLvD"
|
||||
},
|
||||
"execution_count": 6,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"training_args = GRPOConfig(\n",
|
||||
" output_dir=output_dir,\n",
|
||||
" run_name=run_name,\n",
|
||||
" learning_rate=5e-6,\n",
|
||||
" adam_beta1=0.9,\n",
|
||||
" adam_beta2=0.99,\n",
|
||||
" weight_decay=0.1,\n",
|
||||
" warmup_ratio=0.1,\n",
|
||||
" lr_scheduler_type='cosine',\n",
|
||||
" logging_steps=1,\n",
|
||||
" bf16=True,\n",
|
||||
" per_device_train_batch_size=16, # Fixed: Must be divisible by num_generations\n",
|
||||
" gradient_accumulation_steps=4,\n",
|
||||
" num_generations=16, # Kept as is\n",
|
||||
" max_prompt_length=256,\n",
|
||||
" max_completion_length=786,\n",
|
||||
" num_train_epochs=1,\n",
|
||||
" save_steps=100,\n",
|
||||
" max_grad_norm=0.1,\n",
|
||||
" report_to=\"none\",\n",
|
||||
" log_on_each_node=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" r=16,\n",
|
||||
" lora_alpha=64,\n",
|
||||
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\", \"gate_proj\"],\n",
|
||||
" task_type=\"CAUSAL_LM\",\n",
|
||||
" lora_dropout=0.05,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(\n",
|
||||
" model_name,\n",
|
||||
" torch_dtype=torch.bfloat16,\n",
|
||||
" device_map=None\n",
|
||||
").to(\"cuda\")\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "vcZOoqyHBIBV"
|
||||
},
|
||||
"execution_count": 10,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"trainer = GRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" processing_class=tokenizer,\n",
|
||||
" reward_funcs=[\n",
|
||||
" xmlcount_reward_func,\n",
|
||||
" soft_format_reward_func,\n",
|
||||
" strict_format_reward_func,\n",
|
||||
" int_reward_func,\n",
|
||||
" correctness_reward_func],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=dataset,\n",
|
||||
" #peft_config=peft_config\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "WLxd8ZFmBzvE"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"trainer.train()"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "RTbBtLmRIDpY"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
}
|
||||
]
|
||||
}
|
||||
14621
smollm-grpo/SmolLM_x_Grpo.ipynb
Normal file
14621
smollm-grpo/SmolLM_x_Grpo.ipynb
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user