Files
pokemon-showdown-agent-v6/pokemon_showdown_agent_v6_tutorial.ipynb

590 lines
41 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div class=\"align-center\">\n",
"\n",
"<a href=\"https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/index.html\"><img src=\"https://raw.githubusercontent.com/ROCm/gpuaidev/main/docs/images/rocm_logo.png\" alt=\"ROCm AI Developer Hub\" width=\"150\" style=\"display:inline-block; margin-right: 20px;\"></a>\n",
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\" style=\"display:inline-block;\"></a>\n",
"\n",
"</div>\n",
"\n",
"<div align=\"center\">\n",
"\n",
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
"<a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
"\n",
"</div>\n",
"\n",
"---\n",
"\n",
"# Pokemon Showdown Agent v6 with Unsloth on AMD ROCm\n",
"\n",
"**Author:** Yueyuan (GoldenGrapeGentleman)\n",
"\n",
"Build a competitive Pokemon Showdown action model from raw replay logs and reproduce the core `v6` workflow with `Qwen/Qwen3-4B`, Unsloth, and AMD ROCm.\n",
"\n",
"### What you will build\n",
"1. **Portable ROCm setup**: Configure a notebook that works on MI300X-class systems and still degrades gracefully on smaller AMD machines.\n",
"2. **Disk-safe data preparation**: Stream and filter raw replay logs from Hugging Face without materializing the full 40GB+ corpus locally.\n",
"3. **Quick validation SFT**: Run a short 50-step alignment loop and watch the model shift from generic chat behavior to valid `move` / `switch` actions.\n",
"4. **Inference sanity check**: Test the tuned model on a real battle prefix and inspect whether it emits a legal next action.\n",
"5. **Export and publish**: Package the merged checkpoint and push your tutorial artifact to Hugging Face.\n",
"\n",
"### Tutorial scope\n",
"This notebook intentionally trains on a small streamed subset so the workflow stays teachable, reproducible, and safe on limited disk. The full production `v6` checkpoint can be published separately after training or loaded from Hugging Face once available.\n",
"\n",
"### Prerequisites\n",
"To install Unsloth on your AMD machine, follow the [AMD ROCm installation guide](https://unsloth.ai/docs/get-started/install/amd). This notebook is adapted from the Unsloth notebook ecosystem and keeps the same [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme) notebook license context."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. AMD ROCm environment setup\n",
"When running on AMD GPUs such as MI300X or Radeon PRO W7900, Unsloth relies on ROCm plus PyTorch SDPA kernels for strong throughput. The next cell sets conservative defaults before importing training libraries, prefers large external cache mounts when they exist, and falls back to `/tmp` so the tutorial does not break on machines without `/shared-docker`."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import datasets\n",
"\n",
"cache_candidates = [\n",
" \"/shared-docker/.cache/huggingface\",\n",
" \"/data/huggingface\",\n",
" \"/tmp/pokemon-hf-cache\",\n",
"]\n",
"tmp_candidates = [\n",
" \"/shared-docker/.cache/tmp\",\n",
" \"/data/tmp\",\n",
" \"/tmp/pokemon-tmp\",\n",
"]\n",
"\n",
"cache_root = next((path for path in cache_candidates if os.path.isdir(path)), cache_candidates[-1])\n",
"tmp_root = next((path for path in tmp_candidates if os.path.isdir(path)), tmp_candidates[-1])\n",
"\n",
"# Default to the first visible AMD GPU unless the user already pinned a device.\n",
"os.environ.setdefault(\"HIP_VISIBLE_DEVICES\", \"0\")\n",
"os.environ.setdefault(\"ROCR_VISIBLE_DEVICES\", os.environ[\"HIP_VISIBLE_DEVICES\"])\n",
"\n",
"os.environ[\"HF_HOME\"] = cache_root\n",
"os.environ[\"HF_DATASETS_CACHE\"] = f\"{cache_root}/datasets\"\n",
"os.environ[\"TMPDIR\"] = tmp_root\n",
"\n",
"Path(os.environ[\"HF_HOME\"]).mkdir(parents=True, exist_ok=True)\n",
"Path(os.environ[\"HF_DATASETS_CACHE\"]).mkdir(parents=True, exist_ok=True)\n",
"Path(os.environ[\"TMPDIR\"]).mkdir(parents=True, exist_ok=True)\n",
"\n",
"datasets.config.HF_DATASETS_CACHE = os.environ[\"HF_DATASETS_CACHE\"]\n",
"\n",
"# RDNA3 cards often need the AOTriton flag for Flash Attention via SDPA.\n",
"os.environ.setdefault(\"TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL\", \"1\")\n",
"os.environ.setdefault(\"PYTORCH_HIP_ALLOC_CONF\", \"expandable_segments:False\")\n",
"os.environ.setdefault(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"1\")\n",
"\n",
"print(f\"HF_HOME={os.environ['HF_HOME']}\")\n",
"print(f\"HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}\")\n",
"print(f\"TMPDIR={os.environ['TMPDIR']}\")\n"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Installation\n",
"If you are not already inside a prepared ROCm image, install Unsloth and the core training stack. For maximum AMD compatibility, this tutorial avoids depending on optional 8-bit optimizer packages and sticks to the standard PyTorch / Transformers / TRL toolchain."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"%%capture\n",
"import os\n",
"\n",
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
" !pip install --no-cache-dir --no-deps unsloth unsloth_zoo\n",
" !pip install --no-cache-dir transformers accelerate peft trl datasets psutil sentencepiece protobuf tyro huggingface_hub hf_transfer einops\n",
"else:\n",
" pass # Colab / Kaggle usually need their own setup flow."
],
"execution_count": 11,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load the base model\n",
"We load `Qwen/Qwen3-4B` with a configuration tuned for AMD inference stability. The full production `v6` recipe used longer contexts and more data, but this tutorial keeps the setup compact enough for a short reproducible run."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import unsloth\n",
"from unsloth import FastLanguageModel\n",
"import torch\n",
"\n",
"max_seq_length = 2048 # Small enough for a fast tutorial, large enough for real replay prefixes.\n",
"dtype = torch.bfloat16 # Recommended on AMD for stable training and inference.\n",
"load_in_4bit = False # Keep inference stable on ROCm for the public tutorial flow.\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=\"Qwen/Qwen3-4B\",\n",
" max_seq_length=max_seq_length,\n",
" dtype=dtype,\n",
" load_in_4bit=load_in_4bit,\n",
" attn_implementation=\"sdpa\",\n",
")"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"==((====))== Unsloth 2026.1.4: Fast Qwen3 patching. Transformers: 4.57.6.\n",
" \\\\ /| AMD Instinct MI300X VF. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.\n",
"O^O/ \\_/ \\ Torch: 2.10.0+rocm7.1. ROCm Toolkit: 7.1.25424. Triton: 3.6.0\n",
"\\ / Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]\n",
" \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
]
},
{
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now add LoRA adapters so we only need to update 1 to 10% of all parameters!"
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"# 2. Apply LoRA Adapters\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
" lora_alpha = 128,\n",
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
" use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
" random_state = 3407,\n",
" use_rslora = False, # We support rank stabilized LoRA\n",
" loftq_config = None, # And LoftQ\n",
")"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Data preparation from raw replay logs\n",
"The `v6` idea is simple but demanding: feed the model messy real replay logs instead of hand-written state summaries. For the public tutorial we stream `milkkarten/pokemon-showdown-replays-merged`, filter to stronger Gen 9 games, and build a compact subset entirely in memory so the notebook stays disk-safe."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from datasets import Dataset, load_dataset\n",
"from unsloth.chat_templates import get_chat_template\n",
"\n",
"tokenizer = get_chat_template(tokenizer, chat_template=\"qwen3\")\n",
"\n",
"MIN_RATING = 1400\n",
"MAX_TRAIN_SAMPLES = 2000\n",
"SHUFFLE_BUFFER = 10_000\n",
"MAX_LOG_CHARS = 6000\n",
"\n",
"SYSTEM_TEMPLATE = (\n",
" \"You are a Pokemon Showdown battle AI. You play as {side}. \"\n",
" \"Given the battle log, output your next action. \"\n",
" \"Format: move <name> OR switch <name>. \"\n",
" \"Append terastallize if you terastallize this turn.\"\n",
")\n",
"\n",
"def extract_winner_side(log_text):\n",
" winner = None\n",
" players = {}\n",
" for line in log_text.split(\"\\n\"):\n",
" parts = line.split(\"|\")\n",
" if len(parts) >= 4 and parts[1] == \"player\":\n",
" players[parts[2]] = parts[3]\n",
" if len(parts) >= 3 and parts[1] == \"win\":\n",
" winner = parts[2]\n",
" if not winner:\n",
" return None\n",
" for side, name in players.items():\n",
" if name == winner:\n",
" return side\n",
" return None\n",
"\n",
"def format_sample(example):\n",
" log_text = example[\"log\"]\n",
" side = extract_winner_side(log_text)\n",
" if not side:\n",
" return {\"text\": \"\"}\n",
"\n",
" lines = log_text.strip().split(\"\\n\")\n",
" turn_positions = []\n",
" for i, line in enumerate(lines):\n",
" parts = line.split(\"|\")\n",
" if len(parts) >= 3 and parts[1] == \"turn\":\n",
" try:\n",
" turn_positions.append((int(parts[2]), i))\n",
" except ValueError:\n",
" pass\n",
"\n",
" if len(turn_positions) < 2:\n",
" return {\"text\": \"\"}\n",
"\n",
" # Keep prompts compact for the tutorial by using the first actionable turn\n",
" # after the opening whenever possible.\n",
" target_turn_idx = 0 if len(turn_positions) == 2 else 1\n",
" _, turn_line_idx = turn_positions[target_turn_idx]\n",
" next_turn_line = turn_positions[target_turn_idx + 1][1] if target_turn_idx + 1 < len(turn_positions) else len(lines)\n",
"\n",
" action = None\n",
" for j in range(turn_line_idx + 1, next_turn_line):\n",
" parts = lines[j].split(\"|\")\n",
" if len(parts) < 4:\n",
" continue\n",
" if parts[1] == \"move\" and parts[2].startswith(f\"{side}a:\"):\n",
" tera = \"\"\n",
" start_look = max(0, j - 3)\n",
" end_look = min(len(lines), j + 3)\n",
" if any(\"terastallize\" in lines[k] and side in lines[k] for k in range(start_look, end_look)):\n",
" tera = \" terastallize\"\n",
" action = f\"move {parts[3]}{tera}\"\n",
" break\n",
" if parts[1] == \"switch\" and parts[2].startswith(f\"{side}a:\"):\n",
" pokemon = parts[2].split(\": \")[1] if \": \" in parts[2] else parts[2]\n",
" action = f\"switch {pokemon}\"\n",
" break\n",
"\n",
" if not action:\n",
" return {\"text\": \"\"}\n",
"\n",
" log_prefix = \"\\n\".join(lines[:turn_line_idx + 1])\n",
" if len(log_prefix) > MAX_LOG_CHARS:\n",
" return {\"text\": \"\"}\n",
"\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_TEMPLATE.format(side=side)},\n",
" {\"role\": \"user\", \"content\": log_prefix},\n",
" {\"role\": \"assistant\", \"content\": action},\n",
" ]\n",
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
" return {\"text\": text}\n",
"\n",
"print(\"Streaming, filtering, and formatting replay logs...\")\n",
"dataset = load_dataset(\"milkkarten/pokemon-showdown-replays-merged\", split=\"train\", streaming=True)\n",
"dataset = dataset.shuffle(seed=3407, buffer_size=SHUFFLE_BUFFER)\n",
"dataset = dataset.filter(lambda x: \"gen9\" in str(x.get(\"formatid\") or \"\") and (x.get(\"rating\") or 0) >= MIN_RATING)\n",
"\n",
"train_samples = []\n",
"scanned = 0\n",
"for row in dataset:\n",
" scanned += 1\n",
" formatted = format_sample(row)\n",
" if formatted[\"text\"]:\n",
" train_samples.append(formatted)\n",
" if len(train_samples) >= MAX_TRAIN_SAMPLES:\n",
" break\n",
"\n",
"train_dataset = Dataset.from_list(train_samples).shuffle(seed=3407)\n",
"print(f\"Collected {len(train_dataset)} training examples after scanning {scanned} streamed replays.\")\n",
"print(train_dataset[0][\"text\"][:800])"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading and parsing dataset...\n"
]
},
{
"output_type": "stream",
"text": [
"Repo card metadata block was not found. Setting CardData to empty.\n",
"[huggingface_hub.repocard|WARNING]Repo card metadata block was not found. Setting CardData to empty.\n"
]
},
{
"output_type": "stream",
"text": [
"Downloading data: 100%|██████████| 583/583 [00:00<00:00, 3215.89files/s]\n",
"Generating train split: 96%|█████████▌| 27837802/29057184 [08:06<00:21, 57181.98 examples/s]\n"
]
},
{
"output_type": "error",
"ename": "DatasetGenerationError",
"evalue": "An error occurred while generating the dataset",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1834\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1833\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1834\u001b[0m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_table\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtable\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1835\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m CastError \u001b[38;5;28;01mas\u001b[39;00m cast_error:\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/arrow_writer.py:719\u001b[0m, in \u001b[0;36mArrowWriter.write_table\u001b[0;34m(self, pa_table, writer_batch_size)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_examples \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m pa_table\u001b[38;5;241m.\u001b[39mnum_rows\n\u001b[0;32m--> 719\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpa_writer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_table\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwriter_batch_size\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/pyarrow/ipc.pxi:616\u001b[0m, in \u001b[0;36mpyarrow.lib._CRecordBatchWriter.write_table\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/pyarrow/error.pxi:89\u001b[0m, in \u001b[0;36mpyarrow.lib.check_status\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/fsspec/implementations/local.py:469\u001b[0m, in \u001b[0;36mLocalFileOpener.write\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mwrite\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 469\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1850\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1849\u001b[0m num_shards \u001b[38;5;241m=\u001b[39m shard_id \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1850\u001b[0m num_examples, num_bytes \u001b[38;5;241m=\u001b[39m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfinalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1851\u001b[0m writer\u001b[38;5;241m.\u001b[39mclose()\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/arrow_writer.py:736\u001b[0m, in \u001b[0;36mArrowWriter.finalize\u001b[0;34m(self, close_stream)\u001b[0m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m close_stream:\n\u001b[0;32m--> 736\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/fsspec/implementations/local.py:487\u001b[0m, in \u001b[0;36mLocalFileOpener.close\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mclose\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 487\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mDatasetGenerationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 82\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m: text}\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading and parsing dataset...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 82\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmilkkarten/pokemon-showdown-replays-merged\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m# Filter a small subset for demonstration purposes\u001b[39;00m\n\u001b[1;32m 85\u001b[0m dataset \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39mfilter(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgen9\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(x\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mformatid\u001b[39m\u001b[38;5;124m'\u001b[39m)) \u001b[38;5;129;01mand\u001b[39;00m (x\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrating\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1400\u001b[39m)\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/load.py:1417\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, **config_kwargs)\u001b[0m\n\u001b[1;32m 1414\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 1416\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 1417\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1418\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1419\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1420\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1421\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1422\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1423\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1425\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 1426\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1427\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 1428\u001b[0m )\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:897\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, dl_manager, base_path, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 895\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 896\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m--> 897\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 898\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 899\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 900\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 901\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 902\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 903\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 904\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:973\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 969\u001b[0m split_dict\u001b[38;5;241m.\u001b[39madd(split_generator\u001b[38;5;241m.\u001b[39msplit_info)\n\u001b[1;32m 971\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 972\u001b[0m \u001b[38;5;66;03m# Prepare split will record examples associated to the split\u001b[39;00m\n\u001b[0;32m--> 973\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_generator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 974\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m(\n\u001b[1;32m 976\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot find data file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 977\u001b[0m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_download_instructions \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 978\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mOriginal error:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)\n\u001b[1;32m 980\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1705\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split\u001b[0;34m(self, split_generator, file_format, num_proc, max_shard_size)\u001b[0m\n\u001b[1;32m 1703\u001b[0m job_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 1704\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pbar:\n\u001b[0;32m-> 1705\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m job_id, done, content \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_split_single(\n\u001b[1;32m 1706\u001b[0m gen_kwargs\u001b[38;5;241m=\u001b[39mgen_kwargs, job_id\u001b[38;5;241m=\u001b[39mjob_id, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m_prepare_split_args\n\u001b[1;32m 1707\u001b[0m ):\n\u001b[1;32m 1708\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 1709\u001b[0m result \u001b[38;5;241m=\u001b[39m content\n",
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1861\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1859\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, DatasetGenerationError):\n\u001b[1;32m 1860\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1861\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetGenerationError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAn error occurred while generating the dataset\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1863\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m job_id, \u001b[38;5;28;01mTrue\u001b[39;00m, (total_num_examples, total_num_bytes, writer\u001b[38;5;241m.\u001b[39m_features, num_shards, shard_lengths)\n",
"\u001b[0;31mDatasetGenerationError\u001b[0m: An error occurred while generating the dataset"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Quick validation SFT\n",
"Now we run a deliberately short 50-step supervised fine-tuning loop. The point is not to maximize ladder strength in one notebook session, but to make *Agentic Alignment* visible: even a tiny run is often enough to push the model toward strict action formatting instead of free-form commentary."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from trl import SFTConfig, SFTTrainer\n",
"from unsloth import is_bfloat16_supported\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" train_dataset=train_dataset,\n",
" dataset_text_field=\"text\",\n",
" max_seq_length=max_seq_length,\n",
" packing=False,\n",
" args=SFTConfig(\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=5,\n",
" max_steps=50,\n",
" learning_rate=2e-4,\n",
" fp16=not is_bfloat16_supported(),\n",
" bf16=is_bfloat16_supported(),\n",
" logging_steps=10,\n",
" optim=\"adamw_torch\",\n",
" weight_decay=0.01,\n",
" lr_scheduler_type=\"linear\",\n",
" seed=3407,\n",
" output_dir=\"model_output_v6_demo\",\n",
" save_steps=50,\n",
" save_total_limit=2,\n",
" report_to=\"none\",\n",
" ),\n",
")\n",
"\n",
"trainer_stats = trainer.train()\n",
"\n",
"# Optional: export a standalone merged checkpoint for sharing or Hub upload.\n",
"# export_dir = \"model_output_v6_demo/merged_model\"\n",
"# model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Inference sanity check\n",
"Use the same prompt once before training and once after training if you want to see the alignment jump clearly. A good post-SFT answer should be short, action-oriented, and formatted as a legal `move` or `switch` command."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"FastLanguageModel.for_inference(model)\n",
"\n",
"sys_msg = \"You are a Pokemon Showdown battle AI. You play as p2. Given the battle log, output your next action. Format: move <name> OR switch <name>. Append terastallize if you terastallize this turn.\"\n",
"user_msg = '''|player|p1|Player1|266|1500\n",
"|player|p2|Player2|1|1500\n",
"|teamsize|p1|6\n",
"|teamsize|p2|6\n",
"|gen|9\n",
"|tier|[Gen 9] OU\n",
"|\n",
"|start\n",
"|switch|p1a: Garchomp|Garchomp, M|100/100\n",
"|switch|p2a: Corviknight|Corviknight, M|100/100\n",
"|turn|1\n",
"|move|p1a: Garchomp|Earthquake|p2a: Corviknight\n",
"|-immune|p2a: Corviknight\n",
"|turn|2'''\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": sys_msg},\n",
" {\"role\": \"user\", \"content\": user_msg},\n",
"]\n",
"\n",
"text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"with torch.no_grad():\n",
" outputs = model.generate(\n",
" **inputs,\n",
" max_new_tokens=64,\n",
" temperature=0.1,\n",
" do_sample=False,\n",
" pad_token_id=tokenizer.eos_token_id,\n",
" )\n",
"\n",
"full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)\n",
"if \"<|im_start|>assistant\\n\" in full_response:\n",
" response = full_response.split(\"<|im_start|>assistant\\n\")[-1]\n",
"else:\n",
" response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)\n",
"response = response.replace(\"<|im_end|>\", \"\").strip()\n",
"\n",
"print(\"\\n--- AI AGENT PREDICTION ---\")\n",
"print(response)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Export and publish to Hugging Face\n",
"After the 50-step tutorial run, you can optionally merge the adapter into standalone weights and publish the result.\n",
"\n",
"```python\n",
"export_dir = \"model_output_v6_demo/merged_model\"\n",
"model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")\n",
"```\n",
"\n",
"Then upload the merged folder and notebook:\n",
"\n",
"```bash\n",
"hf auth whoami\n",
"hf upload-large-folder your-username/pokemon-showdown-agent-v6-demo model_output_v6_demo/merged_model\n",
"hf upload your-username/pokemon-showdown-agent-v6-demo pokemon_agent_demo_notebook_v2.ipynb pokemon_agent_demo_notebook_v6.ipynb\n",
"```\n",
"\n",
"The full production checkpoint is now published at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`, while this notebook remains the lighter tutorial flow for fast reproduction."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bonus: advanced tactics with GRPO\n",
"\n",
"Supervised fine-tuning teaches the model how strong players act. To push beyond imitation and reward outcomes or formatting more explicitly, you can add **Group Relative Policy Optimization (GRPO)** on top.\n",
"\n",
"For Pokemon, a first reward function can stay simple: reward valid action formatting, reward strategically sensible outputs, and penalize clearly bad or impossible commands. The cell below is still conceptual, but it shows how the `v6` agent can extend into RL after the SFT tutorial."
]
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from trl import GRPOTrainer, GRPOConfig\n",
"import re\n",
"\n",
"# 1. Define a Reward Function (Example: Check if output contains a valid command)\n",
"def format_reward_func(prompts, completions, **kwargs):\n",
" rewards = []\n",
" for completion in completions:\n",
" # TRL completions may be strings or list of dicts depending on version\n",
" text = completion[0][\"content\"] if isinstance(completion, list) else str(completion)\n",
" \n",
" # Reward the model heavily if it successfully follows the format `move X` or `switch X`\n",
" cmd_match = re.search(r\"(move\\s+[\\w\\-]+|switch\\s+[\\w\\-]+)\", text, re.IGNORECASE)\n",
" if cmd_match:\n",
" rewards.append(5.0)\n",
" else:\n",
" rewards.append(-3.0) # Penalize conversational babbling\n",
" \n",
" return rewards\n",
"\n",
"# 2. Prepare GRPO Prompts (We only need the prompts, GRPO will generate its own completions)\n",
"grpo_prompts = [{\"prompt\": p[\"text\"].split(\"<|im_start|>assistant\")[0] + \"<|im_start|>assistant\\n\"} for p in train_samples[:50]]\n",
"from datasets import Dataset\n",
"grpo_dataset = Dataset.from_list(grpo_prompts)\n",
"\n",
"# 3. Configure GRPO\n",
"grpo_config = GRPOConfig(\n",
" output_dir = \"grpo_outputs\",\n",
" learning_rate = 3e-6,\n",
" per_device_train_batch_size = 1,\n",
" gradient_accumulation_steps = 4,\n",
" num_generations = 4, # Generate 4 different strategies to compare\n",
" max_completion_length = 128,\n",
" temperature = 1.3, # Encourage exploration\n",
" max_steps = 10,\n",
" logging_steps = 1,\n",
" report_to = \"none\",\n",
" fp16 = not is_bfloat16_supported(),\n",
" bf16 = is_bfloat16_supported(),\n",
")\n",
"\n",
"grpo_trainer = GRPOTrainer(\n",
" model = model,\n",
" reward_funcs = [format_reward_func],\n",
" args = grpo_config,\n",
" train_dataset = grpo_dataset,\n",
")\n",
"\n",
"# Uncomment to run the GRPO loop\n",
"# grpo_trainer.train()"
],
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}