Files
palindrome-grpo/train_grpo.py

273 lines
8.7 KiB
Python
Raw Normal View History

"""
GRPO training script for palindrome generation.
Trains a small LLM (Qwen2.5-0.5B-Instruct) to generate palindromes
given a theme using Group Relative Policy Optimization.
Reward function:
1. Palindrome accuracy (continuous partial-credit: proportion of matching
mirrored character pairs)
2. Length bonus (longer palindromes = more reward)
3. Theme relevance (keyword overlap between theme and generated text)
reward = α * palindrome_accuracy + β * length_bonus + γ * theme_relevance
"""
import re
import os
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import TrainerCallback
import trackio
# ── Reward function ─────────────────────────────────────────────────────
def _normalize(text: str) -> str:
"""Strip non-alphanumeric and lowercase."""
return re.sub(r'[^a-zA-Z0-9]', '', text).lower()
def _palindrome_accuracy(text: str) -> float:
"""
Continuous palindrome score: proportion of matching mirrored pairs.
Perfect palindrome 1.0
"ractar" (close to racecar) partial credit
Empty / single char 0.0 or 1.0 respectively
"""
cleaned = _normalize(text)
n = len(cleaned)
if n == 0:
return 0.0
if n == 1:
return 1.0
matches = 0
pairs = n // 2
for i in range(pairs):
if cleaned[i] == cleaned[n - 1 - i]:
matches += 1
return matches / pairs
def _length_bonus(text: str, max_len: int = 80) -> float:
"""
Normalized length reward: longer palindromes score higher.
Capped at max_len characters.
"""
cleaned = _normalize(text)
return min(len(cleaned), max_len) / max_len
def _theme_relevance(text: str, theme: str) -> float:
"""
Simple keyword overlap: does the palindrome contain the theme word
or substrings of it? Returns 0.0 - 1.0.
"""
text_lower = _normalize(text)
theme_clean = _normalize(theme)
if len(theme_clean) == 0:
return 0.0
# Exact match of the full theme word
if theme_clean in text_lower:
return 1.0
# Partial match: theme word split into n-grams
if len(theme_clean) > 3:
bigrams = set(theme_clean[i:i + 2] for i in range(len(theme_clean) - 1))
match_count = sum(1 for bg in bigrams if bg in text_lower)
return match_count / len(bigrams)
return 0.0
def palindrome_reward(
prompts,
completions,
completion_ids,
theme=None,
**kwargs,
) -> list[float]:
"""
GRPO reward function for palindrome generation.
Args:
prompts: list[list[dict]] conversational prompts
completions: list[list[dict]] generated completions
completion_ids: list[list[int]] token IDs
theme: list[str] theme column from dataset (forwarded via **kwargs)
Returns:
list[float] one reward per sample
"""
# Reward weights (tunable)
alpha = 0.6 # palindrome accuracy weight
beta = 0.25 # length bonus weight
gamma = 0.15 # theme relevance weight
rewards = []
for i, completion in enumerate(completions):
# Extract text from conversational format
if isinstance(completion, list):
text = completion[0]["content"].strip()
else:
text = str(completion).strip()
acc = _palindrome_accuracy(text)
length = _length_bonus(text)
# Theme from dataset column
theme_text = ""
if theme is not None:
# theme list is flattened: [t*n_generations for t in batch_themes]
theme_text = theme[i] if i < len(theme) else ""
theme_rel = _theme_relevance(text, theme_text) if theme_text else 0.0
reward = alpha * acc + beta * length + gamma * theme_rel
rewards.append(reward)
return rewards
# ── Trackio alert callback ──────────────────────────────────────────────
class PalindromeAlertCallback(TrainerCallback):
"""Log Trackio alerts on key training events."""
def on_train_begin(self, args, state, control, **kwargs):
trackio.alert(
"Training started",
f"Model: Qwen/Qwen2.5-0.5B-Instruct | "
f"num_generations={args.num_generations} | "
f"temperature={args.temperature} | "
f"lr={args.learning_rate} | "
f"batch_size={args.per_device_train_batch_size}×"
f"{args.gradient_accumulation_steps}",
level="INFO",
)
def on_train_end(self, args, state, control, **kwargs):
trackio.alert(
"Training complete",
f"Model pushed to {args.hub_model_id}",
level="INFO",
)
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
# Reward monitoring
reward = logs.get("reward")
if reward is not None:
if reward < 0.1 and state.global_step > 50:
trackio.alert(
"Low reward",
f"reward={reward:.4f} at step {state.global_step}"
"model not producing palindromes yet. Consider increasing "
"temperature for more exploration.",
level="WARN",
)
elif 0.1 <= reward < 0.5 and state.global_step > 50:
trackio.alert(
"Moderate reward",
f"reward={reward:.4f} at step {state.global_step}"
"model starting to learn. Continue training.",
level="INFO",
)
elif reward >= 0.7:
trackio.alert(
"High reward milestone",
f"reward={reward:.4f} at step {state.global_step}"
"model producing good palindromes!",
level="INFO",
)
# KL divergence monitoring
kl = logs.get("kl")
if kl is not None and kl > 0.5:
trackio.alert(
"High KL divergence",
f"kl={kl:.4f} at step {state.global_step}"
"policy diverging too far. Consider lowering learning rate.",
level="WARN",
)
# Loss monitoring
loss = logs.get("loss")
if loss is not None:
if loss > 10:
trackio.alert(
"High loss",
f"loss={loss:.4f} at step {state.global_step}"
"loss spike detected. Check for instability, "
"consider reducing learning rate by 0.5×.",
level="WARN",
)
# ── Main ────────────────────────────────────────────────────────────────
def main():
# Load dataset
dataset_path = os.getenv("DATASET_PATH", "SantiagoC/palindrome-prompts")
dataset = load_dataset(dataset_path, split="train")
print(f"Dataset size: {len(dataset)} samples")
print(f"Columns: {dataset.column_names}")
print(f"First prompt: {dataset[0]['prompt']}")
# Training config
training_args = GRPOConfig(
output_dir="./palindrome-grpo-output",
# ── GRPO specific ──
num_generations=4,
max_completion_length=128,
temperature=0.9,
beta=0.0, # No reference model KL penalty
scale_rewards=False, # Dr.GRPO recommendation
loss_type="dapo",
mask_truncated_completions=True,
# ── Standard training ──
learning_rate=1e-6,
num_train_epochs=6,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
bf16=True,
gradient_checkpointing=True,
# ── Logging ──
logging_steps=10,
logging_first_step=True,
save_steps=100,
save_total_limit=3,
disable_tqdm=True,
report_to="trackio",
# ── Hub push ──
push_to_hub=True,
hub_model_id=os.getenv("HUB_MODEL_ID", "SantiagoC/palindrome-grpo"),
run_name=os.getenv("RUN_NAME", "palindrome-grpo-v1"),
project="palindrome-grpo",
trackio_space_id=os.getenv("TRACKIO_SPACE_ID", "SantiagoC/mlintern-palindrm"),
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=palindrome_reward,
args=training_args,
train_dataset=dataset,
callbacks=[PalindromeAlertCallback()],
)
trainer.train()
# Final push
trainer.save_model()
trainer.push_to_hub()
if __name__ == "__main__":
main()