Files
pokemon-showdown-agent-v6/pokemon_showdown_agent_v6_release.ipynb
ModelHub XC 37b4d3a899 初始化项目,由ModelHub XC社区提供模型
Model: GoldenGrapeGentleman1/pokemon-showdown-agent-v6
Source: Original Platform
2026-04-18 08:31:48 +08:00

669 lines
28 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pokemon Showdown Agent v6 on AMD ROCm\n",
"\n",
"**Author:** Yueyuan (GoldenGrapeGentleman)\n",
"\n",
"This is a clean release notebook for the `v6` Pokemon Showdown agent workflow. It keeps the original notebook untouched and rebuilds the tutorial from scratch with a stricter engineering standard for AMD ROCm users.\n",
"\n",
"## What this release notebook fixes\n",
"1. Dependency installation happens before any third-party imports.\n",
"2. Cache and temp paths fall back safely when `/shared-docker` is unavailable.\n",
"3. Data preparation uses streaming and never materializes the full replay corpus locally.\n",
"4. The tutorial keeps `load_in_4bit=False` for ROCm inference stability and uses `adamw_torch` instead of 8-bit optimizers.\n",
"5. All outputs are intentionally cleared so readers only see results produced in their own environment.\n",
"\n",
"## Recommended workflow\n",
"1. Run the install cell once.\n",
"2. If packages were installed, rerun the notebook from the top.\n",
"3. Execute the remaining cells in order.\n",
"\n",
"## Published production model\n",
"The full production checkpoint is already available at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`.\n",
"\n",
"## Tutorial scope\n",
"This notebook intentionally trains on a small streamed subset so the workflow stays teachable, reproducible, and disk-safe on AMD systems."
],
"id": "18a89f18"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import os\n",
"import platform\n",
"import shutil\n",
"from pathlib import Path\n",
"\n",
"\n",
"def pick_writable_dir(candidates):\n",
" for candidate in candidates:\n",
" path = Path(candidate)\n",
" try:\n",
" path.mkdir(parents=True, exist_ok=True)\n",
" probe = path / \".cursor_write_test\"\n",
" probe.write_text(\"ok\", encoding=\"utf-8\")\n",
" probe.unlink()\n",
" return path\n",
" except Exception:\n",
" continue\n",
" raise RuntimeError(f\"No writable directory found in: {candidates}\")\n",
"\n",
"\n",
"cache_root = pick_writable_dir([\n",
" \"/shared-docker/.cache/huggingface\",\n",
" \"/data/huggingface\",\n",
" \"/tmp/pokemon-hf-cache\",\n",
"])\n",
"tmp_root = pick_writable_dir([\n",
" \"/shared-docker/.cache/tmp\",\n",
" \"/data/tmp\",\n",
" \"/tmp/pokemon-tmp\",\n",
"])\n",
"\n",
"os.environ.setdefault(\"HIP_VISIBLE_DEVICES\", \"0\")\n",
"os.environ.setdefault(\"ROCR_VISIBLE_DEVICES\", os.environ[\"HIP_VISIBLE_DEVICES\"])\n",
"os.environ.setdefault(\"TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL\", \"1\")\n",
"os.environ.setdefault(\"PYTORCH_HIP_ALLOC_CONF\", \"expandable_segments:False\")\n",
"os.environ.setdefault(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"1\")\n",
"os.environ.setdefault(\"HF_HUB_ENABLE_HF_TRANSFER\", \"1\")\n",
"\n",
"os.environ[\"HF_HOME\"] = str(cache_root)\n",
"os.environ[\"HF_DATASETS_CACHE\"] = str(cache_root / \"datasets\")\n",
"os.environ[\"TMPDIR\"] = str(tmp_root)\n",
"\n",
"Path(os.environ[\"HF_DATASETS_CACHE\"]).mkdir(parents=True, exist_ok=True)\n",
"Path(os.environ[\"TMPDIR\"]).mkdir(parents=True, exist_ok=True)\n",
"\n",
"cache_usage = shutil.disk_usage(cache_root)\n",
"print(f\"platform={platform.platform()}\")\n",
"print(f\"HF_HOME={os.environ['HF_HOME']}\")\n",
"print(f\"HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}\")\n",
"print(f\"TMPDIR={os.environ['TMPDIR']}\")\n",
"print(f\"HIP_VISIBLE_DEVICES={os.environ['HIP_VISIBLE_DEVICES']}\")\n",
"print(f\"free_space_gb={cache_usage.free / (1024 ** 3):.1f}\")"
],
"execution_count": null,
"outputs": [],
"id": "2e2615d3"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Install the runtime stack\n",
"This cell intentionally uses only the Python standard library before installation. It checks for missing packages, reports the versions already present in the environment, and only installs core dependencies when they are missing.\n",
"\n",
"On a prebuilt AMD ROCm container, preserving the existing working stack is usually safer than forcing package downgrades inside the notebook.\n",
"\n",
"If packages are installed into a fresh environment, rerun the notebook from the top before continuing."
],
"id": "bd338c40"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import importlib.metadata\n",
"import importlib.util\n",
"import subprocess\n",
"import sys\n",
"\n",
"core_runtime = [\n",
" \"transformers\",\n",
" \"trl\",\n",
" \"datasets\",\n",
" \"tokenizers\",\n",
" \"huggingface_hub\",\n",
" \"accelerate\",\n",
" \"peft\",\n",
" \"psutil\",\n",
" \"sentencepiece\",\n",
" \"protobuf\",\n",
" \"tyro\",\n",
" \"hf_transfer\",\n",
" \"einops\",\n",
"]\n",
"required_modules = [\n",
" \"unsloth\",\n",
" \"unsloth_zoo\",\n",
" \"transformers\",\n",
" \"trl\",\n",
" \"datasets\",\n",
" \"tokenizers\",\n",
" \"huggingface_hub\",\n",
" \"accelerate\",\n",
" \"peft\",\n",
" \"psutil\",\n",
" \"sentencepiece\",\n",
" \"google.protobuf\",\n",
" \"tyro\",\n",
" \"einops\",\n",
"]\n",
"\n",
"missing = [name for name in required_modules if importlib.util.find_spec(name) is None]\n",
"\n",
"if missing:\n",
" print(f\"Installing missing packages: {missing}\")\n",
" subprocess.check_call([\n",
" sys.executable,\n",
" \"-m\",\n",
" \"pip\",\n",
" \"install\",\n",
" \"--no-cache-dir\",\n",
" \"--no-deps\",\n",
" \"unsloth\",\n",
" \"unsloth_zoo\",\n",
" ])\n",
" subprocess.check_call([\n",
" sys.executable,\n",
" \"-m\",\n",
" \"pip\",\n",
" \"install\",\n",
" \"--no-cache-dir\",\n",
" *core_runtime,\n",
" ])\n",
" print(\"Install completed. Rerun the notebook from the top if this is a fresh environment.\")\n",
"else:\n",
" print(\"All required packages are already available.\")\n",
"\n",
"print(\"Detected package versions:\")\n",
"for package_name in [\"unsloth\", \"transformers\", \"trl\", \"datasets\", \"tokenizers\", \"huggingface_hub\"]:\n",
" try:\n",
" print(f\" {package_name}={importlib.metadata.version(package_name)}\")\n",
" except importlib.metadata.PackageNotFoundError:\n",
" print(f\" {package_name}=missing\")"
],
"execution_count": null,
"outputs": [],
"id": "fa699cd4"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Validate the AMD ROCm runtime\n",
"The next cell imports the runtime in an Unsloth-safe order, reports the key package versions, and confirms that the environment can see a ROCm-backed GPU before we touch the model or dataset pipeline."
],
"id": "f74ebb30"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import unsloth\n",
"import datasets\n",
"import huggingface_hub\n",
"import torch\n",
"import transformers\n",
"import trl\n",
"\n",
"\n",
"datasets.config.HF_DATASETS_CACHE = os.environ[\"HF_DATASETS_CACHE\"]\n",
"\n",
"print(f\"torch={torch.__version__}\")\n",
"print(f\"transformers={transformers.__version__}\")\n",
"print(f\"trl={trl.__version__}\")\n",
"print(f\"datasets={datasets.__version__}\")\n",
"print(f\"huggingface_hub={huggingface_hub.__version__}\")\n",
"print(f\"unsloth={getattr(unsloth, '__version__', 'unknown')}\")\n",
"print(f\"hip_version={getattr(torch.version, 'hip', None)}\")\n",
"\n",
"if not torch.cuda.is_available():\n",
" raise RuntimeError(\"No ROCm-visible GPU found. This notebook expects an AMD GPU environment.\")\n",
"\n",
"print(f\"gpu_count={torch.cuda.device_count()}\")\n",
"print(f\"gpu_name={torch.cuda.get_device_name(0)}\")\n",
"print(f\"bf16_supported={torch.cuda.is_bf16_supported()}\")"
],
"execution_count": null,
"outputs": [],
"id": "1e4fcdaa"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load Qwen3-4B and attach LoRA adapters\n",
"For the public AMD tutorial we keep `load_in_4bit=False` because the goal is a stable, reproducible path on ROCm. The production project can still explore tighter memory modes later, but the release notebook should prefer correctness and stability first."
],
"id": "d6c38fd1"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from unsloth import FastLanguageModel\n",
"from unsloth.chat_templates import get_chat_template\n",
"\n",
"max_seq_length = 2048\n",
"dtype = torch.bfloat16\n",
"load_in_4bit = False\n",
"\n",
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
" model_name=\"Qwen/Qwen3-4B\",\n",
" max_seq_length=max_seq_length,\n",
" dtype=dtype,\n",
" load_in_4bit=load_in_4bit,\n",
" attn_implementation=\"sdpa\",\n",
")\n",
"\n",
"tokenizer = get_chat_template(tokenizer, chat_template=\"qwen3\")\n",
"\n",
"model = FastLanguageModel.get_peft_model(\n",
" model,\n",
" r=64,\n",
" target_modules=[\n",
" \"q_proj\",\n",
" \"k_proj\",\n",
" \"v_proj\",\n",
" \"o_proj\",\n",
" \"gate_proj\",\n",
" \"up_proj\",\n",
" \"down_proj\",\n",
" ],\n",
" lora_alpha=128,\n",
" lora_dropout=0,\n",
" bias=\"none\",\n",
" use_gradient_checkpointing=\"unsloth\",\n",
" random_state=3407,\n",
" use_rslora=False,\n",
" loftq_config=None,\n",
")\n",
"\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"all_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"trainable_params={trainable_params / 1e6:.1f}M\")\n",
"print(f\"trainable_ratio={100 * trainable_params / all_params:.2f}%\")"
],
"execution_count": null,
"outputs": [],
"id": "cbb58828"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Stream and format replay logs\n",
"This is the heart of the `v6` idea: we train directly from messy Pokemon Showdown replay logs instead of hand-written state summaries. The implementation below stays disk-safe by using `streaming=True` end to end and only keeps the final tutorial subset in memory."
],
"id": "aa1831b7"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from datasets import Dataset, load_dataset\n",
"\n",
"MIN_RATING = 1400\n",
"MAX_TRAIN_SAMPLES = 2000\n",
"SHUFFLE_BUFFER = 10_000\n",
"MAX_LOG_CHARS = 6000\n",
"\n",
"SYSTEM_TEMPLATE = (\n",
" \"You are a Pokemon Showdown battle AI. You play as {side}. \"\n",
" \"Given the battle log, output your next action. \"\n",
" \"Format: move <name> OR switch <name>. \"\n",
" \"Append terastallize if you terastallize this turn.\"\n",
")\n",
"\n",
"\n",
"def extract_winner_side(log_text):\n",
" winner = None\n",
" players = {}\n",
" for line in log_text.split(\"\\n\"):\n",
" parts = line.split(\"|\")\n",
" if len(parts) >= 4 and parts[1] == \"player\":\n",
" players[parts[2]] = parts[3]\n",
" if len(parts) >= 3 and parts[1] == \"win\":\n",
" winner = parts[2]\n",
" if not winner:\n",
" return None\n",
" for side, name in players.items():\n",
" if name == winner:\n",
" return side\n",
" return None\n",
"\n",
"\n",
"def format_sample(example):\n",
" log_text = example[\"log\"]\n",
" side = extract_winner_side(log_text)\n",
" if not side:\n",
" return {\"text\": \"\"}\n",
"\n",
" lines = log_text.strip().split(\"\\n\")\n",
" turn_positions = []\n",
" for i, line in enumerate(lines):\n",
" parts = line.split(\"|\")\n",
" if len(parts) >= 3 and parts[1] == \"turn\":\n",
" try:\n",
" turn_positions.append((int(parts[2]), i))\n",
" except ValueError:\n",
" pass\n",
"\n",
" if len(turn_positions) < 2:\n",
" return {\"text\": \"\"}\n",
"\n",
" target_turn_idx = 0 if len(turn_positions) == 2 else 1\n",
" _, turn_line_idx = turn_positions[target_turn_idx]\n",
" next_turn_line = turn_positions[target_turn_idx + 1][1] if target_turn_idx + 1 < len(turn_positions) else len(lines)\n",
"\n",
" action = None\n",
" for j in range(turn_line_idx + 1, next_turn_line):\n",
" parts = lines[j].split(\"|\")\n",
" if len(parts) < 4:\n",
" continue\n",
" if parts[1] == \"move\" and parts[2].startswith(f\"{side}a:\"):\n",
" tera = \"\"\n",
" start_look = max(0, j - 3)\n",
" end_look = min(len(lines), j + 3)\n",
" if any(\"terastallize\" in lines[k] and side in lines[k] for k in range(start_look, end_look)):\n",
" tera = \" terastallize\"\n",
" action = f\"move {parts[3]}{tera}\"\n",
" break\n",
" if parts[1] == \"switch\" and parts[2].startswith(f\"{side}a:\"):\n",
" pokemon = parts[2].split(\": \")[1] if \": \" in parts[2] else parts[2]\n",
" action = f\"switch {pokemon}\"\n",
" break\n",
"\n",
" if not action:\n",
" return {\"text\": \"\"}\n",
"\n",
" log_prefix = \"\\n\".join(lines[:turn_line_idx + 1])\n",
" if len(log_prefix) > MAX_LOG_CHARS:\n",
" return {\"text\": \"\"}\n",
"\n",
" messages = [\n",
" {\"role\": \"system\", \"content\": SYSTEM_TEMPLATE.format(side=side)},\n",
" {\"role\": \"user\", \"content\": log_prefix},\n",
" {\"role\": \"assistant\", \"content\": action},\n",
" ]\n",
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
" return {\"text\": text}\n",
"\n",
"\n",
"print(\"Streaming replay logs without materializing the full corpus...\")\n",
"stream = load_dataset(\n",
" \"milkkarten/pokemon-showdown-replays-merged\",\n",
" split=\"train\",\n",
" streaming=True,\n",
")\n",
"stream = stream.shuffle(seed=3407, buffer_size=SHUFFLE_BUFFER)\n",
"stream = stream.filter(\n",
" lambda row: \"gen9\" in str(row.get(\"formatid\") or \"\").lower()\n",
" and (row.get(\"rating\") or 0) >= MIN_RATING\n",
")\n",
"\n",
"train_samples = []\n",
"scanned = 0\n",
"for row in stream:\n",
" scanned += 1\n",
" formatted = format_sample(row)\n",
" if formatted[\"text\"]:\n",
" train_samples.append(formatted)\n",
" if len(train_samples) >= MAX_TRAIN_SAMPLES:\n",
" break\n",
"\n",
"if not train_samples:\n",
" raise RuntimeError(\"No training samples were collected. Check dataset access, filters, or disk/network configuration.\")\n",
"\n",
"train_dataset = Dataset.from_list(train_samples).shuffle(seed=3407)\n",
"print(f\"collected_examples={len(train_dataset)}\")\n",
"print(f\"replays_scanned={scanned}\")\n",
"print(train_dataset[0][\"text\"][:1000])"
],
"execution_count": null,
"outputs": [],
"id": "bcb82788"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Quick validation SFT\n",
"The goal of this step is not to fully optimize ladder strength. It is to show the `Agentic Alignment` moment clearly: after a short 50-step run, the model should stop behaving like a chat bot and start behaving like a strict action emitter."
],
"id": "6e28c2d3"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"from trl import SFTConfig, SFTTrainer\n",
"from unsloth import is_bfloat16_supported\n",
"\n",
"output_dir = \"outputs/pokemon_v6_release_demo\"\n",
"\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" train_dataset=train_dataset,\n",
" dataset_text_field=\"text\",\n",
" max_seq_length=max_seq_length,\n",
" packing=False,\n",
" args=SFTConfig(\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=5,\n",
" max_steps=50,\n",
" learning_rate=2e-4,\n",
" fp16=not is_bfloat16_supported(),\n",
" bf16=is_bfloat16_supported(),\n",
" logging_steps=10,\n",
" optim=\"adamw_torch\",\n",
" weight_decay=0.01,\n",
" lr_scheduler_type=\"linear\",\n",
" seed=3407,\n",
" output_dir=output_dir,\n",
" save_steps=50,\n",
" save_total_limit=2,\n",
" report_to=\"none\",\n",
" ),\n",
")\n",
"\n",
"trainer_stats = trainer.train()\n",
"print(trainer_stats)"
],
"execution_count": null,
"outputs": [],
"id": "6f813bf3"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Inference sanity check\n",
"A healthy post-SFT output should be compact and action-shaped. It does not need to be perfect after only 50 steps, but it should usually look like `move ...` or `switch ...` instead of free-form commentary."
],
"id": "ba5d43d6"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import re\n",
"\n",
"model.eval()\n",
"FastLanguageModel.for_inference(model)\n",
"\n",
"sys_msg = (\n",
" \"You are a Pokemon Showdown battle AI. You play as p2. \"\n",
" \"Given the battle log, output your next action. \"\n",
" \"Format: move <name> OR switch <name>. \"\n",
" \"Append terastallize if you terastallize this turn.\"\n",
")\n",
"user_msg = '''|player|p1|Player1|266|1500\n",
"|player|p2|Player2|1|1500\n",
"|teamsize|p1|6\n",
"|teamsize|p2|6\n",
"|gen|9\n",
"|tier|[Gen 9] OU\n",
"|\n",
"|start\n",
"|switch|p1a: Garchomp|Garchomp, M|100/100\n",
"|switch|p2a: Corviknight|Corviknight, M|100/100\n",
"|turn|1\n",
"|move|p1a: Garchomp|Earthquake|p2a: Corviknight\n",
"|-immune|p2a: Corviknight\n",
"|turn|2'''\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": sys_msg},\n",
" {\"role\": \"user\", \"content\": user_msg},\n",
"]\n",
"\n",
"text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"with torch.no_grad():\n",
" outputs = model.generate(\n",
" **inputs,\n",
" max_new_tokens=64,\n",
" do_sample=False,\n",
" temperature=0.1,\n",
" pad_token_id=tokenizer.eos_token_id,\n",
" use_cache=False,\n",
" )\n",
"\n",
"full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)\n",
"if \"<|im_start|>assistant\\n\" in full_response:\n",
" response = full_response.split(\"<|im_start|>assistant\\n\")[-1]\n",
"else:\n",
" response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)\n",
"response = response.replace(\"<|im_end|>\", \"\").strip()\n",
"response = re.sub(r\"(?s)^<think>.*?</think>\\s*\", \"\", response).strip()\n",
"match = re.search(r\"^(move|switch)\\b\", response, flags=re.IGNORECASE)\n",
"\n",
"print(\"--- AI AGENT PREDICTION ---\")\n",
"print(response)\n",
"print(f\"strict_action_format={bool(match)}\")"
],
"execution_count": null,
"outputs": [],
"id": "cd99efdd"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Export and publish\n",
"If you want a standalone merged checkpoint for sharing, export it into the same `outputs/` subtree to keep the project root tidy.\n",
"\n",
"```python\n",
"export_dir = f\"{output_dir}/merged_model\"\n",
"model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")\n",
"```\n",
"\n",
"Suggested Hugging Face commands for a tutorial-only release:\n",
"\n",
"```bash\n",
"hf auth whoami\n",
"hf repos create your-username/pokemon-showdown-agent-v6-demo --type model --exist-ok\n",
"hf upload-large-folder your-username/pokemon-showdown-agent-v6-demo outputs/pokemon_v6_release_demo/merged_model\n",
"hf upload your-username/pokemon-showdown-agent-v6-demo pokemon_agent_demo_notebook_v6_release.ipynb pokemon_showdown_agent_v6_release.ipynb\n",
"```\n",
"\n",
"The full production checkpoint remains available at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`."
],
"id": "10de5655"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Bonus: GRPO sketch for the next stage\n",
"After SFT, you can layer a lightweight GRPO loop on top to reward valid action formatting or better tactical choices. In this ROCm container, the current `trl` build pulls in extra optional dependencies such as `mergekit` and `llm_blender` when importing `GRPOTrainer`.\n",
"\n",
"To keep the main SFT environment stable, the next cell does **not** auto-install those packages. Instead, it checks whether the optional GRPO stack is available and cleanly skips the setup with an actionable message if it is not.\n",
"\n",
"Run this section only after the earlier SFT/data cells have completed, because it reuses `train_samples`, `Dataset`, and `is_bfloat16_supported()` from the main notebook flow."
],
"id": "d0fb4c3e"
},
{
"cell_type": "code",
"metadata": {},
"source": [
"import importlib.util\n",
"\n",
"grpo_trainer = None\n",
"missing_optional = [\n",
" package_name\n",
" for module_name, package_name in [(\"mergekit\", \"mergekit\"), (\"llm_blender\", \"llm-blender\")]\n",
" if importlib.util.find_spec(module_name) is None\n",
"]\n",
"\n",
"if missing_optional:\n",
" print(\"Skipping Bonus/GRPO setup because optional packages are missing.\")\n",
" print(\"Install them in a separate or disposable environment if you want this section:\")\n",
" print(f\" pip install {' '.join(missing_optional)}\")\n",
"else:\n",
" try:\n",
" from trl import GRPOConfig, GRPOTrainer\n",
" except Exception as exc:\n",
" print(\"Skipping Bonus/GRPO setup because TRL could not import GRPOTrainer cleanly.\")\n",
" print(repr(exc))\n",
" else:\n",
" def format_reward_func(prompts, completions, **kwargs):\n",
" rewards = []\n",
" for completion in completions:\n",
" text = completion[0][\"content\"] if isinstance(completion, list) else str(completion)\n",
" cmd_match = re.search(r\"^(move\\s+.+|switch\\s+.+)$\", text.strip(), re.IGNORECASE)\n",
" rewards.append(5.0 if cmd_match else -3.0)\n",
" return rewards\n",
"\n",
" grpo_prompts = [\n",
" {\"prompt\": sample[\"text\"].split(\"<|im_start|>assistant\")[0] + \"<|im_start|>assistant\\n\"}\n",
" for sample in train_samples[:50]\n",
" ]\n",
" grpo_dataset = Dataset.from_list(grpo_prompts)\n",
"\n",
" grpo_config = GRPOConfig(\n",
" output_dir=\"outputs/pokemon_v6_release_grpo\",\n",
" learning_rate=3e-6,\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=4,\n",
" num_generations=4,\n",
" max_completion_length=128,\n",
" temperature=1.3,\n",
" max_steps=10,\n",
" logging_steps=1,\n",
" report_to=\"none\",\n",
" fp16=not is_bfloat16_supported(),\n",
" bf16=is_bfloat16_supported(),\n",
" )\n",
"\n",
" grpo_trainer = GRPOTrainer(\n",
" model=model,\n",
" reward_funcs=[format_reward_func],\n",
" args=grpo_config,\n",
" train_dataset=grpo_dataset,\n",
" )\n",
" print(\"GRPO setup is ready. Uncomment `grpo_trainer.train()` when desired.\")\n",
"\n",
"# Uncomment when you want to test the RL extension path.\n",
"# if grpo_trainer is not None:\n",
"# grpo_trainer.train()"
],
"execution_count": null,
"outputs": [],
"id": "208bfacd"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}