193 lines
5.1 KiB
Markdown
193 lines
5.1 KiB
Markdown
---
|
|
language:
|
|
- it
|
|
- en
|
|
license: gemma
|
|
tags:
|
|
- ghigliottina
|
|
- italian
|
|
- grpo
|
|
- lora-merged
|
|
- transformers
|
|
library_name: transformers
|
|
pipeline_tag: text-generation
|
|
base_model: google/gemma-3-1b-it
|
|
---
|
|
|
|
# gemma-3-1b-it-ghigliottina-grpo-merged-ckpt564
|
|
|
|
Modello **merged** (base + LoRA) per task stile *Ghigliottina* in italiano.
|
|
|
|
## Cosa contiene questo repo
|
|
|
|
Questo repository HF contiene il risultato del merge di:
|
|
|
|
- **Base model:** `google/gemma-3-1b-it`
|
|
- **Adapter:** `outputs/gemma-3-1b-grpo-train-v2-3ep/checkpoint-564`
|
|
- **Metodo merge:** `peft.PeftModel.merge_and_unload()`
|
|
|
|
Il risultato è un modello standalone (`model.safetensors`) caricabile direttamente con `transformers` senza dover applicare un adapter separato.
|
|
|
|
---
|
|
|
|
## Cosa è stato fatto (pipeline di training)
|
|
|
|
Nel progetto locale `python_project/grpo-training` è stata eseguita questa pipeline GRPO:
|
|
|
|
1. **Dataset preparation**
|
|
- parsing XML e conversione in JSONL (`data_prep.py`)
|
|
- conversione dei prompt in formato **bullet-only** (indizi a elenco puntato)
|
|
- split train/holdout dedicati
|
|
|
|
2. **Prompting setup**
|
|
- system prompt orientato alla Ghigliottina con formato target:
|
|
- `<think>...</think>`
|
|
- `soluzione: <parola>.`
|
|
- vincoli su output finale (no testo extra, soluzione diversa dagli indizi)
|
|
|
|
3. **Training GRPO su Gemma 3 1B IT**
|
|
- base model: `google/gemma-3-1b-it`
|
|
- LoRA training con Unsloth + TRL/GRPO
|
|
- reward shaping multi-componente:
|
|
- format rewards (strict/soft + xmlcount)
|
|
- exact match
|
|
- embedding similarity
|
|
- reasoning rewards (steps + coverage)
|
|
- penalità (missing final answer, extra text after solution, solution-in-clues)
|
|
|
|
4. **Monitoring & eval**
|
|
- tracking TensorBoard reward-by-reward
|
|
- eval holdout a temperature 1.0 e 0.0 (con parser allineato al formato)
|
|
|
|
5. **Publishing (questo repo)**
|
|
- merge base + adapter (`checkpoint-564`) con `merge_and_unload`
|
|
- salvataggio modello merged + tokenizer
|
|
- smoke test di caricamento
|
|
- pubblicazione su Hugging Face
|
|
|
|
Script principali usati:
|
|
|
|
- `train_grpo.py`
|
|
- `reward_functions.py`
|
|
- `eval_accuracy.py`
|
|
- `plot_rewards.py`
|
|
- `merge_and_push_hf.py`
|
|
|
|
Comando usato per il merge + publish:
|
|
|
|
```bash
|
|
python merge_and_push_hf.py \
|
|
--adapter-path ./outputs/gemma-3-1b-grpo-train-v2-3ep/checkpoint-564 \
|
|
--out-dir ./outputs/merged_test_checkpoint564 \
|
|
--hub-repo descansodj/gemma-3-1b-it-ghigliottina-grpo-merged-ckpt564 \
|
|
--push \
|
|
--device cpu
|
|
```
|
|
|
|
---
|
|
|
|
## Prompting atteso (training format)
|
|
|
|
Il training era orientato a questo formato:
|
|
|
|
```text
|
|
<think>
|
|
...
|
|
</think>
|
|
soluzione: <parola>.
|
|
```
|
|
|
|
Input user tipico (bullet-only):
|
|
|
|
```text
|
|
- gatto
|
|
- amico
|
|
- compagnia
|
|
- pelo
|
|
- caccia
|
|
```
|
|
|
|
> Nota: essendo un checkpoint intermedio mergiato, il modello può ancora produrre output non perfettamente strict.
|
|
|
|
---
|
|
|
|
## Come usarlo (Transformers)
|
|
|
|
```python
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
import torch
|
|
|
|
repo_id = "descansodj/gemma-3-1b-it-ghigliottina-grpo-merged-ckpt564"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
repo_id,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
)
|
|
|
|
system_prompt = '''Sei un assistente che risolve il gioco della Ghigliottina, in cui bisogna trovare la “parola comune”.
|
|
Ti verranno fornite 5 parole-indizio: devi trovare UNA sola parola che è collegata a tutte le 5 parole-indizio.
|
|
|
|
Regole di output (obbligatorie):
|
|
1) Prima ragiona a lungo nei tag:
|
|
<think>
|
|
...
|
|
</think>
|
|
2) Nel ragionamento valuta più ipotesi di soluzione e scarta quelle deboli.
|
|
3) Dopo </think> scrivi esattamente: "soluzione: <parola>.\n" (in minuscolo, senza punteggiatura nella parola).
|
|
4) La parola soluzione deve essere diversa da tutti gli indizi forniti.
|
|
5) Non aggiungere testo dopo la parola finale.
|
|
|
|
Esempio SOLO di formato (non del contenuto):
|
|
"""
|
|
<think>
|
|
valuto più ipotesi di soluzione e scarto quelle deboli...
|
|
</think>
|
|
soluzione: <parola>.
|
|
"""
|
|
'''
|
|
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": "\n- gatto\n- amico\n- compagnia\n- pelo\n- caccia"},
|
|
]
|
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
|
|
|
with torch.no_grad():
|
|
out = model.generate(
|
|
**inputs,
|
|
do_sample=True,
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
top_k=64,
|
|
num_beams=1,
|
|
max_new_tokens=256,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
|
print(completion)
|
|
```
|
|
|
|
---
|
|
|
|
## Limitazioni note
|
|
|
|
- `exact_match` non ancora alto al checkpoint usato.
|
|
- `strict_format` non sempre rispettato.
|
|
- Modello utile come baseline merged; non è la versione finale “best”.
|
|
|
|
---
|
|
|
|
## Provenienza e riproducibilità
|
|
|
|
- Progetto: `python_project/grpo-training`
|
|
- Training GRPO: Gemma 3 1B IT + reward shaping custom (format/reasoning/embedding)
|
|
- Checkpoint mergeato: `checkpoint-564`
|
|
|
|
Se vuoi una release più stabile, conviene pubblicare anche il merge da adapter finale e confrontare i due repo HF.
|