274 lines
11 KiB
Markdown
274 lines
11 KiB
Markdown
|
|
---
|
|||
|
|
language:
|
|||
|
|
- ja
|
|||
|
|
- en
|
|||
|
|
license:
|
|||
|
|
- llama3.1
|
|||
|
|
- gemma
|
|||
|
|
library_name: transformers
|
|||
|
|
base_model:
|
|||
|
|
- DataPilot/ArrowCanaria-Llama-8B-SFT-v0.1
|
|||
|
|
tags:
|
|||
|
|
- llama3
|
|||
|
|
- rlhf
|
|||
|
|
- grpo
|
|||
|
|
- dapo
|
|||
|
|
- japanese
|
|||
|
|
- aituber
|
|||
|
|
- roleplay
|
|||
|
|
- chat
|
|||
|
|
pipeline_tag: text-generation
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
# ArrowCanaria-Llama-8B-RL-v0.1
|
|||
|
|
|
|||
|
|
## モデル概要
|
|||
|
|
|
|||
|
|
**ArrowCanaria-Llama-8B-RL-v0.1** は、[ArrowCanaria-Llama-8B-SFT-v0.1](https://huggingface.co/DataPilot/ArrowCanaria-Llama-8B-SFT-v0.1) に対して RLHF(Reinforcement Learning from Human Feedback)を適用し、応答品質をさらに向上させたAItuber向け日本語特化モデルです。
|
|||
|
|
|
|||
|
|
SFTモデルは高品質なデータで学習されていますが、モデルの応答が「データに含まれる平均的な応答」に収束してしまう傾向があります。本モデルでは、外部報酬モデル(Reward Model)によるフィードバックを用いた強化学習を適用することで、**共感・傾聴の品質**や**知識応答の正確性・分かりやすさ**を、SFTの水準からさらに引き上げています。
|
|||
|
|
|
|||
|
|
強化学習アルゴリズムには **GRPO(Group Relative Policy Optimization)** を採用し、DAPO損失関数による安定した最適化を実現しています。相談応答と知識応答の2フェーズで段階的にRLHFを行うことで、SFTで獲得した雑談・RP・キャラクター対話能力を保持しつつ、応答の質を選択的に向上させています。
|
|||
|
|
|
|||
|
|
### 想定ユースケース
|
|||
|
|
|
|||
|
|
- **AItuber / AI VTuber**: 配信中のリスナーとの雑談・コメント対応
|
|||
|
|
- **チャットボット**: 自然な日本語での日常会話・悩み相談
|
|||
|
|
- **ロールプレイ**: キャラクターを演じた対話・創作支援
|
|||
|
|
- **汎用アシスタント**: 知識応答・推論・ツール呼び出しを含む幅広いタスク
|
|||
|
|
|
|||
|
|
### 主な特徴
|
|||
|
|
|
|||
|
|
- 🗣️ **自然な日本語応答** — 定型文や翻訳調を排した、人間らしい対話。RLHFによりさらに自然さが向上
|
|||
|
|
- 💬 **高い雑談・相談性能** — 共感・傾聴・具体的助言の品質が報酬モデルによるフィードバックで最適化済み
|
|||
|
|
- 🎭 **RP・キャラクター対話** — SFTで獲得した一貫した人格・感情表現を保持
|
|||
|
|
- 🧠 **推論力・知識応答** — 正確で分かりやすい知識応答をRLHFで強化
|
|||
|
|
- 🔧 **Tool Use / RAG** — Function Calling・検索拡張生成に対応
|
|||
|
|
- ✍️ **クリエイティブ表現** — 文学的な比喩・暗喩・情景描写など、豊かな日本語表現力
|
|||
|
|
|
|||
|
|
### モデル仕様
|
|||
|
|
|
|||
|
|
| 項目 | 詳細 |
|
|||
|
|
|---|---|
|
|||
|
|
| **モデル名** | `DataPilot/ArrowCanaria-Llama-8B-RL-v0.1` |
|
|||
|
|
| **SFTモデル** | `DataPilot/ArrowCanaria-Llama-8B-SFT-v0.1` |
|
|||
|
|
| **ベースモデル** | `tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5` |
|
|||
|
|
| **アーキテクチャ** | Llama 3.1 (Transformer decoder-only) |
|
|||
|
|
| **パラメータ数** | 8B |
|
|||
|
|
| **RLHFアルゴリズム** | GRPO + DAPO損失 |
|
|||
|
|
| **RLHFデータ量** | 3,600件(相談 1,600件 + 知識QA 2,000件) |
|
|||
|
|
| **コンテキスト長** | 4,096 tokens |
|
|||
|
|
| **精度** | BF16 |
|
|||
|
|
| **ライセンス** | Llama 3.1 Community License + Gemma Terms of use License|
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 推論方法
|
|||
|
|
|
|||
|
|
### 🤗 Transformers
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
|
|||
|
|
model_name = "DataPilot/ArrowCanaria-Llama-8B-RL-v0.1"
|
|||
|
|
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|||
|
|
model_name,
|
|||
|
|
torch_dtype="bfloat16",
|
|||
|
|
device_map="auto",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{"role": "system", "content": "あなたは親しみやすいAIアシスタントです。自然な日本語で会話してください。"},
|
|||
|
|
{"role": "user", "content": "最近ちょっと疲れてるんだよね。何かリフレッシュできる方法ない?"},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
input_ids = tokenizer.apply_chat_template(
|
|||
|
|
messages,
|
|||
|
|
add_generation_prompt=True,
|
|||
|
|
return_tensors="pt",
|
|||
|
|
).to(model.device)
|
|||
|
|
|
|||
|
|
outputs = model.generate(
|
|||
|
|
input_ids,
|
|||
|
|
max_new_tokens=512,
|
|||
|
|
temperature=0.7,
|
|||
|
|
top_p=0.9,
|
|||
|
|
repetition_penalty=1.05,
|
|||
|
|
do_sample=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
response = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
|
|||
|
|
print(response)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### ⚡ vLLM
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# vLLM サーバーの起動
|
|||
|
|
vllm serve DataPilot/ArrowCanaria-Llama-8B-RL-v0.1 \
|
|||
|
|
--dtype bfloat16 \
|
|||
|
|
--max-model-len 4096 \
|
|||
|
|
--host 0.0.0.0 \
|
|||
|
|
--port 8000
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
# リクエスト例(curl)
|
|||
|
|
curl http://localhost:8000/v1/chat/completions \
|
|||
|
|
-H "Content-Type: application/json" \
|
|||
|
|
-d '{
|
|||
|
|
"model": "DataPilot/ArrowCanaria-Llama-8B-RL-v0.1",
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "system", "content": "あなたは親しみやすいAIアシスタントです。自然な日本語で会話してください。"},
|
|||
|
|
{"role": "user", "content": "最近ちょっと疲れてるんだよね。何かリフレッシュできる方法ない?"}
|
|||
|
|
],
|
|||
|
|
"temperature": 0.7,
|
|||
|
|
"max_tokens": 512
|
|||
|
|
}'
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## データ概要
|
|||
|
|
|
|||
|
|
### SFTデータ(ArrowCanaria-Llama-8B-SFT-v0.1 で使用)
|
|||
|
|
|
|||
|
|
SFTモデルは、独自の合成データ生成フレームワーク **[SDG_LOOM](https://zenn.dev/holy_fox/articles/6c318b7e9e0b55)** を活用して作成された **175,000件超** の高品質合成データセットで学習されています。
|
|||
|
|
|
|||
|
|
| 言語 | 件数 | 比率 |
|
|||
|
|
|---|---|---|
|
|||
|
|
| 日本語 | 124,000件 | 70.9% |
|
|||
|
|
| 英語 | 51,000件 | 29.1% |
|
|||
|
|
| **合計** | **175,000件超** | **100%** |
|
|||
|
|
|
|||
|
|
SFTデータの内訳:知識応答・雑談/相談・Tool Use・RAG・推論・RP・AItuber RP・クリエイティブ表現
|
|||
|
|
|
|||
|
|
### RLHFデータ
|
|||
|
|
|
|||
|
|
RLHFでは、以下の2つのデータセットを用いて強化学習を実施しました。
|
|||
|
|
|
|||
|
|
| フェーズ | データセット | 件数 | 目的 |
|
|||
|
|
|---|---|---|---|
|
|||
|
|
| **Phase 1** | `内製EQデータセット` | 1,600件 | 相談応答の共感・傾聴品質の最適化 |
|
|||
|
|
| **Phase 2** | `DataPilot/Zero_SFT_Ja_v3.5` | 2,000件 | 知識応答の正確性・分かりやすさの最適化 |
|
|||
|
|
| | **合計** | **3,600件** | |
|
|||
|
|
|
|||
|
|
RLHFでは学習データの回答(assistant部分)は使用せず、ユーザーの質問(prompt)のみを用いてモデルに応答を生成させ、外部報酬モデルのスコアに基づいて方策を最適化しています。
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 学習概要
|
|||
|
|
|
|||
|
|
### モデル作成パイプライン
|
|||
|
|
|
|||
|
|
本モデルは以下の4段階で構築されています:
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
1. CPT(Continual Pre-Training)
|
|||
|
|
tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5
|
|||
|
|
+ 10k フィルタリング済み高品質日本語合成小説データセット
|
|||
|
|
↓
|
|||
|
|
2. Chat Vector マージ(Mergekit)
|
|||
|
|
CPTモデルに元モデルの Chat Vector をマージし対話能力を復元
|
|||
|
|
↓
|
|||
|
|
3. SFT(Supervised Fine-Tuning)
|
|||
|
|
175,000件超の合成データセットで LoRA による
|
|||
|
|
3フェーズカリキュラム学習 → ArrowCanaria-Llama-8B-SFT-v0.1
|
|||
|
|
↓
|
|||
|
|
4. RLHF(Reinforcement Learning from Human Feedback)
|
|||
|
|
GRPO + 外部報酬モデルによる
|
|||
|
|
2フェーズ強化学習 → ArrowCanaria-Llama-8B-RL-v0.1 ★本モデル
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### RLHF 学習設計
|
|||
|
|
|
|||
|
|
#### アルゴリズム: GRPO + DAPO
|
|||
|
|
|
|||
|
|
本モデルでは **GRPO(Group Relative Policy Optimization)** を採用しています。GRPOは各プロンプトに対して複数の応答候補を生成し、それらの報酬スコアの相対的な優劣(アドバンテージ)に基づいて方策を更新するアルゴリズムです。PPOと異なりCritic(価値関数)を必要としないため、メモリ効率が高く安定した学習が可能です。
|
|||
|
|
|
|||
|
|
損失関数には **DAPO(Direct Advantage Policy Optimization)** を使用しています。DAPOはGRPOの改良版で、KLペナルティを用いずクリッピングのみで方策の更新幅を制御するため、ハイパーパラメータのチューニングが簡易で学習が安定します。
|
|||
|
|
|
|||
|
|
#### 2フェーズ構成
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
Phase 1: 相談 RLHF ──▶ Phase 2: 知識QA RLHF
|
|||
|
|
1,600件 2,000件
|
|||
|
|
「共感・傾聴の最適化」 「正確さ・分かりやすさの最適化」
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
| フェーズ | データ件数 | 目的 | 学習率 |
|
|||
|
|
|---|---|---|---|
|
|||
|
|
| **Phase 1: 相談 RLHF** | 1,600件 | 共感・傾聴・具体的助言の品質を RM フィードバックで最適化 | `5e-6` |
|
|||
|
|
| **Phase 2: 知識QA RLHF** | 2,000件 | 知識応答の正確性・分かりやすさを最適化。Phase 1 の品質を保持 | `3e-6` |
|
|||
|
|
|
|||
|
|
#### 学習率の設計思想
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
P1 (5e-6) > P2 (3e-6)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- **Phase 1 (`5e-6`)**: 相談応答はモデルのEQ中核能力であり、積極的に最適化
|
|||
|
|
- **Phase 2 (`3e-6`)**: Phase 1 で獲得した相談品質を保護しつつ、知識応答を控えめに改善。RLHFの2段階目は過学習リスクが高いため慎重に設定
|
|||
|
|
|
|||
|
|
#### 報酬設計
|
|||
|
|
|
|||
|
|
| 報酬コンポーネント | 設定 | 説明 |
|
|||
|
|
|---|---|---|
|
|||
|
|
| **外部報酬モデル(RM)** | weight: 1.0 | HTTPサーバーとして稼働する外部RMが応答品質をスコアリング |
|
|||
|
|
| **報酬シェーピング** | tanh (scale: 3.0) | 生のRMスコアをtanhで非線形変換し、外れ値の影響を抑制 |
|
|||
|
|
| **過剰長ペナルティ** | weight: 0.12 | 冗長な応答にソフトペナルティを付与し、簡潔さを促進 |
|
|||
|
|
|
|||
|
|
#### 生成パラメータ(GRPO候補生成)
|
|||
|
|
|
|||
|
|
| パラメータ | 値 |
|
|||
|
|
|---|---|
|
|||
|
|
| **num_generations** | 8(各プロンプトに対して8候補を生成して比較) |
|
|||
|
|
| **temperature** | 0.75 |
|
|||
|
|
| **top_p** | 0.95 |
|
|||
|
|
| **top_k** | 40 |
|
|||
|
|
| **max_seq_length** | 2,048 tokens |
|
|||
|
|
| **min_completion_length** | 256 tokens |
|
|||
|
|
|
|||
|
|
#### 主な学習パラメータ
|
|||
|
|
|
|||
|
|
| パラメータ | Phase 1 | Phase 2 |
|
|||
|
|
|---|---|---|
|
|||
|
|
| **学習率** | 5e-6 | 3e-6 |
|
|||
|
|
| **Warmup ratio** | 0.05 | 0.08 |
|
|||
|
|
| **Weight decay** | 0.01 | 0.01 |
|
|||
|
|
| **実効バッチサイズ** | 4(batch=1 × accum=4) | 4(batch=1 × accum=4) |
|
|||
|
|
| **Max steps** | 400 | 500 |
|
|||
|
|
| **LoRA rank (r)** | 32 | 32 |
|
|||
|
|
| **損失関数** | DAPO | DAPO |
|
|||
|
|
| **Beta(KLペナルティ)** | 0.0 | 0.0 |
|
|||
|
|
| **Epsilon(クリッピング)** | 0.2 | 0.2 |
|
|||
|
|
| **学習フレームワーク** | Unsloth + TRL `GRPOTrainer` + vLLM | 同左 |
|
|||
|
|
|
|||
|
|
#### Phase 2 の工夫
|
|||
|
|
|
|||
|
|
- **Warmup ratio を 0.08 に増加**(Phase 1 は 0.05):既にRLHF済みのモデルへの追加学習のため、急激なパラメータ変動を抑制
|
|||
|
|
- **学習率を Phase 1 の 60%(3e-6)に低減**:Phase 1 で獲得した相談品質を最大限保護しつつ、知識応答品質を向上
|
|||
|
|
- **`--force_system_prompt` の適用**:データセット固有のシステムプロンプトを統一し、一貫した応答スタイルを維持
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## 引用
|
|||
|
|
|
|||
|
|
本モデルを利用した場合は、以下の情報を引用・参照していただけると幸いです:
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
DataPilot/ArrowCanaria-Llama-8B-RL-v0.1
|
|||
|
|
SFT model: DataPilot/ArrowCanaria-Llama-8B-SFT-v0.1
|
|||
|
|
Base model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## 謝辞
|
|||
|
|
|
|||
|
|
- [tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5](https://huggingface.co/tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5) — ベースモデル
|
|||
|
|
- [Unsloth](https://github.com/unslothai/unsloth) — 高速 LoRA 学習フレームワーク
|
|||
|
|
- [TRL](https://github.com/huggingface/trl) — GRPOTrainer
|
|||
|
|
- [vLLM](https://github.com/vllm-project/vllm) — GRPO候補生成エンジン
|
|||
|
|
- [SDG_LOOM](https://zenn.dev/holy_fox/articles/6c318b7e9e0b55) — 合成データ生成フレームワーク(SFTデータ作成に使用)
|