Files
palindrome-grpo/train_grpo.py
ModelHub XC 3ab2d41b26 初始化项目,由ModelHub XC社区提供模型
Model: SantiagoC/palindrome-grpo
Source: Original Platform
2026-06-15 23:56:17 +08:00

273 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
GRPO training script for palindrome generation.
Trains a small LLM (Qwen2.5-0.5B-Instruct) to generate palindromes
given a theme using Group Relative Policy Optimization.
Reward function:
1. Palindrome accuracy (continuous partial-credit: proportion of matching
mirrored character pairs)
2. Length bonus (longer palindromes = more reward)
3. Theme relevance (keyword overlap between theme and generated text)
reward = α * palindrome_accuracy + β * length_bonus + γ * theme_relevance
"""
import re
import os
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import TrainerCallback
import trackio
# ── Reward function ─────────────────────────────────────────────────────
def _normalize(text: str) -> str:
"""Strip non-alphanumeric and lowercase."""
return re.sub(r'[^a-zA-Z0-9]', '', text).lower()
def _palindrome_accuracy(text: str) -> float:
"""
Continuous palindrome score: proportion of matching mirrored pairs.
Perfect palindrome → 1.0
"ractar" (close to racecar) → partial credit
Empty / single char → 0.0 or 1.0 respectively
"""
cleaned = _normalize(text)
n = len(cleaned)
if n == 0:
return 0.0
if n == 1:
return 1.0
matches = 0
pairs = n // 2
for i in range(pairs):
if cleaned[i] == cleaned[n - 1 - i]:
matches += 1
return matches / pairs
def _length_bonus(text: str, max_len: int = 80) -> float:
"""
Normalized length reward: longer palindromes score higher.
Capped at max_len characters.
"""
cleaned = _normalize(text)
return min(len(cleaned), max_len) / max_len
def _theme_relevance(text: str, theme: str) -> float:
"""
Simple keyword overlap: does the palindrome contain the theme word
or substrings of it? Returns 0.0 - 1.0.
"""
text_lower = _normalize(text)
theme_clean = _normalize(theme)
if len(theme_clean) == 0:
return 0.0
# Exact match of the full theme word
if theme_clean in text_lower:
return 1.0
# Partial match: theme word split into n-grams
if len(theme_clean) > 3:
bigrams = set(theme_clean[i:i + 2] for i in range(len(theme_clean) - 1))
match_count = sum(1 for bg in bigrams if bg in text_lower)
return match_count / len(bigrams)
return 0.0
def palindrome_reward(
prompts,
completions,
completion_ids,
theme=None,
**kwargs,
) -> list[float]:
"""
GRPO reward function for palindrome generation.
Args:
prompts: list[list[dict]] — conversational prompts
completions: list[list[dict]] — generated completions
completion_ids: list[list[int]] — token IDs
theme: list[str] — theme column from dataset (forwarded via **kwargs)
Returns:
list[float] — one reward per sample
"""
# Reward weights (tunable)
alpha = 0.6 # palindrome accuracy weight
beta = 0.25 # length bonus weight
gamma = 0.15 # theme relevance weight
rewards = []
for i, completion in enumerate(completions):
# Extract text from conversational format
if isinstance(completion, list):
text = completion[0]["content"].strip()
else:
text = str(completion).strip()
acc = _palindrome_accuracy(text)
length = _length_bonus(text)
# Theme from dataset column
theme_text = ""
if theme is not None:
# theme list is flattened: [t*n_generations for t in batch_themes]
theme_text = theme[i] if i < len(theme) else ""
theme_rel = _theme_relevance(text, theme_text) if theme_text else 0.0
reward = alpha * acc + beta * length + gamma * theme_rel
rewards.append(reward)
return rewards
# ── Trackio alert callback ──────────────────────────────────────────────
class PalindromeAlertCallback(TrainerCallback):
"""Log Trackio alerts on key training events."""
def on_train_begin(self, args, state, control, **kwargs):
trackio.alert(
"Training started",
f"Model: Qwen/Qwen2.5-0.5B-Instruct | "
f"num_generations={args.num_generations} | "
f"temperature={args.temperature} | "
f"lr={args.learning_rate} | "
f"batch_size={args.per_device_train_batch_size}×"
f"{args.gradient_accumulation_steps}",
level="INFO",
)
def on_train_end(self, args, state, control, **kwargs):
trackio.alert(
"Training complete",
f"Model pushed to {args.hub_model_id}",
level="INFO",
)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Reward monitoring
reward = logs.get("reward")
if reward is not None:
if reward < 0.1 and state.global_step > 50:
trackio.alert(
"Low reward",
f"reward={reward:.4f} at step {state.global_step}"
"model not producing palindromes yet. Consider increasing "
"temperature for more exploration.",
level="WARN",
)
elif 0.1 <= reward < 0.5 and state.global_step > 50:
trackio.alert(
"Moderate reward",
f"reward={reward:.4f} at step {state.global_step}"
"model starting to learn. Continue training.",
level="INFO",
)
elif reward >= 0.7:
trackio.alert(
"High reward milestone",
f"reward={reward:.4f} at step {state.global_step}"
"model producing good palindromes!",
level="INFO",
)
# KL divergence monitoring
kl = logs.get("kl")
if kl is not None and kl > 0.5:
trackio.alert(
"High KL divergence",
f"kl={kl:.4f} at step {state.global_step}"
"policy diverging too far. Consider lowering learning rate.",
level="WARN",
)
# Loss monitoring
loss = logs.get("loss")
if loss is not None:
if loss > 10:
trackio.alert(
"High loss",
f"loss={loss:.4f} at step {state.global_step}"
"loss spike detected. Check for instability, "
"consider reducing learning rate by 0.5×.",
level="WARN",
)
# ── Main ────────────────────────────────────────────────────────────────
def main():
# Load dataset
dataset_path = os.getenv("DATASET_PATH", "SantiagoC/palindrome-prompts")
dataset = load_dataset(dataset_path, split="train")
print(f"Dataset size: {len(dataset)} samples")
print(f"Columns: {dataset.column_names}")
print(f"First prompt: {dataset[0]['prompt']}")
# Training config
training_args = GRPOConfig(
output_dir="./palindrome-grpo-output",
# ── GRPO specific ──
num_generations=4,
max_completion_length=128,
temperature=0.9,
beta=0.0, # No reference model KL penalty
scale_rewards=False, # Dr.GRPO recommendation
loss_type="dapo",
mask_truncated_completions=True,
# ── Standard training ──
learning_rate=1e-6,
num_train_epochs=6,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
bf16=True,
gradient_checkpointing=True,
# ── Logging ──
logging_steps=10,
logging_first_step=True,
save_steps=100,
save_total_limit=3,
disable_tqdm=True,
report_to="trackio",
# ── Hub push ──
push_to_hub=True,
hub_model_id=os.getenv("HUB_MODEL_ID", "SantiagoC/palindrome-grpo"),
run_name=os.getenv("RUN_NAME", "palindrome-grpo-v1"),
project="palindrome-grpo",
trackio_space_id=os.getenv("TRACKIO_SPACE_ID", "SantiagoC/mlintern-palindrm"),
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=palindrome_reward,
args=training_args,
train_dataset=dataset,
callbacks=[PalindromeAlertCallback()],
)
trainer.train()
# Final push
trainer.save_model()
trainer.push_to_hub()
if __name__ == "__main__":
main()