Files
qwen3-4b-gsm8k/qwen3_4b_gsm8k_thinking.ipynb

1592 lines
310 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "dd248598",
"metadata": {},
"source": [
"# Qwen3-4B-MegaScience fine tuning on GSM8K\n",
"\n",
"Qwen3-4B-MegaScience is a scientific SLM. We fine tune it on GSM8K with LoRA. The answer format is rewritten into `<think>` and `<answer>`."
]
},
{
"cell_type": "markdown",
"id": "45250c6a",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"The stack below covers training, Hub upload, plotting, and GPU telemetry."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72cedce5",
"metadata": {},
"outputs": [],
"source": [
"!pip install -U \"transformers>=4.52.0\" accelerate peft trl datasets huggingface_hub pandas matplotlib nvidia-ml-py tensorboard sentencepiece protobuf safetensors"
]
},
{
"cell_type": "markdown",
"id": "17dc2e46",
"metadata": {},
"source": [
"## Imports and configuration\n",
"\n",
"Our setup includes RTX 5090 with cuda 13.0, Ryzen 9 9950X, 62 GB RAM."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cc66c06f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPU: NVIDIA GeForce RTX 5090\n",
"VRAM: 31.36 GB\n",
"bf16: True\n",
"fp16: False\n"
]
}
],
"source": [
"import os\n",
"import re\n",
"import json\n",
"import time\n",
"from pathlib import Path\n",
"from datetime import date\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import pynvml\n",
"from tqdm import tqdm\n",
"\n",
"from datasets import load_dataset\n",
"from huggingface_hub import HfApi, login, upload_folder\n",
"from peft import LoraConfig\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback\n",
"from trl import SFTConfig, SFTTrainer\n",
"\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"seed = 3407\n",
"base_model_id = \"MegaScience/Qwen3-4B-MegaScience\"\n",
"hub_model_id = \"pymlex/qwen3-4b-gsm8k\"\n",
"run_dir = Path(\"./qwen3_gsm8k_run\")\n",
"adapter_dir = run_dir / \"adapter\"\n",
"publish_dir = run_dir / \"publish\"\n",
"run_dir.mkdir(parents=True, exist_ok=True)\n",
"adapter_dir.mkdir(parents=True, exist_ok=True)\n",
"publish_dir.mkdir(parents=True, exist_ok=True)\n",
"\n",
"torch.manual_seed(seed)\n",
"if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(seed)\n",
"\n",
"torch.backends.cuda.matmul.allow_tf32 = True\n",
"torch.set_float32_matmul_precision(\"high\")\n",
"\n",
"bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()\n",
"fp16 = torch.cuda.is_available() and not bf16\n",
"\n",
"pynvml.nvmlInit()\n",
"\n",
"device_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"cpu\"\n",
"device_props = torch.cuda.get_device_properties(0) if torch.cuda.is_available() else None\n",
"vram_gb = device_props.total_memory / (1024 ** 3) if device_props is not None else 0.0\n",
"\n",
"print(f\"GPU: {device_name}\")\n",
"print(f\"VRAM: {vram_gb:.2f} GB\")\n",
"print(f\"bf16: {bf16}\")\n",
"print(f\"fp16: {fp16}\")"
]
},
{
"cell_type": "markdown",
"id": "b2ca0468",
"metadata": {},
"source": [
"## Dataset\n",
"\n",
"GSM8K already ships with calculation annotations inside `answer`. The final scalar answer sits after `####`, and that piece is moved into the `<answer>` block during formatting. The official train split is used for training and validation."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dfc0a6fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['question', 'answer'],\n",
" num_rows: 7099\n",
"})\n",
"Dataset({\n",
" features: ['question', 'answer'],\n",
" num_rows: 374\n",
"})\n",
"Dataset({\n",
" features: ['question', 'answer'],\n",
" num_rows: 1319\n",
"})\n",
"Emma is 7 years old. If her sister is 9 years older than her, how old will Emma be when her sister is 56?\n",
"Emmas sister is 7 + 9 = <<7+9=16>>16 years old.\n",
"In this many years, Emmas sister will be 56 years old: 56 - 16 = <<56-16=40>>40 years.\n",
"When her sister is 56 years old, Emma will be 7 + 40 = <<7+40=47>>47 years.\n",
"#### 47\n"
]
}
],
"source": [
"raw = load_dataset(\"openai/gsm8k\", \"main\")\n",
"\n",
"train_full = raw[\"train\"].shuffle(seed=seed)\n",
"test_ds = raw[\"test\"]\n",
"\n",
"split = train_full.train_test_split(test_size=0.05, seed=seed)\n",
"train_base = split[\"train\"]\n",
"val_base = split[\"test\"]\n",
"\n",
"print(train_base)\n",
"print(val_base)\n",
"print(test_ds)\n",
"print(train_base[0][\"question\"])\n",
"print(train_base[0][\"answer\"][:400])"
]
},
{
"cell_type": "markdown",
"id": "f0d7f3db",
"metadata": {},
"source": [
"## Chat format\n",
"\n",
"The target answer keeps the calculation trace in `<think>` and places the final value in `<answer>`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "232541b6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[transformers] `torch_dtype` is deprecated! Use `dtype` instead!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57660d54b8b54068b66ab5906dcc7d65",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/398 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>system\n",
"You solve grade-school math problems. Put the reasoning in <think>...</think>. Put only the final result as a single number in <answer>...</answer>.<|im_end|>\n",
"<|im_start|>user\n",
"Emma is 7 years old. If her sister is 9 years older than her, how old will Emma be when her sister is 56?<|im_end|>\n",
"<|im_start|>assistant\n",
"<think>\n",
"Emmas sister is 7 + 9 = <<7+9=16>>16 years old.\n",
"In this many years, Emmas sister will be 56 years old: 56 - 16 = <<56-16=40>>40 years.\n",
"When her sister is 56 years old, Emma will be 7 + 40 = <<7+40=47>>47 years.\n",
"</think>\n",
"\n",
"<answer>\n",
"47\n",
"</answer><|im_end|>\n",
"<|im_end|>\n",
"7099 374 1319\n"
]
}
],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)\n",
"tokenizer.padding_side = \"left\"\n",
"\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" base_model_id,\n",
" device_map=\"auto\",\n",
" torch_dtype=torch.bfloat16 if bf16 else torch.float16,\n",
" trust_remote_code=True,\n",
" low_cpu_mem_usage=True,\n",
" attn_implementation=\"sdpa\",\n",
")\n",
"\n",
"special_tokens = {\n",
" \"additional_special_tokens\": [\"<think>\", \"</think>\", \"<answer>\", \"</answer>\"]\n",
"}\n",
"added = tokenizer.add_special_tokens(special_tokens)\n",
"if added > 0:\n",
" model.resize_token_embeddings(len(tokenizer))\n",
"\n",
"model.config.use_cache = False\n",
"model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"\n",
"SYSTEM_PROMPT = (\n",
" \"You solve grade-school math problems. \"\n",
" \"Put the reasoning in <think>...</think>. \"\n",
" \"Put only the final result as a single number in <answer>...</answer>.\"\n",
")\n",
"\n",
"answer_re = re.compile(r\"####\\s*(.*)$\", re.S)\n",
"\n",
"def canonicalize_answer(text):\n",
" s = str(text).strip().replace(\",\", \"\").replace(\"$\", \"\")\n",
" s = s.replace(\" \", \"\")\n",
" if s.endswith(\".0\"):\n",
" s = s[:-2]\n",
" if s.endswith(\".\"):\n",
" s = s[:-1]\n",
" return s\n",
"\n",
"def split_reasoning_and_answer(answer_text):\n",
" match = answer_re.search(answer_text)\n",
" if match is None:\n",
" return answer_text.strip(), \"\"\n",
" reasoning = answer_text[:match.start()].strip()\n",
" final_answer = canonicalize_answer(match.group(1))\n",
" return reasoning, final_answer\n",
"\n",
"def build_text(example):\n",
" reasoning, final_answer = split_reasoning_and_answer(example[\"answer\"])\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": example[\"question\"].strip()},\n",
" {\n",
" \"role\": \"assistant\",\n",
" \"content\": f\"<think>\\n{reasoning}\\n</think>\\n<answer>\\n{final_answer}\\n</answer>\"\n",
" },\n",
" ]\n",
" text = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=False,\n",
" )\n",
" if not text.endswith(tokenizer.eos_token):\n",
" text += tokenizer.eos_token\n",
" token_count = len(tokenizer(text, add_special_tokens=False)[\"input_ids\"])\n",
" return {\n",
" \"text\": text,\n",
" \"n_tokens\": token_count,\n",
" \"final_answer\": final_answer,\n",
" }\n",
"\n",
"train_ds = train_base.map(build_text, remove_columns=train_base.column_names)\n",
"val_ds = val_base.map(build_text, remove_columns=val_base.column_names)\n",
"test_formatted = test_ds.map(build_text, remove_columns=test_ds.column_names)\n",
"\n",
"eval_subset = val_base.select(range(min(100, len(val_base))))\n",
"\n",
"print(train_ds[0][\"text\"][:1200])\n",
"print(len(train_ds), len(val_ds), len(test_formatted))"
]
},
{
"cell_type": "markdown",
"id": "d3e58438",
"metadata": {},
"source": [
"## Token lengths\n",
"\n",
"We will set the maximum sequence lenght to 768 to make the model process texts in all the expected scenarios."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "438916c0",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAARdxJREFUeJzt3XtcVVX+//H3QS6CeMAboHkjtZS8FZqS2UVJRLQcKdMhxXLSDLyWpd+816RZk6Z5yWrUX2aWzVRqqaGWlaIpjXlL0kbFVMDRAMUEhfX7o2FPJzAvsT2gr+fjsR9fz1rr7P1ZfHc85s3aF4cxxggAAAAAAJQ6D3cXAAAAAADA1YrQDQAAAACATQjdAAAAAADYhNANAAAAAIBNCN0AAAAAANiE0A0AAAAAgE0I3QAAAAAA2ITQDQAAAACATQjdAAAAAADYhNANAEApmjBhghwOh/7zn/+4uxR9/vnncjgc+vzzz91dygU5HA4lJiaW6j5Lmn+/fv1Uv379Uj3O+dSvX1/9+vWzPi9YsEAOh0Nbt269Ise/6667dNddd12RYwEAzo/QDQAoZv/+/UpMTNQNN9wgPz8/+fn5KSwsTAkJCdq+fXux8V999ZWio6N13XXXqWLFiqpbt666deumxYsXu4xzOBxyOBz6y1/+UuJxn3nmGWvMb0PrmjVrdPfdd6t69eoKDAzUrbfeqrfeestlzIEDB+RwOPTSSy+5tBtjNHDgQDkcDk2YMOG88z59+rQmTJhQLkJqebRx40ZNmDBBWVlZ7i7lkuzevVsTJkzQgQMH3F1KMWW5NgDALwjdAAAXK1asUNOmTfXWW28pMjJS06ZN0yuvvKLo6Gh98sknatmypQ4ePGiNX7p0qe644w5lZGRo6NChmjlzph566CH99NNPev3114vtv2LFivrHP/6h/Pz8Yn3vvPOOKlasWKx92bJl6tSpk/Lz8zVhwgT99a9/la+vr/r27atp06b97nyMMXr88cc1b948jR079oKhe+LEiYRum2zcuFETJ050a+h+/fXXlZqaeknf2b17tyZOnHjJwTY1NbXE/wZK0+/V9umnn+rTTz+19fgAgAvzdHcBAICy44cfflCvXr1Ur149rV27VjVr1nTpf+GFFzR79mx5ePzvb7YTJkxQWFiYNm3aJG9vb5fxmZmZxY7RuXNnLVu2TCtXrtR9991ntW/cuFH79+9XbGys/vGPf7h859VXX1XNmjW1bt06+fj4SJIGDhyoxo0ba8GCBRo+fPh55zR48GDNnTtXzzzzjCZNmnTxPwxclby8vGzdvzFGZ86cka+vr3Wuustv/3sEALgHK90AAMvUqVOVm5ur+fPnFwvckuTp6akhQ4aoTp06VtsPP/yg1q1bl/g/8IOCgoq1XXfddbrjjjuKXXr+9ttvq1mzZmratGmx7+Tk5KhKlSouIcbT01PVq1eXr6/veeczdOhQzZo1S6NHj9Zzzz133nHSL5em16hRQ5I0ceJE6zL3X6+Mr1u3Tu3bt1elSpUUGBio++67T999993v7leSDh48qIYNG6pp06bKyMiQJGVlZWnYsGGqU6eOfHx81LBhQ73wwgsqLCx0qanocvl58+apQYMG8vHxUevWrbVly5YLHvd8Nm/erM6dOysgIEB+fn668847tWHDBpcxRfem79u3T/369VNgYKACAgL08MMP6/Tp0y5jf/75Zw0ZMkTVq1dX5cqVde+99+rw4cMuP78JEyZo5MiRkqTQ0FDr5/vbFdoPP/xQTZs2lY+Pj2666SatWrXqoub0448/qnv37qpUqZKCgoI0fPhw5eXlFRtX0j3dS5YsUXh4uCpXriyn06lmzZrplVdekfTLfdgPPPCAJOnuu++26i66GqJ+/frq2rWrVq9erVatWsnX11evvfaa1ffre7qLnD59WgMHDlS1atXkdDrVt29f/fTTTy5jzncrxK/3eaHaSrqnOzMzU/3791dwcLAqVqyoFi1aaOHChS5j7DrvAOBaxUo3AMCyYsUKNWzYUG3atLno7xStiv/444+qXbv2RX3nz3/+s4YOHapTp07J399f586d09KlSzVixAidOXOm2Pi77rpLL7zwgsaOHav4+Hg5HA4tXrxYW7du1XvvvVfiMYYPH64ZM2bo6aef1vPPP3/BmmrUqKE5c+Zo0KBB+tOf/qQePXpIkpo3by7pl3vKo6Ojdf3112vChAn6+eefNXPmTLVr107ffPPNeR/O9cMPP6hDhw6qWrWqkpKSVL16dZ0+fVp33nmnDh8+rIEDB6pu3brauHGjRo8eraNHj2r69Oku+1i8eLFOnjxp3Zc+depU9ejRQ//+978veeV23bp1io6OVnh4uMaPHy8PDw/Nnz9fHTp00Jdffqlbb73VZXzPnj0VGhqqyZMn65tvvtEbb7yhoKAgvfDCC9aYfv366b333lOfPn3Utm1brV+/XjExMS776dGjh77//nu98847mjZtmqpXr2793It89dVX+uc//6nHH39clStX1owZMxQbG6u0tDRVq1btvHP6+eef1bFjR6WlpWnIkCGqVauW3nrrLa1bt+6CP4+kpCT17t1bHTt2tOb03XffacOGDRo6dKjuuOMODRkyRDNmzND//d//qUmTJpJk/V/pl8vIe/furYEDB+rRRx/VjTfe+LvHTExMVGBgoCZMmKDU1FTNmTNHBw8etB78drEuprZf+/nnn3XXXXdp3759SkxMVGhoqJYuXap+/fopKytLQ4cOdRlfmucdAFzTDAAAxpjs7GwjyXTv3r1Y308//WSOHTtmbadPn7b63nzzTSPJeHt7m7vvvtuMHTvWfPnll6agoKDYfiSZhIQEc+LECePt7W3eeustY4wxH3/8sXE4HObAgQNm/PjxRpI5duyY9b1Tp06Znj17GofDYSQZScbPz898+OGHLvvfv3+/kWTq1atnJJmRI0de0s/g2LFjRpIZP358sb6WLVuaoKAgc/z4cavt22+/NR4eHqZv375W26/r/+6770ytWrVM69atzYkTJ6wxzz77rKlUqZL5/vvvXY4xatQoU6FCBZOWluYyn2rVqrl8/6OPPjKSzPLly393Pp999pmRZD777DNjjDGFhYWmUaNGJioqyhQWFlrjTp8+bUJDQ80999xTbB6PPPKIyz7/9Kc/mWrVqlmfU1JSjCQzbNgwl3H9+vUr9rN88cUXjSSzf//+YrUWnUP79u2z2r799lsjycycOfN35zl9+nQjybz33ntWW25urmnYsKHL/I0xJj4+3tSrV8/6PHToUON0Os25c+fOu/+lS5cW20+RonNt1apVJfbFx8dbn+fPn28kmfDwcJOfn2+1T5061UgyH330kdV2vvPwt/v8vdruvPNOc+edd1qfi35OixYtstry8/NNRESE8ff3Nzk5OcaYP37eAQBccXk5AEDSL5dwS5K/v3+xvrvuuks1atSwtlmzZll9jzzyiFatWqW77rpLX331lZ599lm1b99ejRo10saNG0s8VpUqVdS5c2e98847kn5ZUbvttttUr169Esf7+Pjohhtu0P3336933nlHixYtUqtWrfTQQw9p06ZNxcYXXcJ9ww03XNoP4TyOHj2qbdu2qV+/fqpatarV3rx5c91zzz365JNPin1n586duvPOO1W/fn2tWbNGVapUsfqWLl2q9u3bq0qVKvrPf/5jbZGRkSooKNAXX3zhsq8HH3zQ5fvt27eXJP373/++pHls27ZNe/fu1Z///GcdP37cOm5ubq46duyoL774wuXydkl67LHHXD63b99ex48ft86Xosu/H3/8cZdxgwcPvqTaJCkyMlINGjSwPjdv3lxOp/OC8/zkk09Us2ZN3X///Vabn5+fBgwYcMFjBgYGKjc3V0lJSZdcb5HQ0FBFRUVd9PgBAwa4rBQPGjRInp6eJZ5HpemTTz5RSEiIevfubbV5eXlpyJAhOnXqlNavX+8yvrTOOwC41nF5OQBAklS5cmVJ0qlTp4r1vfbaazp58qQyMjL00EMPFeuPiopSVFSUTp8+rZSUFL377ruaO3euunbtqj179pR4b/ef//x
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"all_tokens = list(train_ds[\"n_tokens\"]) + list(val_ds[\"n_tokens\"]) + list(test_formatted[\"n_tokens\"])\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"plt.hist(all_tokens, bins=40)\n",
"plt.title(\"GSM8K token length distribution\")\n",
"plt.xlabel(\"Tokens\")\n",
"plt.ylabel(\"Count\")\n",
"plt.tight_layout()\n",
"plt.savefig(run_dir / \"token_lengths.png\", dpi=160)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "283b48b9",
"metadata": {},
"source": [
"## Evaluation helpers\n",
"\n",
"Exact match is computed from the final value inside `<answer>`. If the tag is missing, the code falls back to the last number in the completion."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2c643680",
"metadata": {},
"outputs": [],
"source": [
"number_re = re.compile(r\"-?\\d+(?:,\\d{3})*(?:\\.\\d+)?\")\n",
"answer_block_re = re.compile(r\"<answer>\\s*(.*?)\\s*</answer>\", re.S)\n",
"\n",
"def extract_gold_answer(text):\n",
" match = answer_re.search(text)\n",
" if match is not None:\n",
" return canonicalize_answer(match.group(1))\n",
" nums = number_re.findall(text.replace(\",\", \"\"))\n",
" if nums:\n",
" return canonicalize_answer(nums[-1])\n",
" return canonicalize_answer(text)\n",
"\n",
"def extract_pred_answer(text):\n",
" match = answer_block_re.search(text)\n",
" if match is not None:\n",
" return canonicalize_answer(match.group(1))\n",
" match = answer_re.search(text)\n",
" if match is not None:\n",
" return canonicalize_answer(match.group(1))\n",
" nums = number_re.findall(text.replace(\",\", \"\"))\n",
" if nums:\n",
" return canonicalize_answer(nums[-1])\n",
" return canonicalize_answer(text)\n",
"\n",
"def build_prompt(question):\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": question.strip()},\n",
" ]\n",
" return tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )\n",
"\n",
"def generate_batch(model, tokenizer, prompts, max_new_tokens=256):\n",
" device = next(model.parameters()).device\n",
" inputs = tokenizer(\n",
" prompts,\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" truncation=True,\n",
" ).to(device)\n",
"\n",
" with torch.inference_mode():\n",
" outputs = model.generate(\n",
" **inputs,\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=False,\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" )\n",
"\n",
" decoded = tokenizer.batch_decode(outputs, skip_special_tokens=False)\n",
" return decoded\n",
"\n",
"def evaluate_exact_match(model, tokenizer, dataset, n_examples=None, batch_size=4, max_new_tokens=512):\n",
" limit = len(dataset) if n_examples is None else min(len(dataset), n_examples)\n",
" subset = dataset.select(range(limit))\n",
" rows = []\n",
" correct = 0\n",
"\n",
" for start in tqdm(range(0, limit, batch_size)):\n",
" batch = subset[start:start + batch_size]\n",
" prompts = [build_prompt(q) for q in batch[\"question\"]]\n",
" decoded = generate_batch(model, tokenizer, prompts, max_new_tokens=max_new_tokens)\n",
"\n",
" for question, gold_text, prompt, full_text in zip(batch[\"question\"], batch[\"answer\"], prompts, decoded):\n",
" pred_text = full_text[len(prompt):] if full_text.startswith(prompt) else full_text\n",
" pred = extract_pred_answer(pred_text)\n",
" gold = extract_gold_answer(gold_text)\n",
" is_correct = int(pred == gold)\n",
" correct += is_correct\n",
" rows.append(\n",
" {\n",
" \"question\": question,\n",
" \"gold_answer\": gold,\n",
" \"pred_answer\": pred,\n",
" \"correct\": is_correct,\n",
" \"pred_text\": pred_text,\n",
" }\n",
" )\n",
"\n",
" df = pd.DataFrame(rows)\n",
" return df, correct / limit\n",
"\n",
" df = pd.DataFrame(rows)\n",
" return df, correct / limit"
]
},
{
"cell_type": "markdown",
"id": "234295e3",
"metadata": {},
"source": [
"## GPU logging"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "5a9f1472",
"metadata": {},
"outputs": [],
"source": [
"gpu_history = []\n",
"\n",
"class TrainStatusCallback(TrainerCallback):\n",
" def __init__(self, eval_subset):\n",
" self.eval_subset = eval_subset\n",
" self.trainer = None\n",
"\n",
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
" if not logs:\n",
" return\n",
" row = {\n",
" \"step\": int(state.global_step),\n",
" \"epoch\": float(state.epoch) if state.epoch is not None else np.nan,\n",
" \"time\": time.time(),\n",
" }\n",
" for key in [\"loss\", \"eval_loss\", \"learning_rate\", \"grad_norm\"]:\n",
" if key in logs:\n",
" row[key] = float(logs[key])\n",
" if torch.cuda.is_available():\n",
" handle = pynvml.nvmlDeviceGetHandleByIndex(torch.cuda.current_device())\n",
" util = pynvml.nvmlDeviceGetUtilizationRates(handle)\n",
" mem = pynvml.nvmlDeviceGetMemoryInfo(handle)\n",
" row[\"gpu_util\"] = float(util.gpu)\n",
" row[\"mem_used_mb\"] = float(mem.used / (1024 ** 2))\n",
" row[\"mem_total_mb\"] = float(mem.total / (1024 ** 2))\n",
" msg = f\"[step {state.global_step}] \"\n",
" if \"loss\" in logs:\n",
" msg += f\"loss={logs['loss']:.4f} \"\n",
" if \"eval_loss\" in logs:\n",
" msg += f\"eval_loss={logs['eval_loss']:.4f} \"\n",
" msg += f\"GPU={util.gpu}% VRAM={mem.used/1024**2:.0f}/{mem.total/1024**2:.0f} MB\"\n",
" print(msg)\n",
" gpu_history.append(row)\n",
"\n",
" def on_evaluate(self, args, state, control, **kwargs):\n",
" if self.trainer is None or self.eval_subset is None:\n",
" return\n",
" metrics_df, acc = evaluate_exact_match(\n",
" self.trainer.model,\n",
" self.trainer.processing_class,\n",
" self.eval_subset,\n",
" n_examples=100,\n",
" batch_size=4,\n",
" max_new_tokens=256,\n",
" )\n",
" metrics = {\n",
" \"eval_exact_match_100\": acc,\n",
" \"eval_correct_100\": int(metrics_df[\"correct\"].sum()),\n",
" \"eval_total_100\": int(len(metrics_df)),\n",
" }\n",
" self.trainer.log(metrics)\n",
" payload = dict(metrics)\n",
" payload[\"step\"] = int(state.global_step)\n",
" payload[\"epoch\"] = float(state.epoch) if state.epoch is not None else np.nan\n",
" payload[\"time\"] = time.time()\n",
" gpu_history.append(payload)\n",
" print(f\"[step {state.global_step}] eval_exact_match_100={acc:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "85406888",
"metadata": {},
"source": [
"## Trainer"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "36f32f0c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"trainer ready\n"
]
}
],
"source": [
"eval_subset = val_base.select(range(min(100, len(val_base))))\n",
"\n",
"lora_config = LoraConfig(\n",
" r=16,\n",
" lora_alpha=32,\n",
" lora_dropout=0.05,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
")\n",
"\n",
"training_args = SFTConfig(\n",
" output_dir=str(adapter_dir),\n",
" per_device_train_batch_size=4,\n",
" per_device_eval_batch_size=4,\n",
" gradient_accumulation_steps=8,\n",
" num_train_epochs=1,\n",
" learning_rate=2e-4,\n",
" warmup_steps=20,\n",
" lr_scheduler_type=\"cosine\",\n",
" weight_decay=0.0,\n",
" optim=\"adamw_torch\",\n",
" logging_strategy=\"steps\",\n",
" logging_steps=1,\n",
" eval_strategy=\"steps\",\n",
" eval_steps=25,\n",
" eval_on_start=True,\n",
" save_strategy=\"steps\",\n",
" save_steps=50,\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"eval_loss\",\n",
" greater_is_better=False,\n",
" bf16=bf16,\n",
" fp16=fp16,\n",
" gradient_checkpointing=True,\n",
" report_to=\"tensorboard\",\n",
" seed=seed,\n",
" packing=False,\n",
" dataset_text_field=\"text\",\n",
" max_length=768,\n",
" logging_first_step=True,\n",
")\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_ds,\n",
" eval_dataset=val_ds,\n",
" peft_config=lora_config,\n",
" processing_class=tokenizer,\n",
")\n",
"\n",
"status_cb = TrainStatusCallback(eval_subset)\n",
"status_cb.trainer = trainer\n",
"trainer.add_callback(status_cb)\n",
"\n",
"print(\"trainer ready\")"
]
},
{
"cell_type": "markdown",
"id": "26a350cf",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cc3596a6",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[transformers] The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 151645, 'bos_token_id': None, 'pad_token_id': 151643}.\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='222' max='222' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [222/222 24:28, Epoch 1/1]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>No log</td>\n",
" <td>2.089016</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>0.408335</td>\n",
" <td>0.418331</td>\n",
" </tr>\n",
" <tr>\n",
" <td>50</td>\n",
" <td>0.368056</td>\n",
" <td>0.366697</td>\n",
" </tr>\n",
" <tr>\n",
" <td>75</td>\n",
" <td>0.343977</td>\n",
" <td>0.356384</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.366282</td>\n",
" <td>0.350802</td>\n",
" </tr>\n",
" <tr>\n",
" <td>125</td>\n",
" <td>0.339358</td>\n",
" <td>0.348205</td>\n",
" </tr>\n",
" <tr>\n",
" <td>150</td>\n",
" <td>0.390431</td>\n",
" <td>0.346121</td>\n",
" </tr>\n",
" <tr>\n",
" <td>175</td>\n",
" <td>0.346124</td>\n",
" <td>0.344797</td>\n",
" </tr>\n",
" <tr>\n",
" <td>200</td>\n",
" <td>0.352397</td>\n",
" <td>0.344206</td>\n",
" </tr>\n",
" <tr>\n",
" <td>222</td>\n",
" <td>0.369009</td>\n",
" <td>0.344098</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 0] eval_loss=2.0890 GPU=95% VRAM=16318/32607 MB\n",
"[step 0] GPU=78% VRAM=16332/32607 MB\n",
"[step 0] eval_exact_match_100=0.2200\n",
"[step 1] loss=2.1023 GPU=88% VRAM=16778/32607 MB\n",
"[step 2] loss=2.2198 GPU=87% VRAM=17670/32607 MB\n",
"[step 3] loss=2.1929 GPU=91% VRAM=17670/32607 MB\n",
"[step 4] loss=2.1926 GPU=71% VRAM=18656/32607 MB\n",
"[step 5] loss=2.0107 GPU=97% VRAM=18656/32607 MB\n",
"[step 6] loss=1.8966 GPU=96% VRAM=18656/32607 MB\n",
"[step 7] loss=1.8916 GPU=95% VRAM=18656/32607 MB\n",
"[step 8] loss=1.7318 GPU=90% VRAM=18656/32607 MB\n",
"[step 9] loss=1.5379 GPU=88% VRAM=18656/32607 MB\n",
"[step 10] loss=1.2925 GPU=97% VRAM=19818/32607 MB\n",
"[step 11] loss=1.2532 GPU=94% VRAM=19818/32607 MB\n",
"[step 12] loss=1.0723 GPU=90% VRAM=19818/32607 MB\n",
"[step 13] loss=0.9937 GPU=95% VRAM=19818/32607 MB\n",
"[step 14] loss=0.8682 GPU=95% VRAM=19818/32607 MB\n",
"[step 15] loss=0.7871 GPU=89% VRAM=19818/32607 MB\n",
"[step 16] loss=0.6485 GPU=97% VRAM=19818/32607 MB\n",
"[step 17] loss=0.5811 GPU=96% VRAM=19818/32607 MB\n",
"[step 18] loss=0.5456 GPU=94% VRAM=19818/32607 MB\n",
"[step 19] loss=0.5182 GPU=95% VRAM=19818/32607 MB\n",
"[step 20] loss=0.4801 GPU=96% VRAM=19818/32607 MB\n",
"[step 21] loss=0.4693 GPU=89% VRAM=19818/32607 MB\n",
"[step 22] loss=0.4335 GPU=97% VRAM=19818/32607 MB\n",
"[step 23] loss=0.4865 GPU=95% VRAM=19818/32607 MB\n",
"[step 24] loss=0.4407 GPU=87% VRAM=19818/32607 MB\n",
"[step 25] loss=0.4083 GPU=27% VRAM=19818/32607 MB\n",
"[step 25] eval_loss=0.4183 GPU=95% VRAM=19818/32607 MB\n",
"[step 25] GPU=73% VRAM=19818/32607 MB\n",
"[step 25] eval_exact_match_100=0.2100\n",
"[step 26] loss=0.4382 GPU=94% VRAM=19818/32607 MB\n",
"[step 27] loss=0.4439 GPU=96% VRAM=19818/32607 MB\n",
"[step 28] loss=0.4533 GPU=95% VRAM=19818/32607 MB\n",
"[step 29] loss=0.4072 GPU=95% VRAM=21048/32607 MB\n",
"[step 30] loss=0.4260 GPU=89% VRAM=21048/32607 MB\n",
"[step 31] loss=0.3852 GPU=96% VRAM=21048/32607 MB\n",
"[step 32] loss=0.3975 GPU=88% VRAM=21048/32607 MB\n",
"[step 33] loss=0.3895 GPU=95% VRAM=21048/32607 MB\n",
"[step 34] loss=0.3756 GPU=89% VRAM=21048/32607 MB\n",
"[step 35] loss=0.3903 GPU=96% VRAM=21048/32607 MB\n",
"[step 36] loss=0.3878 GPU=96% VRAM=21048/32607 MB\n",
"[step 37] loss=0.3829 GPU=96% VRAM=21048/32607 MB\n",
"[step 38] loss=0.4120 GPU=95% VRAM=21048/32607 MB\n",
"[step 39] loss=0.3770 GPU=81% VRAM=21048/32607 MB\n",
"[step 40] loss=0.3588 GPU=96% VRAM=21048/32607 MB\n",
"[step 41] loss=0.4047 GPU=94% VRAM=21048/32607 MB\n",
"[step 42] loss=0.3753 GPU=95% VRAM=21048/32607 MB\n",
"[step 43] loss=0.3822 GPU=88% VRAM=21048/32607 MB\n",
"[step 44] loss=0.3920 GPU=96% VRAM=21048/32607 MB\n",
"[step 45] loss=0.3913 GPU=18% VRAM=21048/32607 MB\n",
"[step 46] loss=0.3862 GPU=90% VRAM=21048/32607 MB\n",
"[step 47] loss=0.3796 GPU=88% VRAM=21048/32607 MB\n",
"[step 48] loss=0.3716 GPU=90% VRAM=21048/32607 MB\n",
"[step 49] loss=0.3870 GPU=95% VRAM=21048/32607 MB\n",
"[step 50] loss=0.3681 GPU=89% VRAM=21048/32607 MB\n",
"[step 50] eval_loss=0.3667 GPU=95% VRAM=21048/32607 MB\n",
"[step 50] GPU=77% VRAM=21048/32607 MB\n",
"[step 50] eval_exact_match_100=0.2200\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/peft/utils/save_and_load.py:386: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 51] loss=0.3877 GPU=94% VRAM=21048/32607 MB\n",
"[step 52] loss=0.3838 GPU=95% VRAM=21048/32607 MB\n",
"[step 53] loss=0.3443 GPU=96% VRAM=21048/32607 MB\n",
"[step 54] loss=0.3773 GPU=95% VRAM=21048/32607 MB\n",
"[step 55] loss=0.3573 GPU=96% VRAM=21048/32607 MB\n",
"[step 56] loss=0.3939 GPU=89% VRAM=21048/32607 MB\n",
"[step 57] loss=0.3672 GPU=95% VRAM=21048/32607 MB\n",
"[step 58] loss=0.3673 GPU=88% VRAM=21048/32607 MB\n",
"[step 59] loss=0.3506 GPU=96% VRAM=21048/32607 MB\n",
"[step 60] loss=0.4168 GPU=95% VRAM=21048/32607 MB\n",
"[step 61] loss=0.3795 GPU=95% VRAM=21048/32607 MB\n",
"[step 62] loss=0.3852 GPU=95% VRAM=21048/32607 MB\n",
"[step 63] loss=0.3734 GPU=96% VRAM=21048/32607 MB\n",
"[step 64] loss=0.3649 GPU=90% VRAM=21048/32607 MB\n",
"[step 65] loss=0.3797 GPU=96% VRAM=21048/32607 MB\n",
"[step 66] loss=0.3820 GPU=88% VRAM=21048/32607 MB\n",
"[step 67] loss=0.3699 GPU=96% VRAM=21048/32607 MB\n",
"[step 68] loss=0.3923 GPU=29% VRAM=21048/32607 MB\n",
"[step 69] loss=0.3930 GPU=95% VRAM=21048/32607 MB\n",
"[step 70] loss=0.3757 GPU=97% VRAM=21048/32607 MB\n",
"[step 71] loss=0.3685 GPU=95% VRAM=21048/32607 MB\n",
"[step 72] loss=0.3598 GPU=97% VRAM=21048/32607 MB\n",
"[step 73] loss=0.3823 GPU=87% VRAM=21048/32607 MB\n",
"[step 74] loss=0.3447 GPU=88% VRAM=21048/32607 MB\n",
"[step 75] loss=0.3440 GPU=89% VRAM=21048/32607 MB\n",
"[step 75] eval_loss=0.3564 GPU=95% VRAM=21048/32607 MB\n",
"[step 75] GPU=77% VRAM=21048/32607 MB\n",
"[step 75] eval_exact_match_100=0.2200\n",
"[step 76] loss=0.3561 GPU=96% VRAM=21048/32607 MB\n",
"[step 77] loss=0.3843 GPU=89% VRAM=21048/32607 MB\n",
"[step 78] loss=0.3969 GPU=95% VRAM=21048/32607 MB\n",
"[step 79] loss=0.3518 GPU=87% VRAM=21048/32607 MB\n",
"[step 80] loss=0.3564 GPU=88% VRAM=21048/32607 MB\n",
"[step 81] loss=0.3557 GPU=95% VRAM=21048/32607 MB\n",
"[step 82] loss=0.3462 GPU=89% VRAM=21048/32607 MB\n",
"[step 83] loss=0.3421 GPU=92% VRAM=21048/32607 MB\n",
"[step 84] loss=0.3248 GPU=97% VRAM=21048/32607 MB\n",
"[step 85] loss=0.3470 GPU=97% VRAM=21048/32607 MB\n",
"[step 86] loss=0.3568 GPU=95% VRAM=21048/32607 MB\n",
"[step 87] loss=0.3598 GPU=88% VRAM=21048/32607 MB\n",
"[step 88] loss=0.3614 GPU=81% VRAM=21048/32607 MB\n",
"[step 89] loss=0.3725 GPU=86% VRAM=21048/32607 MB\n",
"[step 90] loss=0.3347 GPU=94% VRAM=21048/32607 MB\n",
"[step 91] loss=0.3387 GPU=88% VRAM=21048/32607 MB\n",
"[step 92] loss=0.3457 GPU=96% VRAM=21048/32607 MB\n",
"[step 93] loss=0.3184 GPU=95% VRAM=21048/32607 MB\n",
"[step 94] loss=0.3375 GPU=96% VRAM=21048/32607 MB\n",
"[step 95] loss=0.3204 GPU=96% VRAM=21048/32607 MB\n",
"[step 96] loss=0.3575 GPU=87% VRAM=21048/32607 MB\n",
"[step 97] loss=0.3163 GPU=87% VRAM=27864/32607 MB\n",
"[step 98] loss=0.3355 GPU=96% VRAM=27864/32607 MB\n",
"[step 99] loss=0.3385 GPU=89% VRAM=27864/32607 MB\n",
"[step 100] loss=0.3663 GPU=96% VRAM=27864/32607 MB\n",
"[step 100] eval_loss=0.3508 GPU=97% VRAM=27864/32607 MB\n",
"[step 100] GPU=78% VRAM=27864/32607 MB\n",
"[step 100] eval_exact_match_100=0.2100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/peft/utils/save_and_load.py:386: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 101] loss=0.3436 GPU=95% VRAM=27864/32607 MB\n",
"[step 102] loss=0.3894 GPU=95% VRAM=27864/32607 MB\n",
"[step 103] loss=0.3668 GPU=87% VRAM=27864/32607 MB\n",
"[step 104] loss=0.3580 GPU=95% VRAM=27864/32607 MB\n",
"[step 105] loss=0.3599 GPU=95% VRAM=27864/32607 MB\n",
"[step 106] loss=0.3737 GPU=96% VRAM=27864/32607 MB\n",
"[step 107] loss=0.3622 GPU=87% VRAM=27864/32607 MB\n",
"[step 108] loss=0.3490 GPU=96% VRAM=27864/32607 MB\n",
"[step 109] loss=0.3608 GPU=88% VRAM=27864/32607 MB\n",
"[step 110] loss=0.3584 GPU=95% VRAM=27864/32607 MB\n",
"[step 111] loss=0.3820 GPU=88% VRAM=27864/32607 MB\n",
"[step 112] loss=0.3490 GPU=87% VRAM=27864/32607 MB\n",
"[step 113] loss=0.3541 GPU=97% VRAM=27864/32607 MB\n",
"[step 114] loss=0.3481 GPU=96% VRAM=27864/32607 MB\n",
"[step 115] loss=0.3321 GPU=94% VRAM=27864/32607 MB\n",
"[step 116] loss=0.3445 GPU=96% VRAM=27864/32607 MB\n",
"[step 117] loss=0.3203 GPU=90% VRAM=27864/32607 MB\n",
"[step 118] loss=0.3341 GPU=88% VRAM=27864/32607 MB\n",
"[step 119] loss=0.3575 GPU=96% VRAM=27864/32607 MB\n",
"[step 120] loss=0.3498 GPU=96% VRAM=27864/32607 MB\n",
"[step 121] loss=0.3389 GPU=96% VRAM=27864/32607 MB\n",
"[step 122] loss=0.3646 GPU=94% VRAM=27864/32607 MB\n",
"[step 123] loss=0.3489 GPU=96% VRAM=27864/32607 MB\n",
"[step 124] loss=0.3209 GPU=89% VRAM=27864/32607 MB\n",
"[step 125] loss=0.3394 GPU=97% VRAM=27864/32607 MB\n",
"[step 125] eval_loss=0.3482 GPU=95% VRAM=27864/32607 MB\n",
"[step 125] GPU=77% VRAM=27864/32607 MB\n",
"[step 125] eval_exact_match_100=0.2400\n",
"[step 126] loss=0.3233 GPU=80% VRAM=27864/32607 MB\n",
"[step 127] loss=0.3708 GPU=88% VRAM=27864/32607 MB\n",
"[step 128] loss=0.3399 GPU=96% VRAM=27864/32607 MB\n",
"[step 129] loss=0.3445 GPU=95% VRAM=27864/32607 MB\n",
"[step 130] loss=0.3676 GPU=96% VRAM=27864/32607 MB\n",
"[step 131] loss=0.3420 GPU=97% VRAM=27864/32607 MB\n",
"[step 132] loss=0.3942 GPU=94% VRAM=27864/32607 MB\n",
"[step 133] loss=0.3539 GPU=95% VRAM=27864/32607 MB\n",
"[step 134] loss=0.3653 GPU=89% VRAM=27864/32607 MB\n",
"[step 135] loss=0.3522 GPU=95% VRAM=27864/32607 MB\n",
"[step 136] loss=0.3555 GPU=19% VRAM=27864/32607 MB\n",
"[step 137] loss=0.3490 GPU=94% VRAM=27864/32607 MB\n",
"[step 138] loss=0.3494 GPU=96% VRAM=27864/32607 MB\n",
"[step 139] loss=0.3327 GPU=95% VRAM=27864/32607 MB\n",
"[step 140] loss=0.3559 GPU=97% VRAM=27864/32607 MB\n",
"[step 141] loss=0.3564 GPU=96% VRAM=27864/32607 MB\n",
"[step 142] loss=0.3700 GPU=96% VRAM=27864/32607 MB\n",
"[step 143] loss=0.3610 GPU=96% VRAM=27864/32607 MB\n",
"[step 144] loss=0.3599 GPU=91% VRAM=27864/32607 MB\n",
"[step 145] loss=0.3541 GPU=87% VRAM=27864/32607 MB\n",
"[step 146] loss=0.3151 GPU=96% VRAM=27864/32607 MB\n",
"[step 147] loss=0.3659 GPU=96% VRAM=27864/32607 MB\n",
"[step 148] loss=0.3822 GPU=96% VRAM=27864/32607 MB\n",
"[step 149] loss=0.3070 GPU=97% VRAM=27864/32607 MB\n",
"[step 150] loss=0.3904 GPU=87% VRAM=27864/32607 MB\n",
"[step 150] eval_loss=0.3461 GPU=97% VRAM=27864/32607 MB\n",
"[step 150] GPU=78% VRAM=27864/32607 MB\n",
"[step 150] eval_exact_match_100=0.2200\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/peft/utils/save_and_load.py:386: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 151] loss=0.3328 GPU=91% VRAM=27864/32607 MB\n",
"[step 152] loss=0.3426 GPU=96% VRAM=27864/32607 MB\n",
"[step 153] loss=0.3485 GPU=96% VRAM=27864/32607 MB\n",
"[step 154] loss=0.3344 GPU=87% VRAM=27864/32607 MB\n",
"[step 155] loss=0.3581 GPU=95% VRAM=27864/32607 MB\n",
"[step 156] loss=0.3511 GPU=96% VRAM=27864/32607 MB\n",
"[step 157] loss=0.3593 GPU=94% VRAM=27864/32607 MB\n",
"[step 158] loss=0.3500 GPU=89% VRAM=27864/32607 MB\n",
"[step 159] loss=0.3475 GPU=96% VRAM=27864/32607 MB\n",
"[step 160] loss=0.3545 GPU=95% VRAM=27864/32607 MB\n",
"[step 161] loss=0.3571 GPU=96% VRAM=27864/32607 MB\n",
"[step 162] loss=0.3605 GPU=94% VRAM=27864/32607 MB\n",
"[step 163] loss=0.3626 GPU=96% VRAM=27864/32607 MB\n",
"[step 164] loss=0.3316 GPU=96% VRAM=27864/32607 MB\n",
"[step 165] loss=0.3422 GPU=90% VRAM=27864/32607 MB\n",
"[step 166] loss=0.3302 GPU=96% VRAM=27864/32607 MB\n",
"[step 167] loss=0.3241 GPU=86% VRAM=27864/32607 MB\n",
"[step 168] loss=0.3446 GPU=95% VRAM=27864/32607 MB\n",
"[step 169] loss=0.3398 GPU=89% VRAM=27864/32607 MB\n",
"[step 170] loss=0.3516 GPU=89% VRAM=27864/32607 MB\n",
"[step 171] loss=0.3280 GPU=96% VRAM=27864/32607 MB\n",
"[step 172] loss=0.3572 GPU=89% VRAM=27864/32607 MB\n",
"[step 173] loss=0.3705 GPU=96% VRAM=27864/32607 MB\n",
"[step 174] loss=0.3475 GPU=89% VRAM=27864/32607 MB\n",
"[step 175] loss=0.3461 GPU=96% VRAM=27864/32607 MB\n",
"[step 175] eval_loss=0.3448 GPU=95% VRAM=27864/32607 MB\n",
"[step 175] GPU=78% VRAM=27864/32607 MB\n",
"[step 175] eval_exact_match_100=0.2300\n",
"[step 176] loss=0.3703 GPU=96% VRAM=27864/32607 MB\n",
"[step 177] loss=0.3453 GPU=94% VRAM=27864/32607 MB\n",
"[step 178] loss=0.3461 GPU=90% VRAM=27864/32607 MB\n",
"[step 179] loss=0.3642 GPU=95% VRAM=27864/32607 MB\n",
"[step 180] loss=0.3259 GPU=96% VRAM=27864/32607 MB\n",
"[step 181] loss=0.3705 GPU=96% VRAM=27864/32607 MB\n",
"[step 182] loss=0.3460 GPU=96% VRAM=27864/32607 MB\n",
"[step 183] loss=0.3643 GPU=94% VRAM=27864/32607 MB\n",
"[step 184] loss=0.3348 GPU=95% VRAM=27864/32607 MB\n",
"[step 185] loss=0.3759 GPU=94% VRAM=27864/32607 MB\n",
"[step 186] loss=0.3468 GPU=95% VRAM=27864/32607 MB\n",
"[step 187] loss=0.3610 GPU=88% VRAM=27864/32607 MB\n",
"[step 188] loss=0.3439 GPU=89% VRAM=27864/32607 MB\n",
"[step 189] loss=0.3299 GPU=97% VRAM=27864/32607 MB\n",
"[step 190] loss=0.3248 GPU=89% VRAM=27864/32607 MB\n",
"[step 191] loss=0.3324 GPU=96% VRAM=27864/32607 MB\n",
"[step 192] loss=0.3690 GPU=87% VRAM=27864/32607 MB\n",
"[step 193] loss=0.3164 GPU=96% VRAM=27864/32607 MB\n",
"[step 194] loss=0.3486 GPU=89% VRAM=27864/32607 MB\n",
"[step 195] loss=0.3437 GPU=91% VRAM=27864/32607 MB\n",
"[step 196] loss=0.3566 GPU=91% VRAM=27864/32607 MB\n",
"[step 197] loss=0.3529 GPU=96% VRAM=27864/32607 MB\n",
"[step 198] loss=0.3703 GPU=94% VRAM=27864/32607 MB\n",
"[step 199] loss=0.3299 GPU=97% VRAM=27864/32607 MB\n",
"[step 200] loss=0.3524 GPU=96% VRAM=27864/32607 MB\n",
"[step 200] eval_loss=0.3442 GPU=95% VRAM=27864/32607 MB\n",
"[step 200] GPU=78% VRAM=27864/32607 MB\n",
"[step 200] eval_exact_match_100=0.2100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/peft/utils/other.py:1419: UserWarning: Unable to fetch remote file due to the following error [Errno 97] Address family not supported by protocol - silently ignoring the lookup for the file config.json in MegaScience/Qwen3-4B-MegaScience.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.12/dist-packages/peft/utils/save_and_load.py:372: UserWarning: Could not find a config file in MegaScience/Qwen3-4B-MegaScience - will assume that the vocabulary was not modified.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 201] loss=0.3513 GPU=97% VRAM=27864/32607 MB\n",
"[step 202] loss=0.3486 GPU=95% VRAM=27864/32607 MB\n",
"[step 203] loss=0.3507 GPU=95% VRAM=27864/32607 MB\n",
"[step 204] loss=0.3456 GPU=97% VRAM=27864/32607 MB\n",
"[step 205] loss=0.3612 GPU=96% VRAM=27864/32607 MB\n",
"[step 206] loss=0.3479 GPU=97% VRAM=27864/32607 MB\n",
"[step 207] loss=0.3522 GPU=95% VRAM=27864/32607 MB\n",
"[step 208] loss=0.3537 GPU=86% VRAM=27864/32607 MB\n",
"[step 209] loss=0.3729 GPU=95% VRAM=27864/32607 MB\n",
"[step 210] loss=0.3466 GPU=96% VRAM=27864/32607 MB\n",
"[step 211] loss=0.3283 GPU=96% VRAM=27864/32607 MB\n",
"[step 212] loss=0.3392 GPU=89% VRAM=27864/32607 MB\n",
"[step 213] loss=0.3619 GPU=97% VRAM=27864/32607 MB\n",
"[step 214] loss=0.3700 GPU=87% VRAM=27864/32607 MB\n",
"[step 215] loss=0.3551 GPU=95% VRAM=27864/32607 MB\n",
"[step 216] loss=0.3346 GPU=94% VRAM=27864/32607 MB\n",
"[step 217] loss=0.3286 GPU=95% VRAM=27864/32607 MB\n",
"[step 218] loss=0.3198 GPU=96% VRAM=27864/32607 MB\n",
"[step 219] loss=0.3513 GPU=89% VRAM=27864/32607 MB\n",
"[step 220] loss=0.3585 GPU=95% VRAM=27864/32607 MB\n",
"[step 221] loss=0.3472 GPU=96% VRAM=27864/32607 MB\n",
"[step 222] loss=0.3690 GPU=94% VRAM=27864/32607 MB\n",
"[step 222] eval_loss=0.3441 GPU=95% VRAM=27864/32607 MB\n",
"[step 222] GPU=78% VRAM=27864/32607 MB\n",
"[step 222] eval_exact_match_100=0.2100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.12/dist-packages/peft/utils/save_and_load.py:386: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[step 222] GPU=14% VRAM=27864/32607 MB\n",
"TrainOutput(global_step=222, training_loss=0.4492494880079149, metrics={'train_runtime': 1580.5321, 'train_samples_per_second': 4.492, 'train_steps_per_second': 0.14, 'total_flos': 4.806038094014976e+16, 'train_loss': 0.4492494880079149})\n"
]
}
],
"source": [
"train_result = trainer.train()\n",
"print(train_result)"
]
},
{
"cell_type": "markdown",
"id": "042d7bba",
"metadata": {},
"source": [
"## Logs and plots\n",
"\n",
"Loss converges within one epoch. Training doesnt show any improvement in the accuracy metric, which demonstrates the inefficiency of training on task-solution pairs while trying to give the model common sense."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "807b37a4",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAAhehJREFUeJzs3Xd8FHX+x/HX7CbZ9EoKgUDonVAEBEVRUOCwoKciogin5/089FRO7+Tu7Hqop5717AioKFbsKKKgKEV6rwIJpCekl0125/fHJitLSyHJJvB+Ph7zyOzMd77zmc0S9jPf73y/hmmaJiIiIiIiIiLS4CzeDkBERERERETkVKWkW0RERERERKSRKOkWERERERERaSRKukVEREREREQaiZJuERERERERkUaipFtERERERESkkSjpFhEREREREWkkSrpFREREREREGomSbhEREREREZFGoqRbREROe1OmTCExMdHbYdTLiBEjGDFiRJOfNzExkSlTprhfL1myBMMwWLJkSY3HNkbM999/P4ZhNGidtWUYBvfff79Xzi0iIs2fkm4REWm2DMOo1VKbRE9avpKSEu6//379vkVEpEXx8XYAIiIix/Pmm296vJ47dy6LFi06anuPHj1O6jyvvvoqTqfzpOo43Z1zzjmUlpbi5+fXaOcoKSnhgQceADiqpfxf//oXd999d6OdW0REpL6UdIuISLN17bXXerxesWIFixYtOmr7kUpKSggMDKz1eXx9fesVn/zGYrHg7+/vtfP7+Pjg46OvNSIi0vyoe7mIiLRoI0aMoHfv3qxZs4ZzzjmHwMBA/vGPfwDwySefMG7cOOLj47HZbHTq1ImHHnoIh8PhUceRz3Tv27cPwzB44okneOWVV+jUqRM2m41Bgwbxyy+/1BhTbm4ud955J3369CE4OJjQ0FDGjh3Lhg0bPMpVPwf93nvv8cgjj9C2bVv8/f0ZOXIku3fvPqre6lgCAgIYPHgwP/74Y63eo969e3Peeecdtd3pdNKmTRuuuOIK97YnnniCYcOGERUVRUBAAAMHDuSDDz6o8RzHe6a7NjHb7XbuvfdeBg4cSFhYGEFBQQwfPpzvv//eXWbfvn1ER0cD8MADD7gfLah+lvpYz3RXVlby0EMPuX9/iYmJ/OMf/6C8vNyjXGJiIhdddBHLli1j8ODB+Pv707FjR+bOnVvjdR/PunXrGDt2LKGhoQQHBzNy5EhWrFjhUaaiooIHHniALl264O/vT1RUFGeffTaLFi1yl0lPT2fq1Km0bdsWm81G69atufTSS9m3b1+9YxMRkaalW8IiItLi5eTkMHbsWK6++mquvfZaYmNjAZg9ezbBwcFMnz6d4OBgvvvuO+69914KCgr4z3/+U2O98+bNo7CwkD/96U8YhsHjjz/O5Zdfzq+//nrC1vFff/2VBQsWcOWVV9KhQwcyMjJ4+eWXOffcc9m6dSvx8fEe5R999FEsFgt33nkn+fn5PP7440yaNImVK1e6y7z++uv86U9/YtiwYdx+++38+uuvXHLJJURGRpKQkHDC65gwYQL3338/6enpxMXFubcvW7aM1NRUrr76ave2Z555hksuuYRJkyZht9t59913ufLKK/n8888ZN25cje/Z4Wobc0FBAa+99hoTJ07kj3/8I4WFhbz++uuMHj2aVatW0a9fP6Kjo3nxxRe5+eabueyyy7j88ssB6Nu373HPf+ONNzJnzhyuuOIK/vrXv7Jy5UpmzpzJtm3b+Pjjjz3K7t69myuuuIIbbriB66+/nlmzZjFlyhQGDhxIr1696nTdW7ZsYfjw4YSGhvK3v/0NX19fXn75ZUaMGMHSpUsZMmQI4LpRMHPmTG688UYGDx5MQUEBq1evZu3atVxwwQUA/P73v2fLli3ceuutJCYmkpmZyaJFi0hOTm6xg/+JiJx2TBERkRZi2rRp5pH/dZ177rkmYL700ktHlS8pKTlq25/+9CczMDDQLCsrc2+7/vrrzfbt27tf79271wTMqKgoMzc31739k08+MQHzs88+O2GcZWVlpsPh8Ni2d+9e02azmQ8++KB72/fff28CZo8ePczy8nL39meeecYEzE2bNpmmaZp2u92MiYkx+/Xr51HulVdeMQHz3HPPPWE8O3bsMAHzueee89j+5z//2QwODvZ4n458z+x2u9m7d2/z/PPP99jevn178/rrrz/qWr7//vs6x1xZWelRxjRN89ChQ2ZsbKz5hz/8wb0tKyvLBMz77rvvqGu87777PD4b69evNwHzxhtv9Ch35513moD53XffeVwLYP7www/ubZmZmabNZjP/+te/HnWuIx0Z0/jx400/Pz9zz5497m2pqalmSEiIec4557i3JSUlmePGjTtuvYcOHTIB8z//+U+NMYiISPOl7uUiItLi2Ww2pk6detT2gIAA93phYSHZ2dkMHz6ckpIStm/fXmO9EyZMICIiwv16+PDhgKslu6Z4LBbXf7EOh4OcnByCg4Pp1q0ba9euPar81KlTPQYgO/I8q1evJjMzk//7v//zKDdlyhTCwsJqvI6uXbvSr18/5s+f797mcDj44IMPuPjiiz3ep8PXDx06RH5+PsOHDz9m3CdSl5itVqu7jNPpJDc3l8rKSs4444w6n7fal19+CcD06dM9tv/1r38F4IsvvvDY3rNnT/f7DhAdHU23bt1q/F0fyeFw8M033zB+/Hg6duzo3t66dWuuueYali1bRkFBAQDh4eFs2bKFXbt2HbOugIAA/Pz8WLJkCYcOHapTHCIi0nwo6RYRkRavTZs2xxw1e8uWLVx22WWEhYURGhpKdHS0exC2/Pz8Gutt166dx+vqBLymBMjpdPLf//6XLl26YLPZaNWqFdHR0WzcuPGY563pPPv37wegS5cuHuV8fX09ErsTmTBhAj/99BMHDx4EXM9gZ2ZmMmHCBI9yn3/+OWeeeSb+/v5ERka6u3XX5v06XF1jnjNnDn379nU/2xwdHc0XX3xR5/Mefn6LxULnzp09tsfFxREeHu6Or9qRvwNw/R7qmuxmZWVRUlJCt27djtrXo0cPnE4nKSkpADz44IPk5eXRtWtX+vTpw1133cXGjRvd5W02G4899hhfffUVsbGxnHPOOTz++OOkp6fXKSYREfEuJd0iItLiHd46Wy0vL49zzz2XDRs28OCDD/LZZ5+xaNEiHnvsMYBaTRFmtVqPud00zRMe9+9//5vp06dzzjnn8NZbb/H111+zaNEievXqdczz1vc8dTFhwgRM0+T9998H4L333iMsLIwxY8a4y/z4449ccskl+Pv787///Y8vv/ySRYsWcc011zRoLEd66623mDJlCp06deL1119n4cKFLFq0iPPPP/+kp3I7cnC142mK38GRzjnnHPbs2cOsWbPo3bs3r732GgMGDOC1115zl7n99tvZuXMnM2fOxN/fn3vuuYcePXqwbt26RotLREQalpJuERE5JS1ZsoScnBxmz57NbbfdxkUXXcSoUaM8uos3lg8++IDzzjuP119/nauvvpoLL7yQUaNGkZeXV6/62rdvD3BUN+SKigr27t1bqzo6dOjA4MGDmT9/PpWVlXz00UeMHz8em83mLvPhhx/i7+/P119/zR/+8AfGjh3LqFGjGj3mDz74gI4dO/LRRx9x3XXXMXr0aEaNGkVZWZlHudom0NXndzqdR50/IyODvLw8d3wNLTo6msDAQHbs2HHUvu3bt2OxWDwGkYuMjGTq1Km88847pKSk0LdvX/eI7NU6derEX//6V7755hs2b96M3W7nySefbJT4RUSk4SnpFhGRU1J1y+XhLZV2u53//e9/TXLuI1tI33//fXfX7ro644wziI6O5qWXXsJut7u3z549u06J/IQJE1ixYgWzZs0iOzv7qK7lVqsVwzA8plTbt28fCxYsaNSYj/W7WrlyJcuXL/coVz33em2u+Xe/+x0ATz/9tMf2p556CqDOI7HXltVq5cILL+STTz7xmNYrIyODefPmcfbZZxMaGgq4Rt0/XHBwMJ07d3ZPaVZSUnLUjYdOnToREhJy1LRnIiLSfGnKMBE
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAAqzhJREFUeJzs3Xd8VFX6x/HPzKT3BEIgtNAhlIQuIB1FVBArdkBXXcV1FV1/rrtrQVfXujZWbBTbCrqIioh0kCKQkNBbICEhIQnpJCFt5v7+mGQgJIEQQgp836/XvDK5c+69z0ySyTz3nPMck2EYBiIiIiIiIiJS68z1HYCIiIiIiIjIpUpJt4iIiIiIiMhFoqRbRERERERE5CJR0i0iIiIiIiJykSjpFhEREREREblIlHSLiIiIiIiIXCRKukVEREREREQuEiXdIiIiIiIiIheJkm4RERERERGRi0RJt4iI1KkpU6YQEhJS32HUyIgRIxgxYkSdnzckJIQpU6Y4vl+zZg0mk4k1a9acc9+LEfMLL7yAyWSq1WNWl8lk4oUXXqiXc1eltl+PuLg4TCYTc+fOrbVjNlSN+f1ARKS6lHSLiAhgT2aqc6tOoieNX35+Pi+88IJ+3iIiIhfIqb4DEBGRhuGLL74o9/3nn3/O8uXLK2zv1q3bBZ3nk08+wWazXdAxLnfDhg3j5MmTuLi4XLRz5Ofn8+KLLwJU6Cn/+9//zjPPPHPRzt3Y6PUQEZGzUdItIiIA3H333eW+//3331m+fHmF7WfKz8/Hw8Oj2udxdnauUXxyitlsxs3Nrd7O7+TkhJOTPkKU0eshIiJno+HlIiJSbSNGjKBHjx5ERkYybNgwPDw8ePbZZwH44YcfuO666wgODsbV1ZUOHTrw0ksvYbVayx3jzDmcZfNX33zzTT7++GM6dOiAq6sr/fv3Z+vWreeMKSMjg6eeeoqePXvi5eWFj48P48aNY/v27eXalc2DXrBgAf/85z9p1aoVbm5ujB49mpiYmArHLYvF3d2dAQMG8Ntvv1XrNerRowcjR46ssN1ms9GyZUtuueUWx7Y333yTwYMH06RJE9zd3enbty/ffffdOc9R1Zzu6sRcVFTEc889R9++ffH19cXT05OhQ4eyevVqR5u4uDgCAwMBePHFFx1TC8rmUlc2h7mkpISXXnrJ8fMLCQnh2WefpbCwsFy7kJAQrr/+etavX8+AAQNwc3Ojffv2fP755+d83lWJiopi3Lhx+Pj44OXlxejRo/n999/LtSkuLubFF1+kU6dOuLm50aRJE6688kqWL1/uaJOcnMzUqVNp1aoVrq6utGjRghtuuIG4uLiznr+y18NkMvHoo4+yaNEievTogaurK927d2fp0qU1fp6rVq1i6NCheHp64ufnxw033MDevXsrtFuzZg39+vXDzc2NDh068NFHH1V73vnBgwe5+eabad68OW5ubrRq1Yrbb7+d7Ozscu2+/PJLBgwYgIeHB/7+/gwbNoxly5Y5Hq/u+0FlbDYb77zzDt27d8fNzY2goCAeeughMjMzq/EqiYg0PLosKyIi5yU9PZ1x48Zx++23c/fddxMUFATA3Llz8fLyYvr06Xh5ebFq1Sqee+45cnJyeOONN8553K+//poTJ07w0EMPYTKZeP3117nppps4fPjwWXvHDx8+zKJFi7j11ltp164dKSkpfPTRRwwfPpw9e/YQHBxcrv2//vUvzGYzTz31FNnZ2bz++uvcddddbN682dHms88+46GHHmLw4ME8/vjjHD58mAkTJhAQEEDr1q3P+jwmTZrECy+8QHJyMs2bN3dsX79+PUlJSdx+++2Obe+++y4TJkzgrrvuoqioiG+++YZbb72VxYsXc911153zNTtddWPOycnh008/5Y477uCBBx7gxIkTfPbZZ4wdO5YtW7YQHh5OYGAgH374IQ8//DA33ngjN910EwC9evWq8vx/+MMfmDdvHrfccgtPPvkkmzdv5tVXX2Xv3r18//335drGxMRwyy23cP/99zN58mRmz57NlClT6Nu3L927dz+v5717926GDh2Kj48PTz/9NM7Oznz00UeMGDGCtWvXMnDgQMCeGL/66qv84Q9/YMCAAeTk5BAREcG2bdu46qqrALj55pvZvXs3f/rTnwgJCSE1NZXly5cTHx9fo2Jf69evZ+HChTzyyCN4e3vz3nvvcfPNNxMfH0+TJk3O61grVqxg3LhxtG/fnhdeeIGTJ0/y/vvvM2TIELZt2+aILyoqimuuuYYWLVrw4osvYrVamTFjhuMiytkUFRUxduxYCgsL+dOf/kTz5s1JTExk8eLFZGVl4evrC9gvxLzwwgsMHjyYGTNm4OLiwubNm1m1ahVXX301cGHvBw899BBz585l6tSpPPbYY8TGxvLBBx8QFRXFhg0bNFpGRBofQ0REpBLTpk0zzvw3MXz4cAMwZs2aVaF9fn5+hW0PPfSQ4eHhYRQUFDi2TZ482Wjbtq3j+9jYWAMwmjRpYmRkZDi2//DDDwZg/PTTT2eNs6CgwLBareW2xcbGGq6ursaMGTMc21avXm0ARrdu3YzCwkLH9nfffdcAjJ07dxqGYRhFRUVGs2bNjPDw8HLtPv74YwMwhg8fftZ49u/fbwDG+++/X277I488Ynh5eZV7nc58zYqKiowePXoYo0aNKre9bdu2xuTJkys8l9WrV593zCUlJeXaGIZhZGZmGkFBQcZ9993n2Hb8+HEDMJ5//vkKz/H5558v97sRHR1tAMYf/vCHcu2eeuopAzBWrVpV7rkAxrp16xzbUlNTDVdXV+PJJ5+scK4znRnTxIkTDRcXF+PQoUOObUlJSYa3t7cxbNgwx7awsDDjuuuuq/K4mZmZBmC88cYb54zhTGe+HmVxuri4GDExMY5t27dvr/R340xlfxNz5sxxbAsPDzeaNWtmpKenlzue2Ww27r33Xse28ePHGx4eHkZiYqJj28GDBw0nJ6cKMZ4pKirKAIxvv/22yjYHDx40zGazceONN1b4u7PZbI77NX0/+O233wzA+Oqrr8rtu3Tp0kq3i4g0BhpeLiIi58XV1ZWpU6dW2O7u7u64f+LECdLS0hg6dCj5+fns27fvnMedNGkS/v7+ju+HDh0K2HuyzxWP2Wz/d2a1WklPT8fLy4suXbqwbdu2Cu2nTp1argDZmeeJiIggNTWVP/7xj+XaTZkyxdHTdzadO3cmPDyc+fPnO7ZZrVa+++47xo8fX+51Ov1+ZmYm2dnZDB06tNK4z+Z8YrZYLI42NpuNjIwMSkpK6Nev33mft8ySJUsAmD59erntTz75JAA///xzue2hoaGO1x0gMDCQLl26nPNnfSar1cqyZcuYOHEi7du3d2xv0aIFd955J+vXrycnJwcAPz8/du/ezcGDBys9lru7Oy4uLqxZs6bWhjGPGTOGDh06OL7v1asXPj4+5/08jx07RnR0NFOmTCEgIKDc8a666irH62+1WlmxYgUTJ04sN8KjY8eOjBs37pznKftd+fXXX8nPz6+0zaJFi7DZbDz33HOOv7sypw9fr+n7wbfffouvry9XXXUVaWlpjlvfvn3x8vIqNw1CRKSxUNItIiLnpWXLlpVWzd69ezc33ngjvr6++Pj4EBgY6CjCduZ80Mq0adOm3PdlCfi5EiCbzca///1vOnXqhKurK02bNiUwMJAdO3ZUet5znefIkSMAdOrUqVw7Z2fncond2UyaNIkNGzaQmJgI2OfYpqamMmnSpHLtFi9ezBVXXIGbmxsBAQGOYd3Veb1Od74xz5s3j169ejnmNgcGBvLzzz+f93lPP7/ZbKZjx47ltjdv3hw/Pz9HfGXO/BmA/edwvsnu8ePHyc/Pp0uXLhUe69atGzabjYSEBABmzJhBVlYWnTt3pmfPnvzlL39hx44djvaurq689tpr/PLLLwQFBTFs2DBef/11kpOTzyum09XW8yx7/ap6nmlpaeTl5ZGamsrJkycr/ByASredqV27dkyfPp1PP/2Upk2bMnbsWGbOnFnu9+LQoUOYzWZCQ0PPeqyavh8cPHiQ7OxsmjVrRmBgYLl
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAASTRJREFUeJzt3Xl8FOXhx/Hv5toEckDIRZAbBCmnAWJQxCMYLWBBREAqEBFaBRVStYDcKlc9qHJqETyKUqiIVYvSKFJ/jSIgCCqIKIJCAgRJQoAcu8/vD8iSzSaQQIYl4fN+vfYV9plnZp6ZnV32u88zMzZjjBEAAAAAAKh0Pt5uAAAAAAAA1RWhGwAAAAAAixC6AQAAAACwCKEbAAAAAACLELoBAAAAALAIoRsAAAAAAIsQugEAAAAAsAihGwAAAAAAixC6AQAAAACwCKEbAKqJPXv2yGazaenSpa6yKVOmyGazlWt+m82mKVOmVGqbbrjhBt1www2VukxcmoqOtcOHD3u7KbjElPbZBACXE0I3AHjB7bffrho1aignJ6fMOoMGDVJAQIAyMzMvYssq7ptvvtGUKVO0Z88ebzelSti/f7+mTJmiLVu2eLspbo4fP64pU6Zo3bp13m6KpTZs2KAHHnhAcXFx8vf3P+ePUosXL9ZVV12lwMBANW/eXC+88EKp9X755RfdddddqlWrlkJDQ/W73/1OP/zwgxWbAACoYgjdAOAFgwYN0okTJ7Rq1apSpx8/flyrV6/Wrbfeqjp16pz3eiZMmKATJ06c9/zl8c0332jq1Kmlhu4PP/xQH374oaXrr2r279+vqVOnXpKhe+rUqdU+dL///vv629/+JpvNpiZNmpy17qJFi3TffffpN7/5jV544QUlJCTooYce0qxZs9zqHTt2TDfeeKM++eQTjR8/XlOnTtWXX36pbt26XfI/mgEArEfoBgAvuP322xUSEqJly5aVOn316tXKzc3VoEGDLmg9fn5+CgwMvKBlXIiAgAAFBAR4bf1ASffff7+ysrK0ceNGde/evcx6J06c0OOPP64ePXpo5cqVGj58uF599VUNGjRITzzxhH799VdX3fnz52vXrl1699139dhjj2nMmDH68MMPdeDAAT3zzDMXY7MAAJcwQjcAeEFQUJDuuOMOpaam6uDBgx7Tly1bppCQEN1+++06cuSIHnnkEbVp00bBwcEKDQ3Vbbfdpq1bt55zPaWd052Xl6cxY8YoMjLStY6ff/7ZY96ffvpJDzzwgFq0aKGgoCDVqVNH/fr1c+vRXrp0qfr16ydJuvHGG2Wz2WSz2Vy9paWd033w4EENGzZM0dHRCgwMVLt27fTKK6+41Sk6B/Tpp5/Wiy++qKZNm8put6tTp0764osvzrndknT06FGNHj1a9evXl91uV7NmzTRr1iw5nU5JkjFGN954oyIjI91eg/z8fLVp00ZNmzZVbm5uufdF8fWOGTNGjRo1kt1u1xVXXKHBgwfr8OHDWrdunTp16iRJSk5Odu2vs53rWvQafvfdd/r973+vsLAwRUZGauLEiTLGaN++ffrd736n0NBQxcTEeIS8/Px8TZo0SXFxcQoLC1PNmjXVtWtXffzxx277OzIyUpI0depUV7uKn+O/Y8cO3XXXXYqMjFRQUJBatGihxx9/vNTtHzp0qGrVqqWwsDAlJyfr+PHjZ3+xTluxYoXi4uIUFBSkiIgI/f73v9cvv/ziVmfo0KEKDg7WL7/8ot69eys4OFiRkZF65JFH5HA4zrmO6OhoBQUFnbPexx9/rMzMTD3wwANu5SNHjlRubq7ee+89V9nKlSvVqVMn12srSS1bttTNN9+sf/zjH+dclyS9/vrrrm0PDw/XgAEDtG/fPtf0JUuWyGaz6eWXX3abb/r06bLZbHr//fddZU8//bS6dOmiOnXqKCgoSHFxcVq5cqXHOm02m0aNGqUVK1aoVatWCgoKUkJCgrZt2ybpVE9/s2bNFBgYqBtuuMHjeL/hhhvUunVrbdq0SV26dFFQUJAaN26shQsXlmubd+zYoTvvvFPh4eEKDAxUx44d9c4777jVKSgo0NSpU9W8eXMFBgaqTp06uu6667R27dpyrQMALgkGAOAVH374oZFkXnjhBbfyzMxM4+/vbwYPHmyMMeaLL74wTZs2NWPHjjWLFi0y06ZNM/Xq1TNhYWHml19+cc33448/GklmyZIlrrLJkyebkh/1v//9740kc/fdd5u5c+eaO+64w7Rt29ZIMpMnT3bVW7FihWnXrp2ZNGmSefHFF8348eNN7dq1TcOGDU1ubq4xxpjdu3ebhx56yEgy48ePN6+99pp57bXXTHp6ujHGmG7duplu3bq5lnn8+HFz1VVXGX9/fzNmzBjz/PPPm65duxpJZs6cOR7b0qFDB9OsWTMza9YsM3v2bBMREWGuuOIKk5+ff9Z9m5uba9q2bWvq1Kljxo8fbxYuXGgGDx5sbDabefjhh131fvjhBxMcHGz69OnjKhs7dqyx2Wzmk08+qdC+MMaYnJwc07p1a+Pr62uGDx9uFixYYJ544gnTqVMn8+WXX5r09HQzbdo0I8mMGDHCtb92795d5rYUvYbt27c3AwcONPPnzzc9evQwksyzzz5rWrRoYe6//34zf/58c+211xpJbm0/dOiQqVu3rklJSTELFiwws2fPNi1atDD+/v7myy+/NMYYc+zYMbNgwQIjyfTp08fVrq1btxpjjNm6dasJDQ01derUMePGjTOLFi0yjz32mGnTpo1HOzt06GDuuOMOM3/+fHPfffcZSeaxxx476+tljDFLliwxkkynTp3Mc889Z8aOHWuCgoJMo0aNzK+//uqqN2TIEBMYGGh+85vfmHvvvdcsWLDA9O3b10gy8+fPP+d6ihs5cqTH+6PIk08+aSSZjIwMt/K8vDzj4+NjUlJSjDHGOBwOY7fbzf333++xjAkTJhhJJjs7+6ztePLJJ43NZjP9+/c38+fPN1OnTjUREREe296zZ08TFhZm9u7da4wx5quvvjIBAQFm2LBhbsu74oorzAMPPGDmzp1rnn32WdO5c2cjybz77rtu9SSZtm3bmvr165uZM2eamTNnmrCwMNOgQQMzd+5c06pVK/PMM8+YCRMmmICAAHPjjTe6zd+tWzcTGxtroqKizKhRo8zzzz9vrrvuOiPJLF682FWvtM+m7du3m7CwMNOqVSsza9YsM3fuXHP99dcbm81m3nrrLVe98ePHG5vNZoYPH25eeukl88wzz5iBAweamTNnnnWfAsClhNANAF5SWFho6tataxISEtzKFy5caCSZDz74wBhjzMmTJ43D4XCr8+OPPxq73W6mTZvmVnau0L1lyxYjyTzwwANuy7v77rs9Qvfx48c92pyWlmYkmVdffdVVtmLFCiPJfPzxxx71S4buOXPmGEnm9ddfd5Xl5+ebhIQEExwc7AonRdtSp04dc+TIEVfd1atXG0nmX//6l8e6inviiSdMzZo1zXfffedWPnbsWOPr6+sKLcYYs2jRIlebPvvsM+Pr62tGjx7tNl9598WkSZOMJLfQUMTpdBpjTv2IUvJ1Opui13DEiBGussLCQnPFFVcYm83mFj5+/fVXExQUZIYMGeJWNy8vz22Zv/76q4mOjjb33nuvq+zQoUMex0CR66+/3oSEhJiffvqp1G0q3s7iyzTGmD59+pg6deqcdRvz8/NNVFSUad26tTlx4oSr/N133zWSzKRJk1xlQ4YMMZLcjn1jjOnQoYOJi4s763pKOlvoHjlypPH19S11WmRkpBkwYIAx5sx+K9keY4yZN2+ekWR27NhRZhv27NljfH19zVNPPeVWvm3bNuPn5+dWfuDAARMeHm66d+9u8vLyTIcOHUyDBg1MVlaW27wlj9f8/HzTunVrc9NNN7mVSzJ2u938+OOPrrKi90NMTIzbjwXjxo0zktzqduvWzUgyzzzzjKssLy/PtG/f3kRFRbl+HCvts+nmm282bdq0MSdPnnSVOZ1O06VLF9O8eXNXWbt27UyPHj3K2n0AUCUwvBwAvMTX11cDBgxQWlqa27DNZcuWKTo
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAApmdJREFUeJzs3XeYU1X6B/DvTZneGZgCwzB0kCZFigUV7HUF2+Jv7RV27a6uYluVtezqWlDXdcXdFXftuhYsgFhoAkrvMLQpwAzTW8r9/ZGcm5ubm8zNJCGTzPfzPPMoyUzm5s69J+ec9z3vkWRZlkFEREREREREYWeK9gEQERERERERxSsOuomIiIiIiIgihINuIiIiIiIiogjhoJuIiIiIiIgoQjjoJiIiIiIiIooQDrqJiIiIiIiIIoSDbiIiIiIiIqII4aCbiIiIiIiIKEI46CYiIiIiIiKKEA66iYiISFefPn1w1VVXKf/+9ttvIUkSvv32W+Wxq666Cn369Dnqxxat30tERBQsDrqJiIg6aPfu3Zg1axYGDhyIlJQUpKSkYOjQoZg5cybWrVvn9b0PP/wwJElSvsT3PvDAA6irq/P5vsOHD+v+zmHDhuHkk08O23tYunQpHn74YdTU1ITtNcOlrKwMDz/8MH755ZdoHwoREVGHWaJ9AERERLHo008/xaWXXgqLxYIZM2Zg5MiRMJlM2LJlCz744AO8/PLL2L17N4qLi71+7uWXX0ZaWhoaGhrw1Vdf4fHHH8eiRYvw448/QpKko/4+li5dikceeQRXXXUVsrKyvJ7bunUrTKbA8/OvvfYanE5nRI6trKwMjzzyCPr06YNRo0Ydtd9LREQUThx0ExERBWnnzp247LLLUFxcjIULF6KgoMDr+SeffBJz587VHbBOnz4dubm5AICbbroJ06ZNwwcffIDly5dj4sSJR+X4jUpMTGz3e6xW61E4ks7ze4mIiILF9HIiIqIgPfXUU2hsbMQbb7zhM+AGAIvFgt/97ncoKipq97VOPfVUAK5U9XApLS2FJEmYN2+ez3OSJOHhhx8G4Eplv/vuuwEAJSUlSup7aWkpAN813Xq0a6tPPvlkrzR69Zc4nurqatx1110YPnw40tLSkJGRgbPOOgtr165VXufbb7/FuHHjAABXX321z2vorelubGzEnXfeiaKiIiQmJmLQoEF45plnIMuyzzmYNWsWPvroIwwbNgyJiYk45phjsGDBgoDvlYiIqCMY6SYiIgrSp59+iv79+2P8+PEhv9bOnTsBAN26dQv5tYJ10UUXYdu2bXj77bfx7LPPKhH47t27d/g177//flx33XVej/373//Gl19+iR49egAAdu3ahY8++ggXX3wxSkpKUFlZiVdffRWTJ0/Gpk2bUFhYiCFDhuDRRx/Fgw8+iBtuuAEnnngiAGDSpEm6v1eWZZx//vlYvHgxrr32WowaNQpffvkl7r77bhw4cADPPvus1/f/8MMP+OCDD3DLLbcgPT0dzz//PKZNm4a9e/dG5W9BRETxi4NuIiKiINTV1aGsrAwXXnihz3M1NTWw2+3Kv1NTU5GcnOz1PdXV1QCgrOmeO3cu8vLylEHl0TRixAiMHj0ab7/9Ni688MKwVAM/7bTTvP69dOlSLFq0CNdccw3OPvtsAMDw4cOxbds2r/T7//u//8PgwYPx+uuvY/bs2cjLy8NZZ52FBx98EBMnTsQVV1wR8Pd+8sknWLRoER577DHcf//9AICZM2fi4osvxl//+lfMmjUL/fr1U75/8+bN2LRpk/LYKaecgpEjR+Ltt9/GrFmzQj4PREREAtPLiYiIgiAqjaelpfk8d/LJJ6N79+7K10svveTzPYMGDUL37t1RUlKCG2+8Ef3798dnn32GlJSUiB/70VZRUYHp06dj1KhRmDt3rvJ4YmKiMuB2OByoqqpCWloaBg0ahDVr1nTod33++ecwm8343e9+5/X4nXfeCVmW8cUXX3g9PnXqVK9B+IgRI5CRkYFdu3Z16PcTERH5w0g3ERFRENLT0wG4ItVar776Kurr61FZWek3Mvv+++8jIyMDVqsVvXr18hr4GRWNKufBstvtuOSSS+BwOPDBBx94FWVzOp3461//irlz52L37t1wOBzKcx1N7d6zZw8KCwuVv48wZMgQ5Xm13r17+7xGdnY2jhw50qHfT0RE5A8H3UREREHIzMxEQUEBNmzY4POcWOMtCpHpOemkk5S103qSkpIAAM3NzbrPNzU1Kd/jj79BuXpwG2l33303li1bhm+++Qa9evXyeu6JJ57A7Nmzcc011+CPf/wjcnJyYDKZcNtttx21bcDMZrPu49qia0RERKFiejkREVGQzjnnHOzYsQMrV64M+2uLfb23bt3q81xTUxP27dvns/e3VnZ2NgDXGnM1bbQXiEzU/D//+Q+ee+45PPPMM5g8ebLP8++99x5OOeUUvP7667jssstw+umnY+rUqT7HG8yxFRcXo6ysDPX19V6Pb9myRXmeiIgoGjjoJiIiCtI999yDlJQUXHPNNaisrPR5PpRo6ZQpU5CQkICXX37ZJ+r7t7/9DXa7HWeddVbA18jIyEBubi6+++47r8fV66qF1NRUAL4D9I7asGEDrrvuOlxxxRW49dZbdb/HbDb7nKN3330XBw4c6PCxnX322XA4HHjxxRe9Hn/22WchSVK754yIiChSmF5OREQUpAEDBmD+/Pm4/PLLMWjQIMyYMQMjR46ELMvYvXs35s+fD5PJ5JNWbUSPHj3w4IMP4oEHHsBJJ52E888/HykpKVi6dCnefvttnH766TjvvPPafZ3rrrsOf/rTn3Dddddh7Nix+O6777Bt2zaf7xszZgwA11Zfl112GaxWK8477zxlwBusq6++GoArjf7f//6313OTJk1C3759ce655+LRRx/F1VdfjUmTJmH9+vV466230LdvX6/v79evH7KysvDKK68gPT0dqampGD9+PEpKSnx+73nnnYdTTjkF999/P0pLSzFy5Eh89dVX+Pjjj3Hbbbd1aO08ERFROHDQTURE1AEXXHAB1q9fjz//+c/46quv8I9//AOSJKG4uBjnnHMObrrpJowcObJDr33//fejT58+ePHFF/Hoo4/CbrejpKQEjzzyCH7/+997bbXlz4MPPohDhw7hvffewzvvvIOzzjoLX3zxhbJXtjBu3Dj88Y9/xCuvvIIFCxbA6XRi9+7dHR50Hzp0CI2Njbjhhht8nnvjjTfQt29f/OEPf0BjYyPmz5+P//73vxg9ejQ+++wz3HvvvV7fb7Va8eabb+K+++7DTTfdBLvdjjfeeEN30G0ymfDJJ5/gwQcfxH//+1+88cYb6NOnD55++mnceeedHXovRERE4SDJrBhCREREREREFBFc001EREREREQUIRx0ExEREREREUUIB91EREREREREEcJBNxEREREREVGEcNBNREREREREFCEcdBMRERERERFFCPfpBuB0OlFWVob09HRIkhTtwyEiIiIiIqJOTpZl1NfXo7CwECZTgHi2HEVLliyRzz33XLmgoEAGIH/44YdezzudTnn27Nlyfn6+nJSUJE+ZMkXetm2b1/dUVVXJv/71r+X09HQ5MzNTvuaaa+T6+vqgjmPfvn0yAH7xi1/84he/+MUvfvGLX/ziF7+C+tq3b1/A8WZUI92NjY0YOXIkrrnmGlx00UU+zz/11FN4/vnn8eabb6KkpASzZ8/GGWecgU2bNiEpKQkAMGPGDJSXl+Prr7+GzWbD1VdfjRtuuAHz5883fBzp6ekAgH379iEjIyM8b46IiIiIiIjiVl1dHYqKipTxpD+SLMvyUTqmgCRJwocffogLL7wQACDLMgoLC3HnnXfirrvuAgDU1tYiLy8P8+bNw2WXXYbNmzdj6NCh+OmnnzB27FgAwIIFC3D22Wdj//79KCwsNPS76+rqkJmZidraWg66iYiIiIiIqF1Gx5GdtpDa7t27UVFRgalTpyqPZWZmYvz48Vi2bBkAYNmyZcjKylIG3AAwdepUmEwmrFixwu9rt7a2oq6uzuuLiIiIiIiIKNw67aC
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAATY5JREFUeJzt3XtcVHX+x/H3cEdlUFRAApW0VUnTNFNyU0sCC9vcqNRc81pp4K7Sprl5qfZiWW03TWvbVX9bdnF3NZPSWBCsxCzLUktKo7DFEVJhBJXbnN8fNicnyDszyLyejwePbc75zJnPmcdZv3z43iyGYRgCAAAAAADnnY+nEwAAAAAAoKmi6AYAAAAAoIFQdAMAAAAA0EAougEAAAAAaCAU3QAAAAAANBCKbgAAAAAAGghFNwAAAAAADYSiGwAAAACABkLRDQAAAABAA6HoBgAAcJNx48apY8eOnk4DAOBGFN0AAPyMgoICpaWl6Re/+IWaNWumZs2aKS4uTqmpqfrss89cYh988EFZLBbzxxk7e/Zs2e32OnHff/99vZ/ZvXt3DR48uCFvCwAAuJGfpxMAAKAxWrt2rUaMGCE/Pz+NHj1aPXv2lI+Pj3bt2qX//Oc/Wrx4sQoKCtShQweX9y1evFgtWrRQeXm53nnnHf35z39Wdna23n//fVksFg/dDQAA8BSKbgAAfmLPnj0aOXKkOnTooKysLLVr187l/KOPPqrnnntOPj51B4zdcsstatOmjSRp8uTJSklJ0X/+8x9t3rxZ8fHxbsm/qaipqZHD4VBAQICnUwEA4KwxvBwAgJ9YsGCBKioqtHTp0joFtyT5+fnpt7/9rWJiYk55rWuvvVbS8aHq55PFYlFaWppWrlypuLg4BQcHKz4+Xtu3b5ckPf/88+rcubOCgoI0ePBgffPNN3Wu8cEHH2jo0KEKDQ1Vs2bNNGjQIL3//vsuMc7h8F9++aV+85vfKDQ0VG3bttWcOXNkGIb27t2rm266SVarVZGRkXriiSfqfE5xcbEmTpyoiIgIBQUFqWfPnlq+fLlLzDfffCOLxaLHH39cTz31lDp16qTAwEBt2bJFzZs31+9+97s61/3uu+/k6+ur+fPn/+z3lJOTI4vFopycnHo/b9myZeYxm82m8ePHKzo6WoGBgWrXrp1uuummOt/d22+/rauvvlrNmzdXSEiIkpOTtXPnzjqfvXr1anXv3l1BQUHq3r27Vq1a9bN5AgCaLnq6AQD4ibVr16pz587q16/fOV9rz549kqTWrVuf87V+6t1339WaNWuUmpoqSZo/f76GDRumGTNm6LnnntM999yjQ4cOacGCBZowYYKys7PN92ZnZ+v6669Xnz59NG/ePPn4+Gjp0qW69tpr9e677+rKK690+awRI0aoW7dueuSRR5SRkaE//elPCgsL0/PPP69rr71Wjz76qF5++WX9/ve/V9++fTVw4EBJ0tGjRzV48GDt3r1baWlpio2N1cqVKzVu3DiVlpbWKaaXLl2qY8eO6a677lJgYKDat2+vX//613rttdf017/+Vb6+vmbsK6+8IsMwNHr06PPyfaakpGjnzp2aOnWqOnbsqOLiYmVmZqqwsNBc/Oyf//ynxo4dq6SkJD366KM6cuSIFi9erF/+8pf65JNPzLh33nlHKSkpiouL0/z583XgwAGzoAcAeBkDAACYysrKDEnG8OHD65w7dOiQUVJSYv4cOXLEPDdv3jxDkpGfn2+UlJQYBQUFxvPPP28EBgYaERERRkVFhUtcSUlJvZ9/6aWXGoMGDTplnpKMwMBAo6CgwDz2/PPPG5KMyMhIw263m8dnzZplSDJjHQ6HcckllxhJSUmGw+Ew444cOWLExsYa1113XZ37uuuuu8xjNTU1RnR0tGGxWIxHHnnE5fsJDg42xo4dax576qmnDEnGSy+9ZB6rqqoy4uPjjRYtWph5FhQUGJIMq9VqFBcXu9zr+vXrDUnG22+/7XL8sssuO+V3tWHDBkOSsWHDBpfjzs9bunSpmbsk47HHHvvZax0+fNho2bKlceedd7oct9lsRmhoqMvxXr16Ge3atTNKS0vNY++8844hyejQocNJcwYANC0MLwcA4ATOlcZbtGhR59zgwYPVtm1b82fRokV1Yrp06aK2bdsqNjZWd999tzp37qyMjAw1a9bsvOc6ZMgQl+2nnD3zKSkpCgkJqXP866+/liRt27ZNX331lW6//XYdOHBA33//vb7//ntVVFRoyJAh2rhxoxwOh8tnTZo0yfxvX19fXXHFFTIMQxMnTjSPt2zZUl26dDE/R5LeeustRUZGatSoUeYxf39//fa3v1V5eblyc3NdPiclJUVt27Z1OZaQkKCoqCi9/PLL5rEdO3bos88+029+85vT+7JOITg4WAEBAcrJydGhQ4fqjcnMzFRpaalGjRplfmfff/+9fH191a9fP23YsEGStG/fPm3btk1jx45VaGio+f7rrrtOcXFx5yVfAMCFg+HlAACcwFmslpeX1zn3/PPP6/Dhw9q/f//PFnv//ve/ZbVa5e/vr+joaHXq1OmMczjdVc7bt2/v8tpZ4P10rrnzuLOY/OqrryRJY8eO/dlrl5WVqVWrVif9rKCgIHPRuBOPHzhwwHz97bff6pJLLqmz6Fy3bt3M8yeKjY2tk4uPj49Gjx6txYsX68iRI2rWrJlefvllBQUF6dZbb/3ZezgTgYGBevTRR3XvvfcqIiJC/fv317Bhw3THHXcoMjJS0o/fm3Oe/k9ZrVaXe7rkkkvqxHTp0kUff/zxeckZAHBhoOgGAOAEoaGhateunXbs2FHnnLPHuL5FyZwGDhxYpxA9UVBQkKTjc53rc+TIETPmVE6c33w6xw3DkCSzF/uxxx5Tr1696o39aU9/fdc81eecjeDg4HqP33HHHXrssce0evVqjRo1SitWrNCwYcNcepLr83N/wKitra1zbNq0abrxxhu1evVqrV+/XnPmzNH8+fOVnZ2tyy+/3Pze/vnPf5qF+In8/Pi1CgBQF60DAAA/kZycrBdffFFbtmyps6DYuXLu652fn1+nR/rIkSPau3evEhMTz+tn/pSz991qtSohIaFBP6tDhw767LPP5HA4XHq7d+3aZZ4/Hd27d9fll1+ul19+WdHR0SosLNSzzz57yvc5e+tLS0tdjv+0h92pU6dOuvfee3Xvvffqq6++Uq9evfTEE0/opZdeMr+38PDwk35vznty9oyfKD8//5Q5AwCaFuZ0AwDwEzNmzFCzZs00YcIE7d+/v875c+nJHTJkiAICArR48eI686ZfeOEF1dTU6Prrrz/r65+OPn36qFOnTnr88cfrHUZfUlJy3j7rhhtukM1m02uvvWYeq6mp0bPPPqsWLVpo0KBBp32tMWPG6J133tFTTz2l1q1bn9b31KFDB/n6+mrjxo0ux5977jmX10eOHNGxY8dcjnXq1EkhISGqrKyUJCUlJclqteovf/mLqqur63yW83tr166devXqpeXLl6usrMw8n5mZqc8///z0bhYA0GTQ0w0AwE9ccsklWrFihUaNGqUuXbpo9OjR6tmzpwzDUEFBgVasWCEfH5+z2v4pPDxcc+fO1ezZszVw4ED96le/UrNmzbRp0ya98sorSkxM1I033tgAd/UjHx8fvfjii7r++ut16aWXavz48brooov0v//9Txs2bJDVatWbb755Xj7rrrvu0vPPP69x48Zp69at6tixo/71r3/p/fff11NPPeWy4Nup3H777ZoxY4ZWrVqlKVOmyN/f/5TvCQ0N1a233qpnn31WFotFnTp10tq1a1VcXOwS9+WXX2rIkCG67bbbFBcXJz8/P61atUr79+/XyJEjJR0fGbB48WKNGTNGvXv31siRI9W2bVsVFhYqIyNDAwYM0MKFCyUd374tOTlZv/zlLzVhwgQdPHhQzz77rC699NJ6/9ABAGi6KLoBAKjHTTfdpO3bt+uJJ57QO++8o3/84x+yWCzq0KGDkpOTNXnyZPXs2fOsrv3AAw+oY8eOWrhwoR5++GHV1NQoNjZWDz30kGbOnFln0bGGMHjwYOX
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"log_history = pd.DataFrame(trainer.state.log_history)\n",
"gpu_df = pd.DataFrame(gpu_history)\n",
"\n",
"log_history.to_csv(run_dir / \"trainer_log_history.csv\", index=False)\n",
"gpu_df.to_csv(run_dir / \"gpu_history.csv\", index=False)\n",
"\n",
"train_logs = log_history[log_history[\"loss\"].notna()].copy()\n",
"eval_logs = log_history[log_history[\"eval_loss\"].notna()].copy()\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"if not train_logs.empty:\n",
" plt.plot(train_logs[\"step\"], train_logs[\"loss\"])\n",
"if not eval_logs.empty:\n",
" plt.plot(eval_logs[\"step\"], eval_logs[\"eval_loss\"])\n",
"plt.title(\"Train and validation loss\")\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend([\"train\", \"eval\"])\n",
"plt.tight_layout()\n",
"plt.savefig(run_dir / \"loss_curve.png\", dpi=160)\n",
"plt.show()\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"if not train_logs.empty:\n",
" plt.semilogy(train_logs[\"step\"], train_logs[\"loss\"])\n",
"if not eval_logs.empty:\n",
" plt.semilogy(eval_logs[\"step\"], eval_logs[\"eval_loss\"])\n",
"plt.title(\"Train and validation loss in log scale\")\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend([\"train\", \"eval\"])\n",
"plt.tight_layout()\n",
"plt.savefig(run_dir / \"loss_curve.png\", dpi=160)\n",
"plt.show()\n",
"\n",
"acc_logs = log_history[log_history.get(\"eval_exact_match_100\", pd.Series(dtype=float)).notna()].copy()\n",
"if not acc_logs.empty:\n",
" plt.figure(figsize=(10, 4))\n",
" plt.plot(acc_logs[\"step\"], acc_logs[\"eval_exact_match_100\"])\n",
" plt.title(\"Validation exact match on 100 examples\")\n",
" plt.xlabel(\"Step\")\n",
" plt.ylabel(\"Accuracy\")\n",
" plt.ylim(0.0, 1.0)\n",
" plt.tight_layout()\n",
" plt.savefig(run_dir / \"val_exact_match.png\", dpi=160)\n",
" plt.show()\n",
"\n",
"if not gpu_df.empty and \"gpu_util\" in gpu_df.columns:\n",
" plt.figure(figsize=(10, 4))\n",
" plt.plot(gpu_df[\"step\"], gpu_df[\"gpu_util\"])\n",
" plt.title(\"GPU utilization\")\n",
" plt.xlabel(\"Step\")\n",
" plt.ylabel(\"GPU %\")\n",
" plt.tight_layout()\n",
" plt.savefig(run_dir / \"gpu_utilization.png\", dpi=160)\n",
" plt.show()\n",
"\n",
"if not gpu_df.empty and \"mem_used_mb\" in gpu_df.columns:\n",
" plt.figure(figsize=(10, 4))\n",
" plt.plot(gpu_df[\"step\"], gpu_df[\"mem_used_mb\"])\n",
" plt.title(\"GPU memory used\")\n",
" plt.xlabel(\"Step\")\n",
" plt.ylabel(\"MB\")\n",
" plt.tight_layout()\n",
" plt.savefig(run_dir / \"gpu_memory_used.png\", dpi=160)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "78f101f6",
"metadata": {},
"source": [
"## Full test generation"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d17ef25d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 165/165 [10:16<00:00, 3.73s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test exact match: 0.0955\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjksIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvJkbTWQAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQB9JREFUeJzt3Xt8j/X/x/HnZ2YH47M5bTPNmRjSN6dGRJbRUr7pK1ksCbGpSA4/MSop9U2pRWd9Q0lKjqs1iTJnCiFETt8NzTYjO16/P7rt+vZpI5td++zwuN9un9vN3u/X9ble74/rpp67rs912QzDMAQAAAAAAIqdi7MbAAAAAACgvCJ0AwAAAABgEUI3AAAAAAAWIXQDAAAAAGARQjcAAAAAABYhdAMAAAAAYBFCNwAAAAAAFiF0AwAAAABgEUI3AAAAAAAWIXQDAACUATabTVFRUc5uAwBQSIRuAECZcOTIEUVFRalZs2aqUqWKqlSpoqCgIEVGRurHH3/MV//dd9+pd+/eqlu3rjw8PFSvXj316dNHixYtcqiz2Wyy2Wx6+OGHC9zv5MmTzZqzZ886zH399dfq3r27atWqJR8fH3Xo0EEffvihQ83Ro0dls9n00ksvOYwbhqERI0bIZrNp2rRpl133xYsXNW3aNK1bt+4Kn861W7169RX7KEsWLVqkV155xdlt5FOePmMAwNUjdAMASr2VK1eqVatW+vDDDxUSEqLZs2fr1VdfVe/evbV69WrdeOON+vXXX836JUuWqGvXrkpKStJjjz2m1157TQ888IDOnTunt99+O9/7e3h4aOnSpcrMzMw399FHH8nDwyPf+PLly9WzZ09lZmZq2rRpmjFjhjw9PTV48GDNnj37iusxDEOjRo3SW2+9pSlTpvxt6J4+fXqJhO7p06dbuo+SUppDd3n5jAEAV8/V2Q0AAHAlhw8f1oABA1S/fn3Fx8erTp06DvMvvPCC3njjDbm4/O/3yNOmTVNQUJA2bdokNzc3h/rTp0/n20evXr20fPlyrVmzRnfffbc5vnHjRh05ckT9+vXT0qVLHbZ5/fXXVadOHa1du1bu7u6SpBEjRqh58+aaP3++xowZc9k1jR49WvPmzdPkyZP19NNPX/2HAQAAyhzOdAMASrVZs2bpwoULev/99/MFbklydXXVo48+qsDAQHPs8OHDat++fb7ALUm+vr75xurWrauuXbvmu/R84cKFat26tVq1apVvm7S0NFWvXt0M3Hm91KpVS56enpddz2OPPaaYmBhNmjRJzz777GXrpD8uTa9du7Ykafr06eZl7n8+M75//37de++9qlGjhjw8PNSuXTstX77c4X2ysrI0ffp0NW3aVB4eHqpZs6ZuueUWxcXFSZIefPBBxcTESPrf5fY2m+2KvUnSmjVr1KVLF3l5ealatWoKCwvT3r17zfm1a9fKxcVFU6dOddhu0aJFstlsmjt3rjn2/vvv67bbbpOvr6/c3d0VFBTkMP/X/d56662qVq2a7Ha72rdvb/7ddevWTatWrdKvv/5qrqNBgwZXXEfed6WXLFmioKAgeXp6Kjg4WLt375Ykvfnmm2rSpIk8PDzUrVs3HT161GH7DRs26F//+pfq1asnd3d3BQYGasyYMfr999/Nmr/7jHNzc/Xqq6+qdevW8vDwUO3atdWrVy9t27YtX7/Lli1Tq1at5O7urpYtWyo2NvaK6wMAOBdnugEApdrKlSvVpEkTdezY8aq3yTsrfuLECV133XVXtc3AgQP12GOPKT09XVWrVlV2draWLFmisWPH6tKlS/nqu3XrphdeeEFTpkxRRESEbDabFi1apG3btumTTz4pcB9jxozRnDlzNGHCBD333HN/21Pt2rU1d+5cjRw5Uv/85z91zz33SJJuuOEGSdLevXvVuXNn1a1bVxMnTpSXl5c++eQT9e3bV0uXLtU///lPSX+c+Z85c6YefvhhdejQQWlpadq2bZt27Nih22+/XSNGjNCpU6cUFxeX7zvpl/Phhx8qIiJCoaGheuGFF3Tx4kXNnTtXt9xyi3bu3KkGDRrotttu06hRozRz5kz17dtXN910k/773/9q9OjRCgkJ0SOPPGK+39y5c9WyZUvdddddcnV11YoVKzRq1Cjl5uYqMjLSrJs/f74eeughtWzZUpMmTZKPj4927typ2NhYDRw4UJMnT1ZqaqpOnDhhXuZftWrVv13Phg0btHz5cnNfM2fO1J133qnx48frjTfe0KhRo3Tu3DnNmjVLDz30kNauXWtuu2TJEl28eFEjR45UzZo1tWXLFr322ms6ceKElixZIkl/+xkPHTpU8+fPV+/evfXwww8rOztbGzZs0KZNm9SuXTuz7rvvvtNnn32mUaNGqVq1apozZ4769eunY8eOqWbNmlf1dwcAKGEGAAClVGpqqiHJ6Nu3b765c+fOGWfOnDFfFy9eNOfeffddQ5Lh5uZmdO/e3ZgyZYqxYcMGIycnJ9/7SDIiIyON5ORkw83Nzfjwww8NwzCMVatWGTabzTh69KgRHR1tSDLOnDljbpeenm7079/fsNlshiRDklGlShVj2bJlDu9/5MgRQ5JRv359Q5Lx5JNPFuozOHPmjCHJiI6OzjfXo0cPo3Xr1salS5fMsdzcXKNTp05G06ZNzbE2bdoYYWFhV9xPZGSkcbX/W3D+/HnDx8fHGDZsmMN4YmKi4e3t7TB+4cIFo0mTJkbLli2NS5cuGWFhYYbdbjd+/fVXh23//PeXJzQ01GjUqJH5c0pKilGtWjWjY8eOxu+//+5Qm5uba/45LCzMqF+//lWtxTD+OAbc3d2NI0eOmGNvvvmmIcnw9/c30tLSzPFJkyYZkhxqC+p95syZhs1mc1jn5T7jtWvXGpKMRx99NN/cn9eVd0wfOnTIHPvhhx8MScZrr7121esFAJQsLi8HAJRaaWlpkgo+U9mtWzfVrl3bfOVduitJDz30kGJjY9WtWzd99913euaZZ9SlSxc1bdpUGzduLHBf1atXV69evfTRRx9J+uMS6E6dOql+/foF1ru7u6tZs2a699579dFHH2nBggVq166dHnjgAW3atClffVJSkiSpWbNmhfsQLiM5OVlr165V//79df78eZ09e1Znz57Vb7/9ptDQUB08eFAnT56UJPn4+Gjv3r06ePBgsew7Li5OKSkpuv/++839nj17VpUqVVLHjh31zTffmLVVqlTR/PnztW/fPnXt2lWrVq3S7NmzVa9ePYf3/PMl+ampqTp79qxuvfVW/fLLL0pNTTX3e/78eU2cODHfze2u5nL4K+nRo4fDZeh5V1b069dP1apVyzf+yy+/FNj7hQsXdPbsWXXq1EmGYWjnzp1/u++lS5fKZrMpOjo639xf1xUSEqLGjRubP99www2y2+0O/QAAShcuLwcAlFp5YSc9PT3f3Jtvvqnz588rKSlJDzzwQL750NBQhYaG6uLFi9q+fbsWL16sefPm6c4779T+/fsL/G73wIEDNWjQIB07dkzLli3TrFmzLttbVFSUNm3apB07dpg3cevfv79atmypxx57TJs3b3aonzBhglavXq0RI0bIx8dH9957b6E+i786dOiQDMPQlClTNGXKlAJrTp8+rbp16+rpp5/W3XffrWbNmqlVq1bq1auXBg0aZF6mXlh54f22224rcN5utzv83LlzZ40cOVIxMTEKDQ3VQw89lG+b77//XtHR0UpISNDFixcd5lJTU+Xt7a3Dhw9LUoHfsb9Wf/0lgLe3tyQ53Cvgz+Pnzp0zx44dO6apU6dq+fLlDuN5vf+dw4cPKyAgQDVq1Ch0n9IfvzD6634BAKUHoRsAUGp5e3urTp062rNnT765vDOOf72p1V9VqVJFXbp0UZcuXVSrVi1Nnz5da9asUURERL7au+66S+7u7oqIiFBGRob69+9f4HtmZmbq3Xff1fjx4x3uml65cmX17t1br7/+ujIzMx1u5Fa1alWtWbNGXbt2VXh4uOx2u3r27Hk1H0OBcnNzJUnjxo1TaGhogTVNmjSRJHXt2lWHDx/WF198oa+++krvvPOOZs+erXnz5l3
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"merged_model = trainer.model.merge_and_unload()\n",
"merged_model.eval()\n",
"\n",
"test_eval_df, test_exact_match = evaluate_exact_match(\n",
" merged_model,\n",
" tokenizer,\n",
" raw[\"test\"],\n",
" n_examples=None,\n",
" batch_size=8,\n",
" max_new_tokens=512,\n",
")\n",
"\n",
"print(f\"test exact match: {test_exact_match:.4f}\")\n",
"test_eval_df.to_csv(run_dir / \"gsm8k_test_predictions.csv\", index=False)\n",
"\n",
"plt.figure(figsize=(10, 4))\n",
"plt.hist(test_eval_df[\"correct\"].astype(int), bins=2)\n",
"plt.title(\"GSM8K test exact match\")\n",
"plt.xlabel(\"Correct\")\n",
"plt.ylabel(\"Count\")\n",
"plt.tight_layout()\n",
"plt.savefig(run_dir / \"test_exact_match_hist.png\", dpi=160)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "4a5c83b4",
"metadata": {},
"source": [
"## Publish files"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ed612f1b",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "883914a08e4d4d16b677473e251b5df9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Writing model shards: 0%| | 0/5 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"---\n",
"library_name: transformers\n",
"base_model: MegaScience/Qwen3-4B-MegaScience\n",
"datasets:\n",
" - openai/gsm8k\n",
"tags:\n",
" - qwen3\n",
" - gsm8k\n",
" - lora\n",
" - reasoning\n",
" - math\n",
" - benchmark\n",
"license: apache-2.0\n",
"---\n",
"\n",
"# Qwen3-4B-MegaScience GSM8K fine-tune\n",
"\n",
"## Overview\n",
"\n",
"`MegaScience/Qwen3-4B-MegaScience` is a 4B Qwen3 modification trained on the MegaScience dataset. We fine-tuned it on GSM8K, which contains common-sense tasks for school students.\n",
"\n",
"## Dataset\n",
"\n",
"The official GSM8K train split is used for training and validation. We modify it by adding <think> and <answer> tokens so the model formats the answer more correctly.\n",
"\n",
"## Training setup\n",
"\n",
"Base model: `MegaScience/Qwen3-4B-MegaScience`. \n",
"Fine-tuning style: SFT with LoRA. \n",
"Precision: bf16. \n",
"Max length: 768. \n",
"Epochs: 1. \n",
"Train / validation split: 95 / 5 from the official train split.\n",
"\n",
"## Results\n",
"\n",
"Validation loss: **0.3690**\n",
"Validation perplexity: **\n",
"- dataset:\n",
" id: openai/gsm8k\n",
" task_id: gsm8k\n",
" config: main\n",
" split: test\n",
" value: 0.095527\n",
" date: \"2026-05-09\"\n",
" notes: \"greedy, no-tools, local eval\"\n",
"\n"
]
}
],
"source": [
"merged_model.save_pretrained(publish_dir, safe_serialization=True, max_shard_size=\"2GB\")\n",
"tokenizer.save_pretrained(publish_dir)\n",
"\n",
"val_acc_100 = np.nan\n",
"if \"eval_exact_match_100\" in log_history.columns:\n",
" values = log_history[\"eval_exact_match_100\"].dropna()\n",
" if not values.empty:\n",
" val_acc_100 = float(values.iloc[-1])\n",
"\n",
"readme_text = f'''---\n",
"library_name: transformers\n",
"base_model: {base_model_id}\n",
"datasets:\n",
" - openai/gsm8k\n",
"tags:\n",
" - qwen3\n",
" - gsm8k\n",
" - lora\n",
" - reasoning\n",
" - math\n",
" - benchmark\n",
"license: apache-2.0\n",
"---\n",
"\n",
"# Qwen3-4B-MegaScience GSM8K fine-tune\n",
"\n",
"## Overview\n",
"\n",
"`MegaScience/Qwen3-4B-MegaScience` is a 4B Qwen3 modification trained on the MegaScience dataset. We fine-tuned it on GSM8K, which contains common-sense tasks for school students.\n",
"\n",
"## Dataset\n",
"\n",
"The official GSM8K train split is used for training and validation. We modify it by adding <think> and <answer> tokens so the model formats the answer more correctly.\n",
"\n",
"## Training setup\n",
"\n",
"Base model: `{base_model_id}`. \n",
"Fine-tuning style: SFT with LoRA. \n",
"Precision: {\"bf16\" if bf16 else \"fp16\"}. \n",
"Max length: 768. \n",
"Epochs: 1. \n",
"Train / validation split: 95 / 5 from the official train split.\n",
"\n",
"## Results\n",
"\n",
"Validation loss: **0.3690**\n",
"Validation perplexity: **1.4463**\n",
"Test loss: **0.3441**\n",
"Test perplexity: **1.4107**\n",
"Test exact match: **0.0955**\n",
"\n",
"## Inference\n",
"\n",
"Use the same chat format as training. The answer goes inside `<answer>`.\n",
"'''\n",
"(publish_dir / \"README.md\").write_text(readme_text, encoding=\"utf-8\")\n",
"\n",
"eval_dir = publish_dir / \".eval_results\"\n",
"eval_dir.mkdir(parents=True, exist_ok=True)\n",
"eval_yaml = f'''- dataset:\n",
" id: openai/gsm8k\n",
" task_id: gsm8k\n",
" config: main\n",
" split: test\n",
" value: {test_exact_match:.6f}\n",
" date: \"{date.today().isoformat()}\"\n",
" notes: \"greedy, no-tools, local eval\"\n",
"'''\n",
"(eval_dir / \"gsm8k.yaml\").write_text(eval_yaml, encoding=\"utf-8\")\n",
"\n",
"training_meta = {\n",
" \"base_model_id\": base_model_id,\n",
" \"hub_model_id\": hub_model_id,\n",
" \"seed\": seed,\n",
" \"bf16\": bf16,\n",
" \"fp16\": fp16,\n",
"}\n",
"(publish_dir / \"training_meta.json\").write_text(json.dumps(training_meta, indent=2, ensure_ascii=False), encoding=\"utf-8\")\n",
"\n",
"print(readme_text[:900])\n",
"print(eval_yaml)"
]
},
{
"cell_type": "markdown",
"id": "15b2569e",
"metadata": {},
"source": [
"## Hub upload\n",
"\n",
"The token must already be available as `HF_TOKEN`. The folder upload keeps the merged model and the benchmark metadata together."
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "6920a3f9",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1e19e9620bdf490b901a93fa213dbb9c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Processing Files (0 / 0): | | 0.00B / 0.00B "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dfc5d43431c44e7eb9d3ef10283bfc8f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"New Data Upload: | | 0.00B / 0.00B "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"pymlex/qwen3-4b-gsm8k-think-answer\n"
]
}
],
"source": [
"hf_token = \"YOUR_TOKEN\"\n",
"if hf_token is None or hf_token == \"\":\n",
" login()\n",
"else:\n",
" login(token=hf_token)\n",
"\n",
"api = HfApi()\n",
"api.create_repo(hub_model_id, repo_type=\"model\", exist_ok=True)\n",
"\n",
"upload_folder(\n",
" repo_id=hub_model_id,\n",
" folder_path=str(publish_dir),\n",
" repo_type=\"model\",\n",
" commit_message=\"Qwen3-4B-MegaScience fine-tuned on GSM8K\",\n",
")\n",
"\n",
"print(hub_model_id)"
]
},
{
"cell_type": "markdown",
"id": "5ac569bb",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ee55fbd7",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Janets ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\n",
"<think>\n",
"She sells 16 - 3 - 4 = <<16-3-4=9>>9 eggs each day at the farmers' market.\n",
"She makes 9 * 2 = $<<9*2=18>>18 every day at the farmers' market.\n",
"</think>\n",
"\n",
"<think>\n",
"18\n",
"<think><|im_end|>\n"
]
}
],
"source": [
"def solve_question(model, tokenizer, question, max_new_tokens=512):\n",
" prompt = build_prompt(question)\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
" prompt_len = inputs[\"input_ids\"].shape[-1]\n",
"\n",
" end_answer_id = tokenizer.convert_tokens_to_ids(\"</answer>\")\n",
" eos_id = tokenizer.eos_token_id\n",
"\n",
" with torch.inference_mode():\n",
" output = model.generate(\n",
" **inputs,\n",
" max_new_tokens=max_new_tokens,\n",
" do_sample=False,\n",
" eos_token_id=[eos_id, end_answer_id],\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" )\n",
"\n",
" new_tokens = output[0][prompt_len:]\n",
" text = tokenizer.decode(new_tokens, skip_special_tokens=False)\n",
"\n",
" if \"</answer>\" in text:\n",
" text = text.split(\"</answer>\")[0] + \"</answer>\"\n",
"\n",
" return text.strip()\n",
"\n",
"sample_question = raw[\"test\"][0][\"question\"]\n",
"sample_output = solve_question(merged_model, tokenizer, sample_question)\n",
"print(sample_question)\n",
"print(sample_output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}