""" GRPO training script for palindrome generation. Trains a small LLM (Qwen2.5-0.5B-Instruct) to generate palindromes given a theme using Group Relative Policy Optimization. Reward function: 1. Palindrome accuracy (continuous partial-credit: proportion of matching mirrored character pairs) 2. Length bonus (longer palindromes = more reward) 3. Theme relevance (keyword overlap between theme and generated text) reward = α * palindrome_accuracy + β * length_bonus + γ * theme_relevance """ import re import os from datasets import load_dataset, Dataset from trl import GRPOConfig, GRPOTrainer from transformers import TrainerCallback import trackio # ── Reward function ───────────────────────────────────────────────────── def _normalize(text: str) -> str: """Strip non-alphanumeric and lowercase.""" return re.sub(r'[^a-zA-Z0-9]', '', text).lower() def _palindrome_accuracy(text: str) -> float: """ Continuous palindrome score: proportion of matching mirrored pairs. Perfect palindrome → 1.0 "ractar" (close to racecar) → partial credit Empty / single char → 0.0 or 1.0 respectively """ cleaned = _normalize(text) n = len(cleaned) if n == 0: return 0.0 if n == 1: return 1.0 matches = 0 pairs = n // 2 for i in range(pairs): if cleaned[i] == cleaned[n - 1 - i]: matches += 1 return matches / pairs def _length_bonus(text: str, max_len: int = 80) -> float: """ Normalized length reward: longer palindromes score higher. Capped at max_len characters. """ cleaned = _normalize(text) return min(len(cleaned), max_len) / max_len def _theme_relevance(text: str, theme: str) -> float: """ Simple keyword overlap: does the palindrome contain the theme word or substrings of it? Returns 0.0 - 1.0. """ text_lower = _normalize(text) theme_clean = _normalize(theme) if len(theme_clean) == 0: return 0.0 # Exact match of the full theme word if theme_clean in text_lower: return 1.0 # Partial match: theme word split into n-grams if len(theme_clean) > 3: bigrams = set(theme_clean[i:i + 2] for i in range(len(theme_clean) - 1)) match_count = sum(1 for bg in bigrams if bg in text_lower) return match_count / len(bigrams) return 0.0 def palindrome_reward( prompts, completions, completion_ids, theme=None, **kwargs, ) -> list[float]: """ GRPO reward function for palindrome generation. Args: prompts: list[list[dict]] — conversational prompts completions: list[list[dict]] — generated completions completion_ids: list[list[int]] — token IDs theme: list[str] — theme column from dataset (forwarded via **kwargs) Returns: list[float] — one reward per sample """ # Reward weights (tunable) alpha = 0.6 # palindrome accuracy weight beta = 0.25 # length bonus weight gamma = 0.15 # theme relevance weight rewards = [] for i, completion in enumerate(completions): # Extract text from conversational format if isinstance(completion, list): text = completion[0]["content"].strip() else: text = str(completion).strip() acc = _palindrome_accuracy(text) length = _length_bonus(text) # Theme from dataset column theme_text = "" if theme is not None: # theme list is flattened: [t*n_generations for t in batch_themes] theme_text = theme[i] if i < len(theme) else "" theme_rel = _theme_relevance(text, theme_text) if theme_text else 0.0 reward = alpha * acc + beta * length + gamma * theme_rel rewards.append(reward) return rewards # ── Trackio alert callback ────────────────────────────────────────────── class PalindromeAlertCallback(TrainerCallback): """Log Trackio alerts on key training events.""" def on_train_begin(self, args, state, control, **kwargs): trackio.alert( "Training started", f"Model: Qwen/Qwen2.5-0.5B-Instruct | " f"num_generations={args.num_generations} | " f"temperature={args.temperature} | " f"lr={args.learning_rate} | " f"batch_size={args.per_device_train_batch_size}×" f"{args.gradient_accumulation_steps}", level="INFO", ) def on_train_end(self, args, state, control, **kwargs): trackio.alert( "Training complete", f"Model pushed to {args.hub_model_id}", level="INFO", ) def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return # Reward monitoring reward = logs.get("reward") if reward is not None: if reward < 0.1 and state.global_step > 50: trackio.alert( "Low reward", f"reward={reward:.4f} at step {state.global_step} — " "model not producing palindromes yet. Consider increasing " "temperature for more exploration.", level="WARN", ) elif 0.1 <= reward < 0.5 and state.global_step > 50: trackio.alert( "Moderate reward", f"reward={reward:.4f} at step {state.global_step} — " "model starting to learn. Continue training.", level="INFO", ) elif reward >= 0.7: trackio.alert( "High reward milestone", f"reward={reward:.4f} at step {state.global_step} — " "model producing good palindromes!", level="INFO", ) # KL divergence monitoring kl = logs.get("kl") if kl is not None and kl > 0.5: trackio.alert( "High KL divergence", f"kl={kl:.4f} at step {state.global_step} — " "policy diverging too far. Consider lowering learning rate.", level="WARN", ) # Loss monitoring loss = logs.get("loss") if loss is not None: if loss > 10: trackio.alert( "High loss", f"loss={loss:.4f} at step {state.global_step} — " "loss spike detected. Check for instability, " "consider reducing learning rate by 0.5×.", level="WARN", ) # ── Main ──────────────────────────────────────────────────────────────── def main(): # Load dataset dataset_path = os.getenv("DATASET_PATH", "SantiagoC/palindrome-prompts") dataset = load_dataset(dataset_path, split="train") print(f"Dataset size: {len(dataset)} samples") print(f"Columns: {dataset.column_names}") print(f"First prompt: {dataset[0]['prompt']}") # Training config training_args = GRPOConfig( output_dir="./palindrome-grpo-output", # ── GRPO specific ── num_generations=4, max_completion_length=128, temperature=0.9, beta=0.0, # No reference model KL penalty scale_rewards=False, # Dr.GRPO recommendation loss_type="dapo", mask_truncated_completions=True, # ── Standard training ── learning_rate=1e-6, num_train_epochs=6, per_device_train_batch_size=2, gradient_accumulation_steps=4, bf16=True, gradient_checkpointing=True, # ── Logging ── logging_steps=10, logging_first_step=True, save_steps=100, save_total_limit=3, disable_tqdm=True, report_to="trackio", # ── Hub push ── push_to_hub=True, hub_model_id=os.getenv("HUB_MODEL_ID", "SantiagoC/palindrome-grpo"), run_name=os.getenv("RUN_NAME", "palindrome-grpo-v1"), project="palindrome-grpo", trackio_space_id=os.getenv("TRACKIO_SPACE_ID", "SantiagoC/mlintern-palindrm"), ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", reward_funcs=palindrome_reward, args=training_args, train_dataset=dataset, callbacks=[PalindromeAlertCallback()], ) trainer.train() # Final push trainer.save_model() trainer.push_to_hub() if __name__ == "__main__": main()