193 lines
5.5 KiB
Markdown
193 lines
5.5 KiB
Markdown
|
|
---
|
||
|
|
license: other
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
library_name: transformers
|
||
|
|
tags:
|
||
|
|
- llama
|
||
|
|
- llama-3
|
||
|
|
- causal-lm
|
||
|
|
- clinical
|
||
|
|
- grpo
|
||
|
|
- lora
|
||
|
|
- merged-adapter
|
||
|
|
- transformers
|
||
|
|
base_model: meta-llama/Llama-3.2-3B-Instruct
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
---
|
||
|
|
|
||
|
|
# LLMasRNN GRPO Policy Epoch 001 Merged
|
||
|
|
|
||
|
|
This repository contains a merged causal language model produced by taking a LoRA adapter trained in the `LLMasRNN` project and merging it into the base model `meta-llama/Llama-3.2-3B-Instruct`.
|
||
|
|
|
||
|
|
This is the merged checkpoint for:
|
||
|
|
|
||
|
|
- Project: `LLMasRNN`
|
||
|
|
- Training stage: `GRPO Phase 1`
|
||
|
|
- Epoch artifact: `training/artifacts/epoch_001`
|
||
|
|
- Base model: `meta-llama/Llama-3.2-3B-Instruct`
|
||
|
|
- Output repo type: merged full weights
|
||
|
|
|
||
|
|
## What This Model Is
|
||
|
|
|
||
|
|
The policy model is intended for longitudinal clinical prediction workflows in the `LLMasRNN` project. In this training phase, the model is optimized as a memory-update / policy head using GRPO-style reinforcement learning over trajectory rollouts and rubric-based rewards.
|
||
|
|
|
||
|
|
The uploaded weights in this repository are not just the adapter. They are the full merged model weights created from:
|
||
|
|
|
||
|
|
- Base model weights from `meta-llama/Llama-3.2-3B-Instruct`
|
||
|
|
- A LoRA adapter saved at `training/artifacts/epoch_001/policy_lora`
|
||
|
|
|
||
|
|
## Training Summary
|
||
|
|
|
||
|
|
The training run used a custom GRPO training loop implemented in `training/train/policy_trainer.py` and configured by `training/configs/grpo_phase1.yaml`.
|
||
|
|
|
||
|
|
High-level pipeline for one epoch:
|
||
|
|
|
||
|
|
1. Collect RLN trajectories over the training split.
|
||
|
|
2. Sample rollout candidates from the policy.
|
||
|
|
3. Score candidates with a fixed rubric judge.
|
||
|
|
4. Train the LoRA policy with a custom GRPO loss.
|
||
|
|
|
||
|
|
This merged checkpoint corresponds to the first completed epoch.
|
||
|
|
|
||
|
|
## Base Model
|
||
|
|
|
||
|
|
- Base model: `meta-llama/Llama-3.2-3B-Instruct`
|
||
|
|
- Architecture: causal LM
|
||
|
|
- Total parameters after merge: `3,221,924,864`
|
||
|
|
|
||
|
|
## LoRA Configuration
|
||
|
|
|
||
|
|
The policy was trained as a LoRA adapter before merge.
|
||
|
|
|
||
|
|
- PEFT type: `LoRA`
|
||
|
|
- Task type: `CAUSAL_LM`
|
||
|
|
- Rank `r`: `16`
|
||
|
|
- LoRA alpha: `32`
|
||
|
|
- LoRA dropout: `0.05`
|
||
|
|
- Target modules:
|
||
|
|
- `q_proj`
|
||
|
|
- `k_proj`
|
||
|
|
- `v_proj`
|
||
|
|
- `o_proj`
|
||
|
|
- Trainable parameters before merge: `9,175,040`
|
||
|
|
- Trainable fraction before merge: `0.2848%`
|
||
|
|
|
||
|
|
## GRPO / RL Training Configuration
|
||
|
|
|
||
|
|
From `training/configs/grpo_phase1.yaml`:
|
||
|
|
|
||
|
|
- Framework: custom GRPO implementation
|
||
|
|
- KL coefficient: `0.05`
|
||
|
|
- Clip range: `0.2`
|
||
|
|
- Importance clip: `5.0`
|
||
|
|
- Inner steps per epoch: `4`
|
||
|
|
- Batch size: `4` trajectory groups per update step
|
||
|
|
- Rollouts per state / group size `G`: `8`
|
||
|
|
- Learning rate: `5e-6`
|
||
|
|
- Gradient clipping: `1.0`
|
||
|
|
- Warmup steps: `50`
|
||
|
|
|
||
|
|
## Sampling and Reward Setup
|
||
|
|
|
||
|
|
### Policy rollout sampling
|
||
|
|
|
||
|
|
- Temperature: `0.8`
|
||
|
|
- Top-p: `0.95`
|
||
|
|
- Max new tokens: `512`
|
||
|
|
|
||
|
|
### Predictor model
|
||
|
|
|
||
|
|
- Model: `meta-llama/Llama-3.2-3B-Instruct`
|
||
|
|
- Backend: `vLLM`
|
||
|
|
- Max model length: `8192`
|
||
|
|
|
||
|
|
### RLN judge
|
||
|
|
|
||
|
|
- Model: `jinrui123/sft_llama3.2_3b_merged`
|
||
|
|
- Backend: `vLLM`
|
||
|
|
- Max model length: `8192`
|
||
|
|
- Temperature: `0.3`
|
||
|
|
- Max tokens: `1536`
|
||
|
|
|
||
|
|
### Rubric judge
|
||
|
|
|
||
|
|
- Model path during training: `/data/jf44684/TrainingDataParepation/models/RubricARM-8B-Judge`
|
||
|
|
- Backend: `vLLM`
|
||
|
|
- Temperature: `0.3`
|
||
|
|
- Max tokens: `1536`
|
||
|
|
|
||
|
|
### Reward composition
|
||
|
|
|
||
|
|
- Downstream accuracy bonus enabled: `true`
|
||
|
|
- Epsilon accuracy bonus: `0.2`
|
||
|
|
|
||
|
|
## Data / Run Size
|
||
|
|
|
||
|
|
From the saved epoch metadata:
|
||
|
|
|
||
|
|
- Training split: `data/splits/cleaned_df_train_100.json`
|
||
|
|
- Validation split: `data/splits/cleaned_df_val_100.json`
|
||
|
|
- Number of collected trajectory steps: `390`
|
||
|
|
- Number of scored rollout groups: `390`
|
||
|
|
- Debug mode: `false`
|
||
|
|
|
||
|
|
## Epoch 001 Metrics
|
||
|
|
|
||
|
|
Saved in `training/artifacts/epoch_001/meta.json`:
|
||
|
|
|
||
|
|
- Loss: `0.436767578125`
|
||
|
|
- Policy loss: `0.435302734375`
|
||
|
|
- KL: `0.0322265625`
|
||
|
|
- Mean absolute advantage: `0.82421875`
|
||
|
|
- Mean ratio: `0.8505859375`
|
||
|
|
- Mean reward: `0.7509765625`
|
||
|
|
|
||
|
|
These are training-time epoch averages over the 4 inner GRPO update steps for epoch 1.
|
||
|
|
|
||
|
|
## Merge Details
|
||
|
|
|
||
|
|
This repository was created by:
|
||
|
|
|
||
|
|
1. Loading the base model `meta-llama/Llama-3.2-3B-Instruct`
|
||
|
|
2. Loading the saved LoRA adapter from `training/artifacts/epoch_001/policy_lora`
|
||
|
|
3. Calling `merge_and_unload()` with PEFT
|
||
|
|
4. Saving the resulting merged full model weights
|
||
|
|
|
||
|
|
The original adapter checkpoint remains the more storage-efficient representation for continued adapter-based training.
|
||
|
|
|
||
|
|
## Intended Use
|
||
|
|
|
||
|
|
This checkpoint is intended for research and experimentation within the `LLMasRNN` project setting:
|
||
|
|
|
||
|
|
- longitudinal clinical prediction
|
||
|
|
- diagnosis prediction conditioned on evolving patient summaries
|
||
|
|
- RL / GRPO policy experiments
|
||
|
|
- ablation or evaluation of merged policy checkpoints
|
||
|
|
|
||
|
|
It is not validated for clinical deployment, medical decision support in production, or unsupervised real-world medical use.
|
||
|
|
|
||
|
|
## Limitations
|
||
|
|
|
||
|
|
- This is an epoch-1 checkpoint, not a converged final model.
|
||
|
|
- The training objective is project-specific and depends on custom reward shaping and judge models.
|
||
|
|
- Training and evaluation were performed on internal project data/configuration rather than a standardized public benchmark release.
|
||
|
|
- Merged weights inherit the usage constraints and limitations of the base Llama model.
|
||
|
|
|
||
|
|
## Loading
|
||
|
|
|
||
|
|
Example with Transformers:
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
repo_id = "jinrui123/llamasrnn-grpo-epoch001-merged"
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(repo_id)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(repo_id)
|
||
|
|
```
|
||
|
|
|
||
|
|
## Notes on Licensing
|
||
|
|
|
||
|
|
This repository contains merged weights derived from `meta-llama/Llama-3.2-3B-Instruct`. Use and redistribution must comply with the license and access terms of the original base model.
|