初始化项目,由ModelHub XC社区提供模型
Model: GoldenGrapeGentleman1/pokemon-showdown-agent-v6 Source: Original Platform
This commit is contained in:
36
.gitattributes
vendored
Normal file
36
.gitattributes
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||||
176
README.md
Normal file
176
README.md
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
---
|
||||||
|
license: apache-2.0
|
||||||
|
base_model: Qwen/Qwen3-4B
|
||||||
|
library_name: transformers
|
||||||
|
pipeline_tag: text-generation
|
||||||
|
tags:
|
||||||
|
- unsloth
|
||||||
|
- trl
|
||||||
|
- sft
|
||||||
|
- qwen3
|
||||||
|
- pokemon-showdown
|
||||||
|
- game-ai
|
||||||
|
- rocm
|
||||||
|
- amd
|
||||||
|
language:
|
||||||
|
- en
|
||||||
|
---
|
||||||
|
|
||||||
|
# Pokemon Showdown Agent v6
|
||||||
|
|
||||||
|
`Pokemon Showdown Agent v6` is a `Qwen/Qwen3-4B` fine-tune for **next-action prediction from raw Pokemon Showdown replay logs**. Given a battle-log prefix and the side it controls, the model is trained to emit a short action command such as `move Earthquake` or `switch Corviknight`.
|
||||||
|
|
||||||
|
This release is the merged checkpoint from the `v6` pipeline built with **Unsloth + TRL + AMD ROCm**. The tutorial version of the workflow uses a much smaller streamed subset for fast reproduction; this model is the larger production-oriented artifact.
|
||||||
|
|
||||||
|
## What makes v6 different
|
||||||
|
|
||||||
|
- It learns directly from messy raw replay logs instead of hand-written state summaries.
|
||||||
|
- It targets a strict action format suitable for agent pipelines: `move [move-name]` or `switch [pokemon-name]`.
|
||||||
|
- It was developed around AMD ROCm workflows, with `bfloat16` recommended for stable inference.
|
||||||
|
|
||||||
|
## Official notebook
|
||||||
|
|
||||||
|
Use the cleaned release notebook `pokemon_showdown_agent_v6_release.ipynb` for the reproducible tutorial flow.
|
||||||
|
|
||||||
|
## Intended use
|
||||||
|
|
||||||
|
Use this model when you want to:
|
||||||
|
|
||||||
|
- Predict the next action from a raw Pokemon Showdown log prefix.
|
||||||
|
- Build a text-only battle agent or evaluation harness.
|
||||||
|
- Study agent alignment from real replay trajectories.
|
||||||
|
|
||||||
|
This model is **not** a full simulator policy by itself. For ladder play or automated battle loops, you still need legality checks, environment wrappers, and battle-state management outside the model.
|
||||||
|
|
||||||
|
## Prompt format
|
||||||
|
|
||||||
|
The model expects a chat-style prompt with:
|
||||||
|
|
||||||
|
- A `system` message specifying which side the model is playing as.
|
||||||
|
- A `user` message containing the raw replay log prefix up to the current turn marker.
|
||||||
|
|
||||||
|
Recommended system prompt:
|
||||||
|
|
||||||
|
```text
|
||||||
|
You are a Pokemon Showdown battle AI. You play as {side}. Given the battle log, output your next action. Format: move <name> OR switch <name>. Append terastallize if you terastallize this turn.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
repo_id = "GoldenGrapeGentleman1/pokemon-showdown-agent-v6"
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
repo_id,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You are a Pokemon Showdown battle AI. You play as p2. "
|
||||||
|
"Given the battle log, output your next action. "
|
||||||
|
"Format: move <name> OR switch <name>. "
|
||||||
|
"Append terastallize if you terastallize this turn."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"|player|p1|Player1|266|1500\n"
|
||||||
|
"|player|p2|Player2|1|1500\n"
|
||||||
|
"|teamsize|p1|6\n"
|
||||||
|
"|teamsize|p2|6\n"
|
||||||
|
"|gen|9\n"
|
||||||
|
"|tier|[Gen 9] OU\n"
|
||||||
|
"|\n"
|
||||||
|
"|start\n"
|
||||||
|
"|switch|p1a: Garchomp|Garchomp, M|100/100\n"
|
||||||
|
"|switch|p2a: Corviknight|Corviknight, M|100/100\n"
|
||||||
|
"|turn|1\n"
|
||||||
|
"|move|p1a: Garchomp|Earthquake|p2a: Corviknight\n"
|
||||||
|
"|-immune|p2a: Corviknight\n"
|
||||||
|
"|turn|2"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=64,
|
||||||
|
do_sample=False,
|
||||||
|
temperature=0.1,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||||
|
response = decoded.split("<|im_start|>assistant\n")[-1].replace("<|im_end|>", "").strip()
|
||||||
|
if response.startswith("<think>"):
|
||||||
|
response = response.split("</think>", 1)[-1].strip()
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training data
|
||||||
|
|
||||||
|
The full `v6` preprocessing pipeline was built from the public dataset:
|
||||||
|
|
||||||
|
- Source dataset: [`milkkarten/pokemon-showdown-replays-merged`](https://huggingface.co/datasets/milkkarten/pokemon-showdown-replays-merged)
|
||||||
|
|
||||||
|
Project preprocessing summary:
|
||||||
|
|
||||||
|
- `100,000` train games
|
||||||
|
- `10,000` test games
|
||||||
|
- `2,330,115` train samples
|
||||||
|
- `236,349` test samples
|
||||||
|
- `min_rating = 1200`
|
||||||
|
- `max_chars = 12000`
|
||||||
|
|
||||||
|
The companion tutorial notebook uses a smaller streamed subset with a higher rating filter so readers can reproduce the workflow quickly without downloading the full corpus.
|
||||||
|
|
||||||
|
## Training recipe
|
||||||
|
|
||||||
|
- Base model: `Qwen/Qwen3-4B`
|
||||||
|
- Fine-tuning method: LoRA SFT with Unsloth
|
||||||
|
- LoRA rank / alpha: `64 / 128`
|
||||||
|
- Full training context length: up to `4096`
|
||||||
|
- Frameworks: Unsloth, TRL, Transformers, Datasets
|
||||||
|
- Deployment recommendation on AMD: prefer `bfloat16` inference for stability
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- This is a research checkpoint, not a complete battle engine.
|
||||||
|
- The model can still produce illegal or strategically weak actions.
|
||||||
|
- Prompt wording matters; changing the system format can reduce output reliability.
|
||||||
|
- Included evaluation artifacts are sanity checks, not a full competitive benchmark.
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
- [Unsloth](https://github.com/unslothai/unsloth)
|
||||||
|
- [TRL](https://github.com/huggingface/trl)
|
||||||
|
- [Qwen](https://huggingface.co/Qwen)
|
||||||
|
- [Pokemon Showdown](https://pokemonshowdown.com/)
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you build on this work, please cite the upstream tooling as well:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{vonwerra2022trl,
|
||||||
|
title = {{TRL: Transformer Reinforcement Learning}},
|
||||||
|
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
|
||||||
|
year = 2020,
|
||||||
|
journal = {GitHub repository},
|
||||||
|
publisher = {GitHub},
|
||||||
|
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||||
|
}
|
||||||
|
```
|
||||||
28
added_tokens.json
Normal file
28
added_tokens.json
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{
|
||||||
|
"</think>": 151668,
|
||||||
|
"</tool_call>": 151658,
|
||||||
|
"</tool_response>": 151666,
|
||||||
|
"<think>": 151667,
|
||||||
|
"<tool_call>": 151657,
|
||||||
|
"<tool_response>": 151665,
|
||||||
|
"<|box_end|>": 151649,
|
||||||
|
"<|box_start|>": 151648,
|
||||||
|
"<|endoftext|>": 151643,
|
||||||
|
"<|file_sep|>": 151664,
|
||||||
|
"<|fim_middle|>": 151660,
|
||||||
|
"<|fim_pad|>": 151662,
|
||||||
|
"<|fim_prefix|>": 151659,
|
||||||
|
"<|fim_suffix|>": 151661,
|
||||||
|
"<|im_end|>": 151645,
|
||||||
|
"<|im_start|>": 151644,
|
||||||
|
"<|image_pad|>": 151655,
|
||||||
|
"<|object_ref_end|>": 151647,
|
||||||
|
"<|object_ref_start|>": 151646,
|
||||||
|
"<|quad_end|>": 151651,
|
||||||
|
"<|quad_start|>": 151650,
|
||||||
|
"<|repo_name|>": 151663,
|
||||||
|
"<|video_pad|>": 151656,
|
||||||
|
"<|vision_end|>": 151653,
|
||||||
|
"<|vision_pad|>": 151654,
|
||||||
|
"<|vision_start|>": 151652
|
||||||
|
}
|
||||||
144
chat_template.jinja
Normal file
144
chat_template.jinja
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
|
||||||
|
{%- if tools %}
|
||||||
|
{{- '<|im_start|>system
|
||||||
|
' }}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- messages[0].content + '
|
||||||
|
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "
|
||||||
|
" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{\"name\": <function-name>, \"arguments\": <args-json-object>}
|
||||||
|
</tool_call><|im_end|>
|
||||||
|
" }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- '<|im_start|>system
|
||||||
|
' + messages[0].content + '<|im_end|>
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
|
{%- for forward_message in messages %}
|
||||||
|
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||||
|
{%- set message = messages[index] %}
|
||||||
|
{%- set current_content = message.content if message.content is not none else '' %}
|
||||||
|
{%- set tool_start = '<tool_response>' %}
|
||||||
|
{%- set tool_start_length = tool_start|length %}
|
||||||
|
{%- set start_of_message = current_content[:tool_start_length] %}
|
||||||
|
{%- set tool_end = '</tool_response>' %}
|
||||||
|
{%- set tool_end_length = tool_end|length %}
|
||||||
|
{%- set start_pos = (current_content|length) - tool_end_length %}
|
||||||
|
{%- if start_pos < 0 %}
|
||||||
|
{%- set start_pos = 0 %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set end_of_message = current_content[start_pos:] %}
|
||||||
|
{%- if ns.multi_step_tool and message.role == "user" and not(start_of_message == tool_start and end_of_message == tool_end) %}
|
||||||
|
{%- set ns.multi_step_tool = false %}
|
||||||
|
{%- set ns.last_query_index = index %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '
|
||||||
|
' + message.content + '<|im_end|>' + '
|
||||||
|
' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- if message.reasoning_content is defined and message.reasoning_content is not none %}
|
||||||
|
{%- set reasoning_content = message.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in message.content %}
|
||||||
|
{%- set content = (message.content.split('</think>')|last).lstrip('
|
||||||
|
') %}
|
||||||
|
{%- set reasoning_content = (message.content.split('</think>')|first).rstrip('
|
||||||
|
') %}
|
||||||
|
{%- set reasoning_content = (reasoning_content.split('<think>')|last).lstrip('
|
||||||
|
') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if loop.index0 > ns.last_query_index %}
|
||||||
|
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '
|
||||||
|
<think>
|
||||||
|
' + reasoning_content.strip('
|
||||||
|
') + '
|
||||||
|
</think>
|
||||||
|
|
||||||
|
' + content.lstrip('
|
||||||
|
') }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '
|
||||||
|
' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '
|
||||||
|
' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message.tool_calls %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if (loop.first and content) or (not loop.first) %}
|
||||||
|
{{- '
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tool_call.function %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<tool_call>
|
||||||
|
{"name": "' }}
|
||||||
|
{{- tool_call.name }}
|
||||||
|
{{- '", "arguments": ' }}
|
||||||
|
{%- if tool_call.arguments is string %}
|
||||||
|
{{- tool_call.arguments }}
|
||||||
|
{%- else %}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '}
|
||||||
|
</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>
|
||||||
|
' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '
|
||||||
|
<tool_response>
|
||||||
|
' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '
|
||||||
|
</tool_response>' }}
|
||||||
|
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||||
|
{{- '<|im_end|>
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant
|
||||||
|
' }}
|
||||||
|
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||||
|
{{- '<think>
|
||||||
|
|
||||||
|
</think>
|
||||||
|
|
||||||
|
' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
70
config.json
Normal file
70
config.json
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"Qwen3ForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 2560,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 9728,
|
||||||
|
"layer_types": [
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention",
|
||||||
|
"full_attention"
|
||||||
|
],
|
||||||
|
"max_position_embeddings": 40960,
|
||||||
|
"max_window_layers": 36,
|
||||||
|
"model_type": "qwen3",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 36,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"pad_token_id": 151654,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": null,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"transformers_version": "4.57.6",
|
||||||
|
"unsloth_fixed": true,
|
||||||
|
"unsloth_version": "2026.1.4",
|
||||||
|
"use_cache": true,
|
||||||
|
"use_sliding_window": false,
|
||||||
|
"vocab_size": 151936
|
||||||
|
}
|
||||||
151388
merges.txt
Normal file
151388
merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
3
model-00001-of-00002.safetensors
Normal file
3
model-00001-of-00002.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:3f614af718e39ac60bb53e6b488a837ddd974bb38707ff2266020447b23a1599
|
||||||
|
size 4967215360
|
||||||
3
model-00002-of-00002.safetensors
Normal file
3
model-00002-of-00002.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:f94aba14456057cf3d20b43a3b0dd76656942fa4a042cfa9e60f4bcc8b1cd6f2
|
||||||
|
size 3077766632
|
||||||
405
model.safetensors.index.json
Normal file
405
model.safetensors.index.json
Normal file
@@ -0,0 +1,405 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"total_size": 8044936192
|
||||||
|
},
|
||||||
|
"weight_map": {
|
||||||
|
"model.embed_tokens.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.13.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.14.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.15.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.16.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.17.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.18.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.19.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.20.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.20.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.25.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.26.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.27.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.28.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.29.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.30.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.30.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.31.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.32.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.33.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.34.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.input_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.mlp.gate_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.mlp.up_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.k_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.k_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.q_norm.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.q_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.35.self_attn.v_proj.weight": "model-00002-of-00002.safetensors",
|
||||||
|
"model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.k_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.q_norm.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00002.safetensors",
|
||||||
|
"model.norm.weight": "model-00002-of-00002.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
669
pokemon_showdown_agent_v6_release.ipynb
Normal file
669
pokemon_showdown_agent_v6_release.ipynb
Normal file
@@ -0,0 +1,669 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Pokemon Showdown Agent v6 on AMD ROCm\n",
|
||||||
|
"\n",
|
||||||
|
"**Author:** Yueyuan (GoldenGrapeGentleman)\n",
|
||||||
|
"\n",
|
||||||
|
"This is a clean release notebook for the `v6` Pokemon Showdown agent workflow. It keeps the original notebook untouched and rebuilds the tutorial from scratch with a stricter engineering standard for AMD ROCm users.\n",
|
||||||
|
"\n",
|
||||||
|
"## What this release notebook fixes\n",
|
||||||
|
"1. Dependency installation happens before any third-party imports.\n",
|
||||||
|
"2. Cache and temp paths fall back safely when `/shared-docker` is unavailable.\n",
|
||||||
|
"3. Data preparation uses streaming and never materializes the full replay corpus locally.\n",
|
||||||
|
"4. The tutorial keeps `load_in_4bit=False` for ROCm inference stability and uses `adamw_torch` instead of 8-bit optimizers.\n",
|
||||||
|
"5. All outputs are intentionally cleared so readers only see results produced in their own environment.\n",
|
||||||
|
"\n",
|
||||||
|
"## Recommended workflow\n",
|
||||||
|
"1. Run the install cell once.\n",
|
||||||
|
"2. If packages were installed, rerun the notebook from the top.\n",
|
||||||
|
"3. Execute the remaining cells in order.\n",
|
||||||
|
"\n",
|
||||||
|
"## Published production model\n",
|
||||||
|
"The full production checkpoint is already available at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`.\n",
|
||||||
|
"\n",
|
||||||
|
"## Tutorial scope\n",
|
||||||
|
"This notebook intentionally trains on a small streamed subset so the workflow stays teachable, reproducible, and disk-safe on AMD systems."
|
||||||
|
],
|
||||||
|
"id": "18a89f18"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import platform\n",
|
||||||
|
"import shutil\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def pick_writable_dir(candidates):\n",
|
||||||
|
" for candidate in candidates:\n",
|
||||||
|
" path = Path(candidate)\n",
|
||||||
|
" try:\n",
|
||||||
|
" path.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
" probe = path / \".cursor_write_test\"\n",
|
||||||
|
" probe.write_text(\"ok\", encoding=\"utf-8\")\n",
|
||||||
|
" probe.unlink()\n",
|
||||||
|
" return path\n",
|
||||||
|
" except Exception:\n",
|
||||||
|
" continue\n",
|
||||||
|
" raise RuntimeError(f\"No writable directory found in: {candidates}\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"cache_root = pick_writable_dir([\n",
|
||||||
|
" \"/shared-docker/.cache/huggingface\",\n",
|
||||||
|
" \"/data/huggingface\",\n",
|
||||||
|
" \"/tmp/pokemon-hf-cache\",\n",
|
||||||
|
"])\n",
|
||||||
|
"tmp_root = pick_writable_dir([\n",
|
||||||
|
" \"/shared-docker/.cache/tmp\",\n",
|
||||||
|
" \"/data/tmp\",\n",
|
||||||
|
" \"/tmp/pokemon-tmp\",\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ.setdefault(\"HIP_VISIBLE_DEVICES\", \"0\")\n",
|
||||||
|
"os.environ.setdefault(\"ROCR_VISIBLE_DEVICES\", os.environ[\"HIP_VISIBLE_DEVICES\"])\n",
|
||||||
|
"os.environ.setdefault(\"TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL\", \"1\")\n",
|
||||||
|
"os.environ.setdefault(\"PYTORCH_HIP_ALLOC_CONF\", \"expandable_segments:False\")\n",
|
||||||
|
"os.environ.setdefault(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"1\")\n",
|
||||||
|
"os.environ.setdefault(\"HF_HUB_ENABLE_HF_TRANSFER\", \"1\")\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"HF_HOME\"] = str(cache_root)\n",
|
||||||
|
"os.environ[\"HF_DATASETS_CACHE\"] = str(cache_root / \"datasets\")\n",
|
||||||
|
"os.environ[\"TMPDIR\"] = str(tmp_root)\n",
|
||||||
|
"\n",
|
||||||
|
"Path(os.environ[\"HF_DATASETS_CACHE\"]).mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"Path(os.environ[\"TMPDIR\"]).mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"cache_usage = shutil.disk_usage(cache_root)\n",
|
||||||
|
"print(f\"platform={platform.platform()}\")\n",
|
||||||
|
"print(f\"HF_HOME={os.environ['HF_HOME']}\")\n",
|
||||||
|
"print(f\"HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}\")\n",
|
||||||
|
"print(f\"TMPDIR={os.environ['TMPDIR']}\")\n",
|
||||||
|
"print(f\"HIP_VISIBLE_DEVICES={os.environ['HIP_VISIBLE_DEVICES']}\")\n",
|
||||||
|
"print(f\"free_space_gb={cache_usage.free / (1024 ** 3):.1f}\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "2e2615d3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. Install the runtime stack\n",
|
||||||
|
"This cell intentionally uses only the Python standard library before installation. It checks for missing packages, reports the versions already present in the environment, and only installs core dependencies when they are missing.\n",
|
||||||
|
"\n",
|
||||||
|
"On a prebuilt AMD ROCm container, preserving the existing working stack is usually safer than forcing package downgrades inside the notebook.\n",
|
||||||
|
"\n",
|
||||||
|
"If packages are installed into a fresh environment, rerun the notebook from the top before continuing."
|
||||||
|
],
|
||||||
|
"id": "bd338c40"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import importlib.metadata\n",
|
||||||
|
"import importlib.util\n",
|
||||||
|
"import subprocess\n",
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"core_runtime = [\n",
|
||||||
|
" \"transformers\",\n",
|
||||||
|
" \"trl\",\n",
|
||||||
|
" \"datasets\",\n",
|
||||||
|
" \"tokenizers\",\n",
|
||||||
|
" \"huggingface_hub\",\n",
|
||||||
|
" \"accelerate\",\n",
|
||||||
|
" \"peft\",\n",
|
||||||
|
" \"psutil\",\n",
|
||||||
|
" \"sentencepiece\",\n",
|
||||||
|
" \"protobuf\",\n",
|
||||||
|
" \"tyro\",\n",
|
||||||
|
" \"hf_transfer\",\n",
|
||||||
|
" \"einops\",\n",
|
||||||
|
"]\n",
|
||||||
|
"required_modules = [\n",
|
||||||
|
" \"unsloth\",\n",
|
||||||
|
" \"unsloth_zoo\",\n",
|
||||||
|
" \"transformers\",\n",
|
||||||
|
" \"trl\",\n",
|
||||||
|
" \"datasets\",\n",
|
||||||
|
" \"tokenizers\",\n",
|
||||||
|
" \"huggingface_hub\",\n",
|
||||||
|
" \"accelerate\",\n",
|
||||||
|
" \"peft\",\n",
|
||||||
|
" \"psutil\",\n",
|
||||||
|
" \"sentencepiece\",\n",
|
||||||
|
" \"google.protobuf\",\n",
|
||||||
|
" \"tyro\",\n",
|
||||||
|
" \"einops\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"missing = [name for name in required_modules if importlib.util.find_spec(name) is None]\n",
|
||||||
|
"\n",
|
||||||
|
"if missing:\n",
|
||||||
|
" print(f\"Installing missing packages: {missing}\")\n",
|
||||||
|
" subprocess.check_call([\n",
|
||||||
|
" sys.executable,\n",
|
||||||
|
" \"-m\",\n",
|
||||||
|
" \"pip\",\n",
|
||||||
|
" \"install\",\n",
|
||||||
|
" \"--no-cache-dir\",\n",
|
||||||
|
" \"--no-deps\",\n",
|
||||||
|
" \"unsloth\",\n",
|
||||||
|
" \"unsloth_zoo\",\n",
|
||||||
|
" ])\n",
|
||||||
|
" subprocess.check_call([\n",
|
||||||
|
" sys.executable,\n",
|
||||||
|
" \"-m\",\n",
|
||||||
|
" \"pip\",\n",
|
||||||
|
" \"install\",\n",
|
||||||
|
" \"--no-cache-dir\",\n",
|
||||||
|
" *core_runtime,\n",
|
||||||
|
" ])\n",
|
||||||
|
" print(\"Install completed. Rerun the notebook from the top if this is a fresh environment.\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"All required packages are already available.\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Detected package versions:\")\n",
|
||||||
|
"for package_name in [\"unsloth\", \"transformers\", \"trl\", \"datasets\", \"tokenizers\", \"huggingface_hub\"]:\n",
|
||||||
|
" try:\n",
|
||||||
|
" print(f\" {package_name}={importlib.metadata.version(package_name)}\")\n",
|
||||||
|
" except importlib.metadata.PackageNotFoundError:\n",
|
||||||
|
" print(f\" {package_name}=missing\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "fa699cd4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Validate the AMD ROCm runtime\n",
|
||||||
|
"The next cell imports the runtime in an Unsloth-safe order, reports the key package versions, and confirms that the environment can see a ROCm-backed GPU before we touch the model or dataset pipeline."
|
||||||
|
],
|
||||||
|
"id": "f74ebb30"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import unsloth\n",
|
||||||
|
"import datasets\n",
|
||||||
|
"import huggingface_hub\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import transformers\n",
|
||||||
|
"import trl\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"datasets.config.HF_DATASETS_CACHE = os.environ[\"HF_DATASETS_CACHE\"]\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"torch={torch.__version__}\")\n",
|
||||||
|
"print(f\"transformers={transformers.__version__}\")\n",
|
||||||
|
"print(f\"trl={trl.__version__}\")\n",
|
||||||
|
"print(f\"datasets={datasets.__version__}\")\n",
|
||||||
|
"print(f\"huggingface_hub={huggingface_hub.__version__}\")\n",
|
||||||
|
"print(f\"unsloth={getattr(unsloth, '__version__', 'unknown')}\")\n",
|
||||||
|
"print(f\"hip_version={getattr(torch.version, 'hip', None)}\")\n",
|
||||||
|
"\n",
|
||||||
|
"if not torch.cuda.is_available():\n",
|
||||||
|
" raise RuntimeError(\"No ROCm-visible GPU found. This notebook expects an AMD GPU environment.\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"gpu_count={torch.cuda.device_count()}\")\n",
|
||||||
|
"print(f\"gpu_name={torch.cuda.get_device_name(0)}\")\n",
|
||||||
|
"print(f\"bf16_supported={torch.cuda.is_bf16_supported()}\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "1e4fcdaa"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3. Load Qwen3-4B and attach LoRA adapters\n",
|
||||||
|
"For the public AMD tutorial we keep `load_in_4bit=False` because the goal is a stable, reproducible path on ROCm. The production project can still explore tighter memory modes later, but the release notebook should prefer correctness and stability first."
|
||||||
|
],
|
||||||
|
"id": "d6c38fd1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from unsloth import FastLanguageModel\n",
|
||||||
|
"from unsloth.chat_templates import get_chat_template\n",
|
||||||
|
"\n",
|
||||||
|
"max_seq_length = 2048\n",
|
||||||
|
"dtype = torch.bfloat16\n",
|
||||||
|
"load_in_4bit = False\n",
|
||||||
|
"\n",
|
||||||
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||||
|
" model_name=\"Qwen/Qwen3-4B\",\n",
|
||||||
|
" max_seq_length=max_seq_length,\n",
|
||||||
|
" dtype=dtype,\n",
|
||||||
|
" load_in_4bit=load_in_4bit,\n",
|
||||||
|
" attn_implementation=\"sdpa\",\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = get_chat_template(tokenizer, chat_template=\"qwen3\")\n",
|
||||||
|
"\n",
|
||||||
|
"model = FastLanguageModel.get_peft_model(\n",
|
||||||
|
" model,\n",
|
||||||
|
" r=64,\n",
|
||||||
|
" target_modules=[\n",
|
||||||
|
" \"q_proj\",\n",
|
||||||
|
" \"k_proj\",\n",
|
||||||
|
" \"v_proj\",\n",
|
||||||
|
" \"o_proj\",\n",
|
||||||
|
" \"gate_proj\",\n",
|
||||||
|
" \"up_proj\",\n",
|
||||||
|
" \"down_proj\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" lora_alpha=128,\n",
|
||||||
|
" lora_dropout=0,\n",
|
||||||
|
" bias=\"none\",\n",
|
||||||
|
" use_gradient_checkpointing=\"unsloth\",\n",
|
||||||
|
" random_state=3407,\n",
|
||||||
|
" use_rslora=False,\n",
|
||||||
|
" loftq_config=None,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||||
|
"all_params = sum(p.numel() for p in model.parameters())\n",
|
||||||
|
"print(f\"trainable_params={trainable_params / 1e6:.1f}M\")\n",
|
||||||
|
"print(f\"trainable_ratio={100 * trainable_params / all_params:.2f}%\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "cbb58828"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 4. Stream and format replay logs\n",
|
||||||
|
"This is the heart of the `v6` idea: we train directly from messy Pokemon Showdown replay logs instead of hand-written state summaries. The implementation below stays disk-safe by using `streaming=True` end to end and only keeps the final tutorial subset in memory."
|
||||||
|
],
|
||||||
|
"id": "aa1831b7"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from datasets import Dataset, load_dataset\n",
|
||||||
|
"\n",
|
||||||
|
"MIN_RATING = 1400\n",
|
||||||
|
"MAX_TRAIN_SAMPLES = 2000\n",
|
||||||
|
"SHUFFLE_BUFFER = 10_000\n",
|
||||||
|
"MAX_LOG_CHARS = 6000\n",
|
||||||
|
"\n",
|
||||||
|
"SYSTEM_TEMPLATE = (\n",
|
||||||
|
" \"You are a Pokemon Showdown battle AI. You play as {side}. \"\n",
|
||||||
|
" \"Given the battle log, output your next action. \"\n",
|
||||||
|
" \"Format: move <name> OR switch <name>. \"\n",
|
||||||
|
" \"Append terastallize if you terastallize this turn.\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def extract_winner_side(log_text):\n",
|
||||||
|
" winner = None\n",
|
||||||
|
" players = {}\n",
|
||||||
|
" for line in log_text.split(\"\\n\"):\n",
|
||||||
|
" parts = line.split(\"|\")\n",
|
||||||
|
" if len(parts) >= 4 and parts[1] == \"player\":\n",
|
||||||
|
" players[parts[2]] = parts[3]\n",
|
||||||
|
" if len(parts) >= 3 and parts[1] == \"win\":\n",
|
||||||
|
" winner = parts[2]\n",
|
||||||
|
" if not winner:\n",
|
||||||
|
" return None\n",
|
||||||
|
" for side, name in players.items():\n",
|
||||||
|
" if name == winner:\n",
|
||||||
|
" return side\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def format_sample(example):\n",
|
||||||
|
" log_text = example[\"log\"]\n",
|
||||||
|
" side = extract_winner_side(log_text)\n",
|
||||||
|
" if not side:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" lines = log_text.strip().split(\"\\n\")\n",
|
||||||
|
" turn_positions = []\n",
|
||||||
|
" for i, line in enumerate(lines):\n",
|
||||||
|
" parts = line.split(\"|\")\n",
|
||||||
|
" if len(parts) >= 3 and parts[1] == \"turn\":\n",
|
||||||
|
" try:\n",
|
||||||
|
" turn_positions.append((int(parts[2]), i))\n",
|
||||||
|
" except ValueError:\n",
|
||||||
|
" pass\n",
|
||||||
|
"\n",
|
||||||
|
" if len(turn_positions) < 2:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" target_turn_idx = 0 if len(turn_positions) == 2 else 1\n",
|
||||||
|
" _, turn_line_idx = turn_positions[target_turn_idx]\n",
|
||||||
|
" next_turn_line = turn_positions[target_turn_idx + 1][1] if target_turn_idx + 1 < len(turn_positions) else len(lines)\n",
|
||||||
|
"\n",
|
||||||
|
" action = None\n",
|
||||||
|
" for j in range(turn_line_idx + 1, next_turn_line):\n",
|
||||||
|
" parts = lines[j].split(\"|\")\n",
|
||||||
|
" if len(parts) < 4:\n",
|
||||||
|
" continue\n",
|
||||||
|
" if parts[1] == \"move\" and parts[2].startswith(f\"{side}a:\"):\n",
|
||||||
|
" tera = \"\"\n",
|
||||||
|
" start_look = max(0, j - 3)\n",
|
||||||
|
" end_look = min(len(lines), j + 3)\n",
|
||||||
|
" if any(\"terastallize\" in lines[k] and side in lines[k] for k in range(start_look, end_look)):\n",
|
||||||
|
" tera = \" terastallize\"\n",
|
||||||
|
" action = f\"move {parts[3]}{tera}\"\n",
|
||||||
|
" break\n",
|
||||||
|
" if parts[1] == \"switch\" and parts[2].startswith(f\"{side}a:\"):\n",
|
||||||
|
" pokemon = parts[2].split(\": \")[1] if \": \" in parts[2] else parts[2]\n",
|
||||||
|
" action = f\"switch {pokemon}\"\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" if not action:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" log_prefix = \"\\n\".join(lines[:turn_line_idx + 1])\n",
|
||||||
|
" if len(log_prefix) > MAX_LOG_CHARS:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" messages = [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": SYSTEM_TEMPLATE.format(side=side)},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": log_prefix},\n",
|
||||||
|
" {\"role\": \"assistant\", \"content\": action},\n",
|
||||||
|
" ]\n",
|
||||||
|
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
|
||||||
|
" return {\"text\": text}\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Streaming replay logs without materializing the full corpus...\")\n",
|
||||||
|
"stream = load_dataset(\n",
|
||||||
|
" \"milkkarten/pokemon-showdown-replays-merged\",\n",
|
||||||
|
" split=\"train\",\n",
|
||||||
|
" streaming=True,\n",
|
||||||
|
")\n",
|
||||||
|
"stream = stream.shuffle(seed=3407, buffer_size=SHUFFLE_BUFFER)\n",
|
||||||
|
"stream = stream.filter(\n",
|
||||||
|
" lambda row: \"gen9\" in str(row.get(\"formatid\") or \"\").lower()\n",
|
||||||
|
" and (row.get(\"rating\") or 0) >= MIN_RATING\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"train_samples = []\n",
|
||||||
|
"scanned = 0\n",
|
||||||
|
"for row in stream:\n",
|
||||||
|
" scanned += 1\n",
|
||||||
|
" formatted = format_sample(row)\n",
|
||||||
|
" if formatted[\"text\"]:\n",
|
||||||
|
" train_samples.append(formatted)\n",
|
||||||
|
" if len(train_samples) >= MAX_TRAIN_SAMPLES:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
"if not train_samples:\n",
|
||||||
|
" raise RuntimeError(\"No training samples were collected. Check dataset access, filters, or disk/network configuration.\")\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = Dataset.from_list(train_samples).shuffle(seed=3407)\n",
|
||||||
|
"print(f\"collected_examples={len(train_dataset)}\")\n",
|
||||||
|
"print(f\"replays_scanned={scanned}\")\n",
|
||||||
|
"print(train_dataset[0][\"text\"][:1000])"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "bcb82788"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 5. Quick validation SFT\n",
|
||||||
|
"The goal of this step is not to fully optimize ladder strength. It is to show the `Agentic Alignment` moment clearly: after a short 50-step run, the model should stop behaving like a chat bot and start behaving like a strict action emitter."
|
||||||
|
],
|
||||||
|
"id": "6e28c2d3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from trl import SFTConfig, SFTTrainer\n",
|
||||||
|
"from unsloth import is_bfloat16_supported\n",
|
||||||
|
"\n",
|
||||||
|
"output_dir = \"outputs/pokemon_v6_release_demo\"\n",
|
||||||
|
"\n",
|
||||||
|
"trainer = SFTTrainer(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" tokenizer=tokenizer,\n",
|
||||||
|
" train_dataset=train_dataset,\n",
|
||||||
|
" dataset_text_field=\"text\",\n",
|
||||||
|
" max_seq_length=max_seq_length,\n",
|
||||||
|
" packing=False,\n",
|
||||||
|
" args=SFTConfig(\n",
|
||||||
|
" per_device_train_batch_size=2,\n",
|
||||||
|
" gradient_accumulation_steps=4,\n",
|
||||||
|
" warmup_steps=5,\n",
|
||||||
|
" max_steps=50,\n",
|
||||||
|
" learning_rate=2e-4,\n",
|
||||||
|
" fp16=not is_bfloat16_supported(),\n",
|
||||||
|
" bf16=is_bfloat16_supported(),\n",
|
||||||
|
" logging_steps=10,\n",
|
||||||
|
" optim=\"adamw_torch\",\n",
|
||||||
|
" weight_decay=0.01,\n",
|
||||||
|
" lr_scheduler_type=\"linear\",\n",
|
||||||
|
" seed=3407,\n",
|
||||||
|
" output_dir=output_dir,\n",
|
||||||
|
" save_steps=50,\n",
|
||||||
|
" save_total_limit=2,\n",
|
||||||
|
" report_to=\"none\",\n",
|
||||||
|
" ),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer_stats = trainer.train()\n",
|
||||||
|
"print(trainer_stats)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "6f813bf3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 6. Inference sanity check\n",
|
||||||
|
"A healthy post-SFT output should be compact and action-shaped. It does not need to be perfect after only 50 steps, but it should usually look like `move ...` or `switch ...` instead of free-form commentary."
|
||||||
|
],
|
||||||
|
"id": "ba5d43d6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"model.eval()\n",
|
||||||
|
"FastLanguageModel.for_inference(model)\n",
|
||||||
|
"\n",
|
||||||
|
"sys_msg = (\n",
|
||||||
|
" \"You are a Pokemon Showdown battle AI. You play as p2. \"\n",
|
||||||
|
" \"Given the battle log, output your next action. \"\n",
|
||||||
|
" \"Format: move <name> OR switch <name>. \"\n",
|
||||||
|
" \"Append terastallize if you terastallize this turn.\"\n",
|
||||||
|
")\n",
|
||||||
|
"user_msg = '''|player|p1|Player1|266|1500\n",
|
||||||
|
"|player|p2|Player2|1|1500\n",
|
||||||
|
"|teamsize|p1|6\n",
|
||||||
|
"|teamsize|p2|6\n",
|
||||||
|
"|gen|9\n",
|
||||||
|
"|tier|[Gen 9] OU\n",
|
||||||
|
"|\n",
|
||||||
|
"|start\n",
|
||||||
|
"|switch|p1a: Garchomp|Garchomp, M|100/100\n",
|
||||||
|
"|switch|p2a: Corviknight|Corviknight, M|100/100\n",
|
||||||
|
"|turn|1\n",
|
||||||
|
"|move|p1a: Garchomp|Earthquake|p2a: Corviknight\n",
|
||||||
|
"|-immune|p2a: Corviknight\n",
|
||||||
|
"|turn|2'''\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": sys_msg},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_msg},\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
||||||
|
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" outputs = model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=64,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" pad_token_id=tokenizer.eos_token_id,\n",
|
||||||
|
" use_cache=False,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)\n",
|
||||||
|
"if \"<|im_start|>assistant\\n\" in full_response:\n",
|
||||||
|
" response = full_response.split(\"<|im_start|>assistant\\n\")[-1]\n",
|
||||||
|
"else:\n",
|
||||||
|
" response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)\n",
|
||||||
|
"response = response.replace(\"<|im_end|>\", \"\").strip()\n",
|
||||||
|
"response = re.sub(r\"(?s)^<think>.*?</think>\\s*\", \"\", response).strip()\n",
|
||||||
|
"match = re.search(r\"^(move|switch)\\b\", response, flags=re.IGNORECASE)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"--- AI AGENT PREDICTION ---\")\n",
|
||||||
|
"print(response)\n",
|
||||||
|
"print(f\"strict_action_format={bool(match)}\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "cd99efdd"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 7. Export and publish\n",
|
||||||
|
"If you want a standalone merged checkpoint for sharing, export it into the same `outputs/` subtree to keep the project root tidy.\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"export_dir = f\"{output_dir}/merged_model\"\n",
|
||||||
|
"model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Suggested Hugging Face commands for a tutorial-only release:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"hf auth whoami\n",
|
||||||
|
"hf repos create your-username/pokemon-showdown-agent-v6-demo --type model --exist-ok\n",
|
||||||
|
"hf upload-large-folder your-username/pokemon-showdown-agent-v6-demo outputs/pokemon_v6_release_demo/merged_model\n",
|
||||||
|
"hf upload your-username/pokemon-showdown-agent-v6-demo pokemon_agent_demo_notebook_v6_release.ipynb pokemon_showdown_agent_v6_release.ipynb\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"The full production checkpoint remains available at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`."
|
||||||
|
],
|
||||||
|
"id": "10de5655"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Bonus: GRPO sketch for the next stage\n",
|
||||||
|
"After SFT, you can layer a lightweight GRPO loop on top to reward valid action formatting or better tactical choices. In this ROCm container, the current `trl` build pulls in extra optional dependencies such as `mergekit` and `llm_blender` when importing `GRPOTrainer`.\n",
|
||||||
|
"\n",
|
||||||
|
"To keep the main SFT environment stable, the next cell does **not** auto-install those packages. Instead, it checks whether the optional GRPO stack is available and cleanly skips the setup with an actionable message if it is not.\n",
|
||||||
|
"\n",
|
||||||
|
"Run this section only after the earlier SFT/data cells have completed, because it reuses `train_samples`, `Dataset`, and `is_bfloat16_supported()` from the main notebook flow."
|
||||||
|
],
|
||||||
|
"id": "d0fb4c3e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import importlib.util\n",
|
||||||
|
"\n",
|
||||||
|
"grpo_trainer = None\n",
|
||||||
|
"missing_optional = [\n",
|
||||||
|
" package_name\n",
|
||||||
|
" for module_name, package_name in [(\"mergekit\", \"mergekit\"), (\"llm_blender\", \"llm-blender\")]\n",
|
||||||
|
" if importlib.util.find_spec(module_name) is None\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"if missing_optional:\n",
|
||||||
|
" print(\"Skipping Bonus/GRPO setup because optional packages are missing.\")\n",
|
||||||
|
" print(\"Install them in a separate or disposable environment if you want this section:\")\n",
|
||||||
|
" print(f\" pip install {' '.join(missing_optional)}\")\n",
|
||||||
|
"else:\n",
|
||||||
|
" try:\n",
|
||||||
|
" from trl import GRPOConfig, GRPOTrainer\n",
|
||||||
|
" except Exception as exc:\n",
|
||||||
|
" print(\"Skipping Bonus/GRPO setup because TRL could not import GRPOTrainer cleanly.\")\n",
|
||||||
|
" print(repr(exc))\n",
|
||||||
|
" else:\n",
|
||||||
|
" def format_reward_func(prompts, completions, **kwargs):\n",
|
||||||
|
" rewards = []\n",
|
||||||
|
" for completion in completions:\n",
|
||||||
|
" text = completion[0][\"content\"] if isinstance(completion, list) else str(completion)\n",
|
||||||
|
" cmd_match = re.search(r\"^(move\\s+.+|switch\\s+.+)$\", text.strip(), re.IGNORECASE)\n",
|
||||||
|
" rewards.append(5.0 if cmd_match else -3.0)\n",
|
||||||
|
" return rewards\n",
|
||||||
|
"\n",
|
||||||
|
" grpo_prompts = [\n",
|
||||||
|
" {\"prompt\": sample[\"text\"].split(\"<|im_start|>assistant\")[0] + \"<|im_start|>assistant\\n\"}\n",
|
||||||
|
" for sample in train_samples[:50]\n",
|
||||||
|
" ]\n",
|
||||||
|
" grpo_dataset = Dataset.from_list(grpo_prompts)\n",
|
||||||
|
"\n",
|
||||||
|
" grpo_config = GRPOConfig(\n",
|
||||||
|
" output_dir=\"outputs/pokemon_v6_release_grpo\",\n",
|
||||||
|
" learning_rate=3e-6,\n",
|
||||||
|
" per_device_train_batch_size=1,\n",
|
||||||
|
" gradient_accumulation_steps=4,\n",
|
||||||
|
" num_generations=4,\n",
|
||||||
|
" max_completion_length=128,\n",
|
||||||
|
" temperature=1.3,\n",
|
||||||
|
" max_steps=10,\n",
|
||||||
|
" logging_steps=1,\n",
|
||||||
|
" report_to=\"none\",\n",
|
||||||
|
" fp16=not is_bfloat16_supported(),\n",
|
||||||
|
" bf16=is_bfloat16_supported(),\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" grpo_trainer = GRPOTrainer(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" reward_funcs=[format_reward_func],\n",
|
||||||
|
" args=grpo_config,\n",
|
||||||
|
" train_dataset=grpo_dataset,\n",
|
||||||
|
" )\n",
|
||||||
|
" print(\"GRPO setup is ready. Uncomment `grpo_trainer.train()` when desired.\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Uncomment when you want to test the RL extension path.\n",
|
||||||
|
"# if grpo_trainer is not None:\n",
|
||||||
|
"# grpo_trainer.train()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"id": "208bfacd"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
590
pokemon_showdown_agent_v6_tutorial.ipynb
Normal file
590
pokemon_showdown_agent_v6_tutorial.ipynb
Normal file
@@ -0,0 +1,590 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"<div class=\"align-center\">\n",
|
||||||
|
"\n",
|
||||||
|
"<a href=\"https://rocm.docs.amd.com/projects/ai-developer-hub/en/latest/index.html\"><img src=\"https://raw.githubusercontent.com/ROCm/gpuaidev/main/docs/images/rocm_logo.png\" alt=\"ROCm AI Developer Hub\" width=\"150\" style=\"display:inline-block; margin-right: 20px;\"></a>\n",
|
||||||
|
"<a href=\"https://unsloth.ai/\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png\" width=\"115\" style=\"display:inline-block;\"></a>\n",
|
||||||
|
"\n",
|
||||||
|
"</div>\n",
|
||||||
|
"\n",
|
||||||
|
"<div align=\"center\">\n",
|
||||||
|
"\n",
|
||||||
|
"<a href=\"https://discord.gg/unsloth\"><img src=\"https://github.com/unslothai/unsloth/raw/main/images/Discord button.png\" width=\"145\"></a>\n",
|
||||||
|
"<a href=\"https://unsloth.ai/docs/\"><img src=\"https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true\" width=\"125\"></a>\n",
|
||||||
|
"\n",
|
||||||
|
"</div>\n",
|
||||||
|
"\n",
|
||||||
|
"---\n",
|
||||||
|
"\n",
|
||||||
|
"# Pokemon Showdown Agent v6 with Unsloth on AMD ROCm\n",
|
||||||
|
"\n",
|
||||||
|
"**Author:** Yueyuan (GoldenGrapeGentleman)\n",
|
||||||
|
"\n",
|
||||||
|
"Build a competitive Pokemon Showdown action model from raw replay logs and reproduce the core `v6` workflow with `Qwen/Qwen3-4B`, Unsloth, and AMD ROCm.\n",
|
||||||
|
"\n",
|
||||||
|
"### What you will build\n",
|
||||||
|
"1. **Portable ROCm setup**: Configure a notebook that works on MI300X-class systems and still degrades gracefully on smaller AMD machines.\n",
|
||||||
|
"2. **Disk-safe data preparation**: Stream and filter raw replay logs from Hugging Face without materializing the full 40GB+ corpus locally.\n",
|
||||||
|
"3. **Quick validation SFT**: Run a short 50-step alignment loop and watch the model shift from generic chat behavior to valid `move` / `switch` actions.\n",
|
||||||
|
"4. **Inference sanity check**: Test the tuned model on a real battle prefix and inspect whether it emits a legal next action.\n",
|
||||||
|
"5. **Export and publish**: Package the merged checkpoint and push your tutorial artifact to Hugging Face.\n",
|
||||||
|
"\n",
|
||||||
|
"### Tutorial scope\n",
|
||||||
|
"This notebook intentionally trains on a small streamed subset so the workflow stays teachable, reproducible, and safe on limited disk. The full production `v6` checkpoint can be published separately after training or loaded from Hugging Face once available.\n",
|
||||||
|
"\n",
|
||||||
|
"### Prerequisites\n",
|
||||||
|
"To install Unsloth on your AMD machine, follow the [AMD ROCm installation guide](https://unsloth.ai/docs/get-started/install/amd). This notebook is adapted from the Unsloth notebook ecosystem and keeps the same [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme) notebook license context."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. AMD ROCm environment setup\n",
|
||||||
|
"When running on AMD GPUs such as MI300X or Radeon PRO W7900, Unsloth relies on ROCm plus PyTorch SDPA kernels for strong throughput. The next cell sets conservative defaults before importing training libraries, prefers large external cache mounts when they exist, and falls back to `/tmp` so the tutorial does not break on machines without `/shared-docker`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"import datasets\n",
|
||||||
|
"\n",
|
||||||
|
"cache_candidates = [\n",
|
||||||
|
" \"/shared-docker/.cache/huggingface\",\n",
|
||||||
|
" \"/data/huggingface\",\n",
|
||||||
|
" \"/tmp/pokemon-hf-cache\",\n",
|
||||||
|
"]\n",
|
||||||
|
"tmp_candidates = [\n",
|
||||||
|
" \"/shared-docker/.cache/tmp\",\n",
|
||||||
|
" \"/data/tmp\",\n",
|
||||||
|
" \"/tmp/pokemon-tmp\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"cache_root = next((path for path in cache_candidates if os.path.isdir(path)), cache_candidates[-1])\n",
|
||||||
|
"tmp_root = next((path for path in tmp_candidates if os.path.isdir(path)), tmp_candidates[-1])\n",
|
||||||
|
"\n",
|
||||||
|
"# Default to the first visible AMD GPU unless the user already pinned a device.\n",
|
||||||
|
"os.environ.setdefault(\"HIP_VISIBLE_DEVICES\", \"0\")\n",
|
||||||
|
"os.environ.setdefault(\"ROCR_VISIBLE_DEVICES\", os.environ[\"HIP_VISIBLE_DEVICES\"])\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"HF_HOME\"] = cache_root\n",
|
||||||
|
"os.environ[\"HF_DATASETS_CACHE\"] = f\"{cache_root}/datasets\"\n",
|
||||||
|
"os.environ[\"TMPDIR\"] = tmp_root\n",
|
||||||
|
"\n",
|
||||||
|
"Path(os.environ[\"HF_HOME\"]).mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"Path(os.environ[\"HF_DATASETS_CACHE\"]).mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"Path(os.environ[\"TMPDIR\"]).mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"datasets.config.HF_DATASETS_CACHE = os.environ[\"HF_DATASETS_CACHE\"]\n",
|
||||||
|
"\n",
|
||||||
|
"# RDNA3 cards often need the AOTriton flag for Flash Attention via SDPA.\n",
|
||||||
|
"os.environ.setdefault(\"TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL\", \"1\")\n",
|
||||||
|
"os.environ.setdefault(\"PYTORCH_HIP_ALLOC_CONF\", \"expandable_segments:False\")\n",
|
||||||
|
"os.environ.setdefault(\"UNSLOTH_SKIP_TORCHVISION_CHECK\", \"1\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"HF_HOME={os.environ['HF_HOME']}\")\n",
|
||||||
|
"print(f\"HF_DATASETS_CACHE={os.environ['HF_DATASETS_CACHE']}\")\n",
|
||||||
|
"print(f\"TMPDIR={os.environ['TMPDIR']}\")\n"
|
||||||
|
],
|
||||||
|
"execution_count": 10,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Installation\n",
|
||||||
|
"If you are not already inside a prepared ROCm image, install Unsloth and the core training stack. For maximum AMD compatibility, this tutorial avoids depending on optional 8-bit optimizer packages and sticks to the standard PyTorch / Transformers / TRL toolchain."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"%%capture\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"if \"COLAB_\" not in \"\".join(os.environ.keys()):\n",
|
||||||
|
" !pip install --no-cache-dir --no-deps unsloth unsloth_zoo\n",
|
||||||
|
" !pip install --no-cache-dir transformers accelerate peft trl datasets psutil sentencepiece protobuf tyro huggingface_hub hf_transfer einops\n",
|
||||||
|
"else:\n",
|
||||||
|
" pass # Colab / Kaggle usually need their own setup flow."
|
||||||
|
],
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3. Load the base model\n",
|
||||||
|
"We load `Qwen/Qwen3-4B` with a configuration tuned for AMD inference stability. The full production `v6` recipe used longer contexts and more data, but this tutorial keeps the setup compact enough for a short reproducible run."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"import unsloth\n",
|
||||||
|
"from unsloth import FastLanguageModel\n",
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"max_seq_length = 2048 # Small enough for a fast tutorial, large enough for real replay prefixes.\n",
|
||||||
|
"dtype = torch.bfloat16 # Recommended on AMD for stable training and inference.\n",
|
||||||
|
"load_in_4bit = False # Keep inference stable on ROCm for the public tutorial flow.\n",
|
||||||
|
"\n",
|
||||||
|
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
||||||
|
" model_name=\"Qwen/Qwen3-4B\",\n",
|
||||||
|
" max_seq_length=max_seq_length,\n",
|
||||||
|
" dtype=dtype,\n",
|
||||||
|
" load_in_4bit=load_in_4bit,\n",
|
||||||
|
" attn_implementation=\"sdpa\",\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"==((====))== Unsloth 2026.1.4: Fast Qwen3 patching. Transformers: 4.57.6.\n",
|
||||||
|
" \\\\ /| AMD Instinct MI300X VF. Num GPUs = 1. Max memory: 191.688 GB. Platform: Linux.\n",
|
||||||
|
"O^O/ \\_/ \\ Torch: 2.10.0+rocm7.1. ROCm Toolkit: 7.1.25424. Triton: 3.6.0\n",
|
||||||
|
"\\ / Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]\n",
|
||||||
|
" \"-____-\" Free license: http://github.com/unslothai/unsloth\n",
|
||||||
|
"Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.56s/it]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We now add LoRA adapters so we only need to update 1 to 10% of all parameters!"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# 2. Apply LoRA Adapters\n",
|
||||||
|
"model = FastLanguageModel.get_peft_model(\n",
|
||||||
|
" model,\n",
|
||||||
|
" r = 64, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
|
||||||
|
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
||||||
|
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
||||||
|
" lora_alpha = 128,\n",
|
||||||
|
" lora_dropout = 0, # Supports any, but = 0 is optimized\n",
|
||||||
|
" bias = \"none\", # Supports any, but = \"none\" is optimized\n",
|
||||||
|
" use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
|
||||||
|
" random_state = 3407,\n",
|
||||||
|
" use_rslora = False, # We support rank stabilized LoRA\n",
|
||||||
|
" loftq_config = None, # And LoftQ\n",
|
||||||
|
")"
|
||||||
|
],
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 4. Data preparation from raw replay logs\n",
|
||||||
|
"The `v6` idea is simple but demanding: feed the model messy real replay logs instead of hand-written state summaries. For the public tutorial we stream `milkkarten/pokemon-showdown-replays-merged`, filter to stronger Gen 9 games, and build a compact subset entirely in memory so the notebook stays disk-safe."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from datasets import Dataset, load_dataset\n",
|
||||||
|
"from unsloth.chat_templates import get_chat_template\n",
|
||||||
|
"\n",
|
||||||
|
"tokenizer = get_chat_template(tokenizer, chat_template=\"qwen3\")\n",
|
||||||
|
"\n",
|
||||||
|
"MIN_RATING = 1400\n",
|
||||||
|
"MAX_TRAIN_SAMPLES = 2000\n",
|
||||||
|
"SHUFFLE_BUFFER = 10_000\n",
|
||||||
|
"MAX_LOG_CHARS = 6000\n",
|
||||||
|
"\n",
|
||||||
|
"SYSTEM_TEMPLATE = (\n",
|
||||||
|
" \"You are a Pokemon Showdown battle AI. You play as {side}. \"\n",
|
||||||
|
" \"Given the battle log, output your next action. \"\n",
|
||||||
|
" \"Format: move <name> OR switch <name>. \"\n",
|
||||||
|
" \"Append terastallize if you terastallize this turn.\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"def extract_winner_side(log_text):\n",
|
||||||
|
" winner = None\n",
|
||||||
|
" players = {}\n",
|
||||||
|
" for line in log_text.split(\"\\n\"):\n",
|
||||||
|
" parts = line.split(\"|\")\n",
|
||||||
|
" if len(parts) >= 4 and parts[1] == \"player\":\n",
|
||||||
|
" players[parts[2]] = parts[3]\n",
|
||||||
|
" if len(parts) >= 3 and parts[1] == \"win\":\n",
|
||||||
|
" winner = parts[2]\n",
|
||||||
|
" if not winner:\n",
|
||||||
|
" return None\n",
|
||||||
|
" for side, name in players.items():\n",
|
||||||
|
" if name == winner:\n",
|
||||||
|
" return side\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
"def format_sample(example):\n",
|
||||||
|
" log_text = example[\"log\"]\n",
|
||||||
|
" side = extract_winner_side(log_text)\n",
|
||||||
|
" if not side:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" lines = log_text.strip().split(\"\\n\")\n",
|
||||||
|
" turn_positions = []\n",
|
||||||
|
" for i, line in enumerate(lines):\n",
|
||||||
|
" parts = line.split(\"|\")\n",
|
||||||
|
" if len(parts) >= 3 and parts[1] == \"turn\":\n",
|
||||||
|
" try:\n",
|
||||||
|
" turn_positions.append((int(parts[2]), i))\n",
|
||||||
|
" except ValueError:\n",
|
||||||
|
" pass\n",
|
||||||
|
"\n",
|
||||||
|
" if len(turn_positions) < 2:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" # Keep prompts compact for the tutorial by using the first actionable turn\n",
|
||||||
|
" # after the opening whenever possible.\n",
|
||||||
|
" target_turn_idx = 0 if len(turn_positions) == 2 else 1\n",
|
||||||
|
" _, turn_line_idx = turn_positions[target_turn_idx]\n",
|
||||||
|
" next_turn_line = turn_positions[target_turn_idx + 1][1] if target_turn_idx + 1 < len(turn_positions) else len(lines)\n",
|
||||||
|
"\n",
|
||||||
|
" action = None\n",
|
||||||
|
" for j in range(turn_line_idx + 1, next_turn_line):\n",
|
||||||
|
" parts = lines[j].split(\"|\")\n",
|
||||||
|
" if len(parts) < 4:\n",
|
||||||
|
" continue\n",
|
||||||
|
" if parts[1] == \"move\" and parts[2].startswith(f\"{side}a:\"):\n",
|
||||||
|
" tera = \"\"\n",
|
||||||
|
" start_look = max(0, j - 3)\n",
|
||||||
|
" end_look = min(len(lines), j + 3)\n",
|
||||||
|
" if any(\"terastallize\" in lines[k] and side in lines[k] for k in range(start_look, end_look)):\n",
|
||||||
|
" tera = \" terastallize\"\n",
|
||||||
|
" action = f\"move {parts[3]}{tera}\"\n",
|
||||||
|
" break\n",
|
||||||
|
" if parts[1] == \"switch\" and parts[2].startswith(f\"{side}a:\"):\n",
|
||||||
|
" pokemon = parts[2].split(\": \")[1] if \": \" in parts[2] else parts[2]\n",
|
||||||
|
" action = f\"switch {pokemon}\"\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" if not action:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" log_prefix = \"\\n\".join(lines[:turn_line_idx + 1])\n",
|
||||||
|
" if len(log_prefix) > MAX_LOG_CHARS:\n",
|
||||||
|
" return {\"text\": \"\"}\n",
|
||||||
|
"\n",
|
||||||
|
" messages = [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": SYSTEM_TEMPLATE.format(side=side)},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": log_prefix},\n",
|
||||||
|
" {\"role\": \"assistant\", \"content\": action},\n",
|
||||||
|
" ]\n",
|
||||||
|
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)\n",
|
||||||
|
" return {\"text\": text}\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Streaming, filtering, and formatting replay logs...\")\n",
|
||||||
|
"dataset = load_dataset(\"milkkarten/pokemon-showdown-replays-merged\", split=\"train\", streaming=True)\n",
|
||||||
|
"dataset = dataset.shuffle(seed=3407, buffer_size=SHUFFLE_BUFFER)\n",
|
||||||
|
"dataset = dataset.filter(lambda x: \"gen9\" in str(x.get(\"formatid\") or \"\") and (x.get(\"rating\") or 0) >= MIN_RATING)\n",
|
||||||
|
"\n",
|
||||||
|
"train_samples = []\n",
|
||||||
|
"scanned = 0\n",
|
||||||
|
"for row in dataset:\n",
|
||||||
|
" scanned += 1\n",
|
||||||
|
" formatted = format_sample(row)\n",
|
||||||
|
" if formatted[\"text\"]:\n",
|
||||||
|
" train_samples.append(formatted)\n",
|
||||||
|
" if len(train_samples) >= MAX_TRAIN_SAMPLES:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset = Dataset.from_list(train_samples).shuffle(seed=3407)\n",
|
||||||
|
"print(f\"Collected {len(train_dataset)} training examples after scanning {scanned} streamed replays.\")\n",
|
||||||
|
"print(train_dataset[0][\"text\"][:800])"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Downloading and parsing dataset...\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Repo card metadata block was not found. Setting CardData to empty.\n",
|
||||||
|
"[huggingface_hub.repocard|WARNING]Repo card metadata block was not found. Setting CardData to empty.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Downloading data: 100%|██████████| 583/583 [00:00<00:00, 3215.89files/s]\n",
|
||||||
|
"Generating train split: 96%|█████████▌| 27837802/29057184 [08:06<00:21, 57181.98 examples/s]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"output_type": "error",
|
||||||
|
"ename": "DatasetGenerationError",
|
||||||
|
"evalue": "An error occurred while generating the dataset",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1834\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1833\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1834\u001b[0m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_table\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtable\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1835\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m CastError \u001b[38;5;28;01mas\u001b[39;00m cast_error:\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/arrow_writer.py:719\u001b[0m, in \u001b[0;36mArrowWriter.write_table\u001b[0;34m(self, pa_table, writer_batch_size)\u001b[0m\n\u001b[1;32m 718\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_examples \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m pa_table\u001b[38;5;241m.\u001b[39mnum_rows\n\u001b[0;32m--> 719\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpa_writer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_table\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpa_table\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwriter_batch_size\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/pyarrow/ipc.pxi:616\u001b[0m, in \u001b[0;36mpyarrow.lib._CRecordBatchWriter.write_table\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/pyarrow/error.pxi:89\u001b[0m, in \u001b[0;36mpyarrow.lib.check_status\u001b[0;34m()\u001b[0m\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/fsspec/implementations/local.py:469\u001b[0m, in \u001b[0;36mLocalFileOpener.write\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mwrite\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 469\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device",
|
||||||
|
"\nDuring handling of the above exception, another exception occurred:\n",
|
||||||
|
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1850\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1849\u001b[0m num_shards \u001b[38;5;241m=\u001b[39m shard_id \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1850\u001b[0m num_examples, num_bytes \u001b[38;5;241m=\u001b[39m \u001b[43mwriter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfinalize\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1851\u001b[0m writer\u001b[38;5;241m.\u001b[39mclose()\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/arrow_writer.py:736\u001b[0m, in \u001b[0;36mArrowWriter.finalize\u001b[0;34m(self, close_stream)\u001b[0m\n\u001b[1;32m 735\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m close_stream:\n\u001b[0;32m--> 736\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstream\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 737\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/fsspec/implementations/local.py:487\u001b[0m, in \u001b[0;36mLocalFileOpener.close\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mclose\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 487\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"\u001b[0;31mOSError\u001b[0m: [Errno 28] No space left on device",
|
||||||
|
"\nThe above exception was the direct cause of the following exception:\n",
|
||||||
|
"\u001b[0;31mDatasetGenerationError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[14], line 82\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m: text}\n\u001b[1;32m 81\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading and parsing dataset...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 82\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmilkkarten/pokemon-showdown-replays-merged\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 84\u001b[0m \u001b[38;5;66;03m# Filter a small subset for demonstration purposes\u001b[39;00m\n\u001b[1;32m 85\u001b[0m dataset \u001b[38;5;241m=\u001b[39m dataset\u001b[38;5;241m.\u001b[39mfilter(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mgen9\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(x\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mformatid\u001b[39m\u001b[38;5;124m'\u001b[39m)) \u001b[38;5;129;01mand\u001b[39;00m (x\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mrating\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1400\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/load.py:1417\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, keep_in_memory, save_infos, revision, token, streaming, num_proc, storage_options, **config_kwargs)\u001b[0m\n\u001b[1;32m 1414\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\u001b[38;5;241m.\u001b[39mas_streaming_dataset(split\u001b[38;5;241m=\u001b[39msplit)\n\u001b[1;32m 1416\u001b[0m \u001b[38;5;66;03m# Download and prepare data\u001b[39;00m\n\u001b[0;32m-> 1417\u001b[0m \u001b[43mbuilder_instance\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdownload_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1418\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1419\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1420\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1421\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1422\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1423\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1425\u001b[0m \u001b[38;5;66;03m# Build dataset for splits\u001b[39;00m\n\u001b[1;32m 1426\u001b[0m keep_in_memory \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 1427\u001b[0m keep_in_memory \u001b[38;5;28;01mif\u001b[39;00m keep_in_memory \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m is_small_dataset(builder_instance\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size)\n\u001b[1;32m 1428\u001b[0m )\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:897\u001b[0m, in \u001b[0;36mDatasetBuilder.download_and_prepare\u001b[0;34m(self, output_dir, download_config, download_mode, verification_mode, dl_manager, base_path, file_format, max_shard_size, num_proc, storage_options, **download_and_prepare_kwargs)\u001b[0m\n\u001b[1;32m 895\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_proc \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 896\u001b[0m prepare_split_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnum_proc\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m num_proc\n\u001b[0;32m--> 897\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_download_and_prepare\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 898\u001b[0m \u001b[43m \u001b[49m\u001b[43mdl_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdl_manager\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 899\u001b[0m \u001b[43m \u001b[49m\u001b[43mverification_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverification_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 900\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 901\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdownload_and_prepare_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 902\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 903\u001b[0m \u001b[38;5;66;03m# Sync info\u001b[39;00m\n\u001b[1;32m 904\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39mdataset_size \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(split\u001b[38;5;241m.\u001b[39mnum_bytes \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minfo\u001b[38;5;241m.\u001b[39msplits\u001b[38;5;241m.\u001b[39mvalues())\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:973\u001b[0m, in \u001b[0;36mDatasetBuilder._download_and_prepare\u001b[0;34m(self, dl_manager, verification_mode, **prepare_split_kwargs)\u001b[0m\n\u001b[1;32m 969\u001b[0m split_dict\u001b[38;5;241m.\u001b[39madd(split_generator\u001b[38;5;241m.\u001b[39msplit_info)\n\u001b[1;32m 971\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 972\u001b[0m \u001b[38;5;66;03m# Prepare split will record examples associated to the split\u001b[39;00m\n\u001b[0;32m--> 973\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_prepare_split\u001b[49m\u001b[43m(\u001b[49m\u001b[43msplit_generator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mprepare_split_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 974\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 975\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mOSError\u001b[39;00m(\n\u001b[1;32m 976\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot find data file. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 977\u001b[0m \u001b[38;5;241m+\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_download_instructions \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 978\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mOriginal error:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(e)\n\u001b[1;32m 980\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1705\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split\u001b[0;34m(self, split_generator, file_format, num_proc, max_shard_size)\u001b[0m\n\u001b[1;32m 1703\u001b[0m job_id \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 1704\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m pbar:\n\u001b[0;32m-> 1705\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m job_id, done, content \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_split_single(\n\u001b[1;32m 1706\u001b[0m gen_kwargs\u001b[38;5;241m=\u001b[39mgen_kwargs, job_id\u001b[38;5;241m=\u001b[39mjob_id, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m_prepare_split_args\n\u001b[1;32m 1707\u001b[0m ):\n\u001b[1;32m 1708\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 1709\u001b[0m result \u001b[38;5;241m=\u001b[39m content\n",
|
||||||
|
"File \u001b[0;32m/opt/venv/lib/python3.10/site-packages/datasets/builder.py:1861\u001b[0m, in \u001b[0;36mArrowBasedBuilder._prepare_split_single\u001b[0;34m(self, gen_kwargs, fpath, file_format, max_shard_size, job_id)\u001b[0m\n\u001b[1;32m 1859\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, DatasetGenerationError):\n\u001b[1;32m 1860\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m\n\u001b[0;32m-> 1861\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetGenerationError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAn error occurred while generating the dataset\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 1863\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m job_id, \u001b[38;5;28;01mTrue\u001b[39;00m, (total_num_examples, total_num_bytes, writer\u001b[38;5;241m.\u001b[39m_features, num_shards, shard_lengths)\n",
|
||||||
|
"\u001b[0;31mDatasetGenerationError\u001b[0m: An error occurred while generating the dataset"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 5. Quick validation SFT\n",
|
||||||
|
"Now we run a deliberately short 50-step supervised fine-tuning loop. The point is not to maximize ladder strength in one notebook session, but to make *Agentic Alignment* visible: even a tiny run is often enough to push the model toward strict action formatting instead of free-form commentary."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from trl import SFTConfig, SFTTrainer\n",
|
||||||
|
"from unsloth import is_bfloat16_supported\n",
|
||||||
|
"\n",
|
||||||
|
"trainer = SFTTrainer(\n",
|
||||||
|
" model=model,\n",
|
||||||
|
" tokenizer=tokenizer,\n",
|
||||||
|
" train_dataset=train_dataset,\n",
|
||||||
|
" dataset_text_field=\"text\",\n",
|
||||||
|
" max_seq_length=max_seq_length,\n",
|
||||||
|
" packing=False,\n",
|
||||||
|
" args=SFTConfig(\n",
|
||||||
|
" per_device_train_batch_size=2,\n",
|
||||||
|
" gradient_accumulation_steps=4,\n",
|
||||||
|
" warmup_steps=5,\n",
|
||||||
|
" max_steps=50,\n",
|
||||||
|
" learning_rate=2e-4,\n",
|
||||||
|
" fp16=not is_bfloat16_supported(),\n",
|
||||||
|
" bf16=is_bfloat16_supported(),\n",
|
||||||
|
" logging_steps=10,\n",
|
||||||
|
" optim=\"adamw_torch\",\n",
|
||||||
|
" weight_decay=0.01,\n",
|
||||||
|
" lr_scheduler_type=\"linear\",\n",
|
||||||
|
" seed=3407,\n",
|
||||||
|
" output_dir=\"model_output_v6_demo\",\n",
|
||||||
|
" save_steps=50,\n",
|
||||||
|
" save_total_limit=2,\n",
|
||||||
|
" report_to=\"none\",\n",
|
||||||
|
" ),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"trainer_stats = trainer.train()\n",
|
||||||
|
"\n",
|
||||||
|
"# Optional: export a standalone merged checkpoint for sharing or Hub upload.\n",
|
||||||
|
"# export_dir = \"model_output_v6_demo/merged_model\"\n",
|
||||||
|
"# model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 6. Inference sanity check\n",
|
||||||
|
"Use the same prompt once before training and once after training if you want to see the alignment jump clearly. A good post-SFT answer should be short, action-oriented, and formatted as a legal `move` or `switch` command."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"FastLanguageModel.for_inference(model)\n",
|
||||||
|
"\n",
|
||||||
|
"sys_msg = \"You are a Pokemon Showdown battle AI. You play as p2. Given the battle log, output your next action. Format: move <name> OR switch <name>. Append terastallize if you terastallize this turn.\"\n",
|
||||||
|
"user_msg = '''|player|p1|Player1|266|1500\n",
|
||||||
|
"|player|p2|Player2|1|1500\n",
|
||||||
|
"|teamsize|p1|6\n",
|
||||||
|
"|teamsize|p2|6\n",
|
||||||
|
"|gen|9\n",
|
||||||
|
"|tier|[Gen 9] OU\n",
|
||||||
|
"|\n",
|
||||||
|
"|start\n",
|
||||||
|
"|switch|p1a: Garchomp|Garchomp, M|100/100\n",
|
||||||
|
"|switch|p2a: Corviknight|Corviknight, M|100/100\n",
|
||||||
|
"|turn|1\n",
|
||||||
|
"|move|p1a: Garchomp|Earthquake|p2a: Corviknight\n",
|
||||||
|
"|-immune|p2a: Corviknight\n",
|
||||||
|
"|turn|2'''\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": sys_msg},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": user_msg},\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
||||||
|
"inputs = tokenizer(text, return_tensors=\"pt\").to(\"cuda\")\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" outputs = model.generate(\n",
|
||||||
|
" **inputs,\n",
|
||||||
|
" max_new_tokens=64,\n",
|
||||||
|
" temperature=0.1,\n",
|
||||||
|
" do_sample=False,\n",
|
||||||
|
" pad_token_id=tokenizer.eos_token_id,\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)\n",
|
||||||
|
"if \"<|im_start|>assistant\\n\" in full_response:\n",
|
||||||
|
" response = full_response.split(\"<|im_start|>assistant\\n\")[-1]\n",
|
||||||
|
"else:\n",
|
||||||
|
" response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)\n",
|
||||||
|
"response = response.replace(\"<|im_end|>\", \"\").strip()\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"\\n--- AI AGENT PREDICTION ---\")\n",
|
||||||
|
"print(response)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 7. Export and publish to Hugging Face\n",
|
||||||
|
"After the 50-step tutorial run, you can optionally merge the adapter into standalone weights and publish the result.\n",
|
||||||
|
"\n",
|
||||||
|
"```python\n",
|
||||||
|
"export_dir = \"model_output_v6_demo/merged_model\"\n",
|
||||||
|
"model.save_pretrained_merged(export_dir, tokenizer, save_method=\"merged_16bit\")\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Then upload the merged folder and notebook:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"hf auth whoami\n",
|
||||||
|
"hf upload-large-folder your-username/pokemon-showdown-agent-v6-demo model_output_v6_demo/merged_model\n",
|
||||||
|
"hf upload your-username/pokemon-showdown-agent-v6-demo pokemon_agent_demo_notebook_v2.ipynb pokemon_agent_demo_notebook_v6.ipynb\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"The full production checkpoint is now published at `GoldenGrapeGentleman1/pokemon-showdown-agent-v6`, while this notebook remains the lighter tutorial flow for fast reproduction."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Bonus: advanced tactics with GRPO\n",
|
||||||
|
"\n",
|
||||||
|
"Supervised fine-tuning teaches the model how strong players act. To push beyond imitation and reward outcomes or formatting more explicitly, you can add **Group Relative Policy Optimization (GRPO)** on top.\n",
|
||||||
|
"\n",
|
||||||
|
"For Pokemon, a first reward function can stay simple: reward valid action formatting, reward strategically sensible outputs, and penalize clearly bad or impossible commands. The cell below is still conceptual, but it shows how the `v6` agent can extend into RL after the SFT tutorial."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"from trl import GRPOTrainer, GRPOConfig\n",
|
||||||
|
"import re\n",
|
||||||
|
"\n",
|
||||||
|
"# 1. Define a Reward Function (Example: Check if output contains a valid command)\n",
|
||||||
|
"def format_reward_func(prompts, completions, **kwargs):\n",
|
||||||
|
" rewards = []\n",
|
||||||
|
" for completion in completions:\n",
|
||||||
|
" # TRL completions may be strings or list of dicts depending on version\n",
|
||||||
|
" text = completion[0][\"content\"] if isinstance(completion, list) else str(completion)\n",
|
||||||
|
" \n",
|
||||||
|
" # Reward the model heavily if it successfully follows the format `move X` or `switch X`\n",
|
||||||
|
" cmd_match = re.search(r\"(move\\s+[\\w\\-]+|switch\\s+[\\w\\-]+)\", text, re.IGNORECASE)\n",
|
||||||
|
" if cmd_match:\n",
|
||||||
|
" rewards.append(5.0)\n",
|
||||||
|
" else:\n",
|
||||||
|
" rewards.append(-3.0) # Penalize conversational babbling\n",
|
||||||
|
" \n",
|
||||||
|
" return rewards\n",
|
||||||
|
"\n",
|
||||||
|
"# 2. Prepare GRPO Prompts (We only need the prompts, GRPO will generate its own completions)\n",
|
||||||
|
"grpo_prompts = [{\"prompt\": p[\"text\"].split(\"<|im_start|>assistant\")[0] + \"<|im_start|>assistant\\n\"} for p in train_samples[:50]]\n",
|
||||||
|
"from datasets import Dataset\n",
|
||||||
|
"grpo_dataset = Dataset.from_list(grpo_prompts)\n",
|
||||||
|
"\n",
|
||||||
|
"# 3. Configure GRPO\n",
|
||||||
|
"grpo_config = GRPOConfig(\n",
|
||||||
|
" output_dir = \"grpo_outputs\",\n",
|
||||||
|
" learning_rate = 3e-6,\n",
|
||||||
|
" per_device_train_batch_size = 1,\n",
|
||||||
|
" gradient_accumulation_steps = 4,\n",
|
||||||
|
" num_generations = 4, # Generate 4 different strategies to compare\n",
|
||||||
|
" max_completion_length = 128,\n",
|
||||||
|
" temperature = 1.3, # Encourage exploration\n",
|
||||||
|
" max_steps = 10,\n",
|
||||||
|
" logging_steps = 1,\n",
|
||||||
|
" report_to = \"none\",\n",
|
||||||
|
" fp16 = not is_bfloat16_supported(),\n",
|
||||||
|
" bf16 = is_bfloat16_supported(),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"grpo_trainer = GRPOTrainer(\n",
|
||||||
|
" model = model,\n",
|
||||||
|
" reward_funcs = [format_reward_func],\n",
|
||||||
|
" args = grpo_config,\n",
|
||||||
|
" train_dataset = grpo_dataset,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Uncomment to run the GRPO loop\n",
|
||||||
|
"# grpo_trainer.train()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
31
special_tokens_map.json
Normal file
31
special_tokens_map.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>"
|
||||||
|
],
|
||||||
|
"eos_token": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"pad_token": {
|
||||||
|
"content": "<|vision_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
3
tokenizer.json
Normal file
3
tokenizer.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
|
||||||
|
size 11422654
|
||||||
241
tokenizer_config.json
Normal file
241
tokenizer_config.json
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": false,
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"151643": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151644": {
|
||||||
|
"content": "<|im_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151645": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151646": {
|
||||||
|
"content": "<|object_ref_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151647": {
|
||||||
|
"content": "<|object_ref_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151648": {
|
||||||
|
"content": "<|box_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151649": {
|
||||||
|
"content": "<|box_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151650": {
|
||||||
|
"content": "<|quad_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151651": {
|
||||||
|
"content": "<|quad_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151652": {
|
||||||
|
"content": "<|vision_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151653": {
|
||||||
|
"content": "<|vision_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151654": {
|
||||||
|
"content": "<|vision_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151655": {
|
||||||
|
"content": "<|image_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151656": {
|
||||||
|
"content": "<|video_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151657": {
|
||||||
|
"content": "<tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151658": {
|
||||||
|
"content": "</tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151659": {
|
||||||
|
"content": "<|fim_prefix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151660": {
|
||||||
|
"content": "<|fim_middle|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151661": {
|
||||||
|
"content": "<|fim_suffix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151662": {
|
||||||
|
"content": "<|fim_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151663": {
|
||||||
|
"content": "<|repo_name|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151664": {
|
||||||
|
"content": "<|file_sep|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151665": {
|
||||||
|
"content": "<tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151666": {
|
||||||
|
"content": "</tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151667": {
|
||||||
|
"content": "<think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151668": {
|
||||||
|
"content": "</think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>"
|
||||||
|
],
|
||||||
|
"bos_token": null,
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "<|im_end|>",
|
||||||
|
"errors": "replace",
|
||||||
|
"extra_special_tokens": {},
|
||||||
|
"model_max_length": 40960,
|
||||||
|
"pad_token": "<|vision_pad|>",
|
||||||
|
"padding_side": "left",
|
||||||
|
"split_special_tokens": false,
|
||||||
|
"tokenizer_class": "Qwen2Tokenizer",
|
||||||
|
"unk_token": null,
|
||||||
|
"chat_template": "\n{%- if tools %}\n {{- '<|im_start|>system\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\n\n' }}\n {%- endif %}\n {{- \"# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\n</tool_call><|im_end|>\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for forward_message in messages %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- set message = messages[index] %}\n {%- set current_content = message.content if message.content is not none else '' %}\n {%- set tool_start = '<tool_response>' %}\n {%- set tool_start_length = tool_start|length %}\n {%- set start_of_message = current_content[:tool_start_length] %}\n {%- set tool_end = '</tool_response>' %}\n {%- set tool_end_length = tool_end|length %}\n {%- set start_pos = (current_content|length) - tool_end_length %}\n {%- if start_pos < 0 %}\n {%- set start_pos = 0 %}\n {%- endif %}\n {%- set end_of_message = current_content[start_pos:] %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(start_of_message == tool_start and end_of_message == tool_end) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = (message.content.split('</think>')|last).lstrip('\n') %}\n {%- set reasoning_content = (message.content.split('</think>')|first).rstrip('\n') %}\n {%- set reasoning_content = (reasoning_content.split('<think>')|last).lstrip('\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\n<tool_response>\n' }}\n {{- message.content }}\n {{- '\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\n\n</think>\n\n' }}\n {%- endif %}\n{%- endif %}\n"
|
||||||
|
}
|
||||||
1
vocab.json
Normal file
1
vocab.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user