{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "\n", "\n", "\n", "---\n", "\n", "# Pokemon Showdown Agent v6 with Unsloth on AMD ROCm\n", "\n", "**Author:** Yueyuan (GoldenGrapeGentleman)\n", "\n", "Build a competitive Pokemon Showdown action model from raw replay logs and reproduce the core `v6` workflow with `Qwen/Qwen3-4B`, Unsloth, and AMD ROCm.\n", "\n", "### What you will build\n", "1. **Portable ROCm setup**: Configure a notebook that works on MI300X-class systems and still degrades gracefully on smaller AMD machines.\n", "2. **Disk-safe data preparation**: Stream and filter raw replay logs from Hugging Face without materializing the full 40GB+ corpus locally.\n", "3. **Quick validation SFT**: Run a short 50-step alignment loop and watch the model shift from generic chat behavior to valid `move` / `switch` actions.\n", "4. **Inference sanity check**: Test the tuned model on a real battle prefix and inspect whether it emits a legal next action.\n", "5. **Export and publish**: Package the merged checkpoint and push your tutorial artifact to Hugging Face.\n", "\n", "### Tutorial scope\n", "This notebook intentionally trains on a small streamed subset so the workflow stays teachable, reproducible, and safe on limited disk. The full production `v6` checkpoint can be published separately after training or loaded from Hugging Face once available.\n", "\n", "### Prerequisites\n", "To install Unsloth on your AMD machine, follow the [AMD ROCm installation guide](https://unsloth.ai/docs/get-started/install/amd). This notebook is adapted from the Unsloth notebook ecosystem and keeps the same [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme) notebook license context." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. AMD ROCm environment setup\n", "When running on AMD GPUs such as MI300X or Radeon PRO W7900, Unsloth relies on ROCm plus PyTorch SDPA kernels for strong throughput. The next cell sets conservative defaults before importing training libraries, prefers large external cache mounts when they exist, and falls back to `/tmp` so the tutorial does not break on machines without `/shared-docker`." ] }, { "cell_type": "code", "metadata": {}, "source": [ "import os\n", "from pathlib import Path\n", "\n", "import datasets\n", "\n", "cache_candidates = [\n", " \"/shared-docker/.cache/huggingface\",\n", " \"/data/huggingface\",\n", " \"/tmp/pokemon-hf-cache\",\n", "]\n", "tmp_candidates = [\n", " \"/shared-docker/.cache/tmp\",\n", " \"/data/tmp\",\n", " \"/tmp/pokemon-tmp\",\n", "]\n", "\n", "cache_root = next((path for path in cache_candidates if os.path.isdir(path)), cache_candidates[-1])\n", "tmp_root = next((path for path in tmp_candidates if os.path.isdir(path)), tmp_candidates[-1])\n", "\n", "# Default to the first visible AMD GPU unless the user already pinned a device.\n", "os.environ.setdefault(\"HIP_VISIBLE_DEVICES\", \"0\")\n", "os.environ.setdefault(\"ROCR_VISIBLE_DEVICES\", os.environ[\"HIP_VISIBLE_DEVICES\"])\n", "\n", "os.environ[\"HF_HOME\"] = cache_root\n", "os.environ[\"HF_DATASETS_CACHE\"] = f\"{cache_root}/datasets\"\n", "os.environ[\"TMPDIR\"] = tmp_root\n", "\n", "Path(os.environ[\"HF_HOME\"]).mkdir(parents=True, exist_ok=True)\n", "Path(os.environ[\"HF_DATASETS_CACHE\"]).mkdir(parents=True, exist_ok=True)\n", "Path(os.environ[\"TMPDIR\"]).mkdir(parents=True, exist_ok=True)\n", "\n", "datasets.config.HF_DATASETS_CACHE = os.environ[\"HF_DATASETS_CACHE\"]\n", "\n", "# RDNA3 cards often need the AOTriton flag for Flash Attention via SDPA.\n", "os.environ.setdefault(\"TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL\", \"1\")\n", "os.environ.setdefault(\"PYTORCH_HIP_ALLOC_CONF\", \"expandable_segments:False\")\n", "os.environ.setdefault(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"1\")\n", "\n", "print(f\"HF_HOME={os.environ['HF_HOME']}\")\n", "print(f\"HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}\")\n", "print(f\"TMPDIR={os.environ['TMPDIR']}\")\n" ], "execution_count": 10, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Installation\n", "If you are not already inside a prepared ROCm image, install Unsloth and the core training stack. For maximum AMD compatibility, this tutorial avoids depending on optional 8-bit optimizer packages and sticks to the standard PyTorch / Transformers / TRL toolchain." ] }, { "cell_type": "code", "metadata": {}, "source": [ "%%capture\n", "import os\n", "\n", "if \"COLAB_\" not in \"\".join(os.environ.keys()):\n", " !pip install --no-cache-dir --no-deps unsloth unsloth_zoo\n", " !pip install --no-cache-dir transformers accelerate peft trl datasets psutil sentencepiece protobuf tyro huggingface_hub hf_transfer einops\n", "else:\n", " pass # Colab / Kaggle usually need their own setup flow." ], "execution_count": 11, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load the base model\n", "We load `Qwen/Qwen3-4B` with a configuration tuned for AMD inference stability. The full production `v6` recipe used longer contexts and more data, but this tutorial keeps the setup compact enough for a short reproducible run." ] }, { "cell_type": "code", "metadata": {}, "source": [ "import unsloth\n", "from unsloth import FastLanguageModel\n", "import torch\n", "\n", "max_seq_length = 2048 # Small enough for a fast tutorial, large enough for real replay prefixes.\n", "dtype = torch.bfloat16 # Recommended on AMD for stable training and inference.\n", "load_in_4bit = False # Keep inference stable on ROCm for the public tutorial flow.\n", "\n", "model, tokenizer = FastLanguageModel.from_pretrained(\n", " model_name=\"Qwen/Qwen3-4B\",\n", " max_seq_length=max_seq_length,\n", " dtype=dtype,\n", " load_in_4bit=load_in_4bit,\n", " attn_implementation=\"sdpa\",\n", ")" ], "execution_count": 12, "outputs": [ { "output_type": "stream", "text": [ "==((====))== Unsloth 2026.1.4: Fast Qwen3 patching. Transformers: 4.57.6.\n", " \\\\ /| AMD Instinct MI300X VF. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.\n", "O^O/ \\_/ \\ Torch: 2.10.0+rocm7.1. ROCm Toolkit: 7.1.25424. Triton: 3.6.0\n", "\\ / Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]\n", " \"-____-\" Free license: http://github.com/unslothai/unsloth\n", "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n" ] }, { "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n" ] } ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now add LoRA adapters so we only need to update 1 to 10% of all parameters!" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# 2. Apply LoRA Adapters\n", "model = FastLanguageModel.get_peft_model(\n", " model,\n", " r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n", " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n", " \"gate_proj\", \"up_proj\", \"down_proj\"],\n", " lora_alpha = 128,\n", " lora_dropout = 0, # Supports any, but = 0 is optimized\n", " bias = \"none\", # Supports any, but = \"none\" is optimized\n", " use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n", " random_state = 3407,\n", " use_rslora = False, # We support rank stabilized LoRA\n", " loftq_config = None, # And LoftQ\n", ")" ], "execution_count": 13, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Data preparation from raw replay logs\n", "The `v6` idea is simple but demanding: feed the model messy real replay logs instead of hand-written state summaries. For the public tutorial we stream `milkkarten/pokemon-showdown-replays-merged`, filter to stronger Gen 9 games, and build a compact subset entirely in memory so the notebook stays disk-safe." ] }, { "cell_type": "code", "metadata": {}, "source": [ "from datasets import Dataset, load_dataset\n", "from unsloth.chat_templates import get_chat_template\n", "\n", "tokenizer = get_chat_template(tokenizer, chat_template=\"qwen3\")\n", "\n", "MIN_RATING = 1400\n", "MAX_TRAIN_SAMPLES = 2000\n", "SHUFFLE_BUFFER = 10_000\n", "MAX_LOG_CHARS = 6000\n", "\n", "SYSTEM_TEMPLATE = (\n", " \"You are a Pokemon Showdown battle AI. You play as {side}. \"\n", " \"Given the battle log, output your next action. \"\n", " \"Format: move