477 lines
12 KiB
Markdown
477 lines
12 KiB
Markdown
|
|
---
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
license: llama3.2
|
||
|
|
base_model: meta-llama/Llama-3.2-1B
|
||
|
|
tags:
|
||
|
|
- llama-3.2
|
||
|
|
- fine-tuned
|
||
|
|
- sft
|
||
|
|
- dpo
|
||
|
|
- content-safety
|
||
|
|
- aegis
|
||
|
|
- trl
|
||
|
|
- peft
|
||
|
|
- lora
|
||
|
|
- rlhf
|
||
|
|
library_name: transformers
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
datasets:
|
||
|
|
- nvidia/Aegis-AI-Content-Safety-Dataset-2.0
|
||
|
|
widget:
|
||
|
|
- text: "What is artificial intelligence?"
|
||
|
|
example_title: "AI Question"
|
||
|
|
- text: "How can I learn programming?"
|
||
|
|
example_title: "Learning Question"
|
||
|
|
- text: "Explain quantum computing in simple terms."
|
||
|
|
example_title: "Complex Topic"
|
||
|
|
---
|
||
|
|
|
||
|
|
# Llama-3.2-1B-Aegis-SFT-DPO
|
||
|
|
|
||
|
|
<div align="center">
|
||
|
|
<strong>Fine-tuned Llama 3.2 1B for Content-Safe Instruction Following</strong>
|
||
|
|
</div>
|
||
|
|
|
||
|
|
<br>
|
||
|
|
|
||
|
|
This model is a fine-tuned version of [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) using a **two-stage training approach**:
|
||
|
|
|
||
|
|
1. **Supervised Fine-Tuning (SFT)** - Teaching the model to follow instructions
|
||
|
|
2. **Direct Preference Optimization (DPO)** - Aligning with human preferences for safety
|
||
|
|
|
||
|
|
## 🎯 Model Description
|
||
|
|
|
||
|
|
- **Base Model**: meta-llama/Llama-3.2-1B
|
||
|
|
- **Fine-tuning Method**: SFT + DPO (RLHF approach)
|
||
|
|
- **Dataset**: [nvidia/Aegis-AI-Content-Safety-Dataset-2.0](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-Dataset-2.0)
|
||
|
|
- **Training Samples**: 500
|
||
|
|
- **Focus**: Content safety and responsible AI responses
|
||
|
|
- **Architecture**: Parameter Efficient Fine-Tuning (LoRA)
|
||
|
|
- **Model Size**: ~1B parameters
|
||
|
|
- **Quantization**: 4-bit during training, full precision release
|
||
|
|
|
||
|
|
## 🚀 Quick Start
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# Load model and tokenizer
|
||
|
|
model_name = "ahczhg/Llama-3.2-1B-Aegis-SFT-DPO"
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
torch_dtype=torch.bfloat16,
|
||
|
|
device_map="auto"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Prepare messages
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "What is artificial intelligence?"}
|
||
|
|
]
|
||
|
|
|
||
|
|
# Apply chat template and generate
|
||
|
|
inputs = tokenizer.apply_chat_template(
|
||
|
|
messages,
|
||
|
|
add_generation_prompt=True,
|
||
|
|
return_tensors="pt"
|
||
|
|
).to(model.device)
|
||
|
|
|
||
|
|
outputs = model.generate(
|
||
|
|
inputs,
|
||
|
|
max_new_tokens=256,
|
||
|
|
temperature=0.7,
|
||
|
|
top_p=0.9,
|
||
|
|
do_sample=True,
|
||
|
|
pad_token_id=tokenizer.eos_token_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode response
|
||
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
print(response)
|
||
|
|
```
|
||
|
|
|
||
|
|
## 📊 Training Details
|
||
|
|
|
||
|
|
### Dataset Information
|
||
|
|
|
||
|
|
- **Source**: NVIDIA Aegis AI Content Safety Dataset 2.0
|
||
|
|
- **Total Samples Used**: 500
|
||
|
|
- **SFT Split**: 400 samples (~80%)
|
||
|
|
- **DPO Split**: 100 samples (~20%)
|
||
|
|
- **Data Filtering**: Removed redacted prompts and invalid entries
|
||
|
|
- **Format**: Conversational pairs with safety labels
|
||
|
|
|
||
|
|
### Training Methodology
|
||
|
|
|
||
|
|
This model follows a two-stage approach similar to RLHF (Reinforcement Learning from Human Feedback), inspired by [AMD's Instella-3B-Instruct](https://huggingface.co/amd/Instella-3B-Instruct):
|
||
|
|
|
||
|
|
#### Stage 1: Supervised Fine-Tuning (SFT)
|
||
|
|
|
||
|
|
Teaching the model to follow the instruction format and generate appropriate responses.
|
||
|
|
|
||
|
|
**Hyperparameters**:
|
||
|
|
```yaml
|
||
|
|
Epochs: 2
|
||
|
|
Batch Size: 1
|
||
|
|
Gradient Accumulation: 8
|
||
|
|
Effective Batch Size: 8
|
||
|
|
Learning Rate: 1e-5
|
||
|
|
Optimizer: AdamW
|
||
|
|
LR Scheduler: Cosine
|
||
|
|
Warmup Steps: 100
|
||
|
|
Weight Decay: 0.1
|
||
|
|
Max Gradient Norm: 1.0
|
||
|
|
Precision: BF16
|
||
|
|
Gradient Checkpointing: True
|
||
|
|
```
|
||
|
|
|
||
|
|
#### Stage 2: Direct Preference Optimization (DPO)
|
||
|
|
|
||
|
|
Optimizing the model to prefer safe, helpful responses over problematic ones using preference learning.
|
||
|
|
|
||
|
|
**Hyperparameters**:
|
||
|
|
```yaml
|
||
|
|
Epochs: 1
|
||
|
|
Batch Size: 1
|
||
|
|
Gradient Accumulation: 8
|
||
|
|
Effective Batch Size: 8
|
||
|
|
Learning Rate: 5e-7
|
||
|
|
Beta (DPO): 0.1
|
||
|
|
Max Prompt Length: 512
|
||
|
|
Max Sequence Length: 1024
|
||
|
|
Optimizer: AdamW
|
||
|
|
LR Scheduler: Cosine
|
||
|
|
Warmup Ratio: 10%
|
||
|
|
Precision: BF16
|
||
|
|
```
|
||
|
|
|
||
|
|
### LoRA Configuration
|
||
|
|
|
||
|
|
Parameter-efficient fine-tuning using Low-Rank Adaptation:
|
||
|
|
|
||
|
|
```yaml
|
||
|
|
Rank (r): 8
|
||
|
|
Alpha: 16
|
||
|
|
Dropout: 0.05
|
||
|
|
Target Modules:
|
||
|
|
- q_proj
|
||
|
|
- k_proj
|
||
|
|
- v_proj
|
||
|
|
- o_proj
|
||
|
|
Bias: none
|
||
|
|
Task Type: CAUSAL_LM
|
||
|
|
Trainable Parameters: ~0.5% of total
|
||
|
|
```
|
||
|
|
|
||
|
|
### Training Infrastructure
|
||
|
|
|
||
|
|
- **Platform**: Google Colab
|
||
|
|
- **GPU**: NVIDIA T4 (16GB VRAM)
|
||
|
|
- **Training Quantization**: 4-bit NF4 with double quantization
|
||
|
|
- **Gradient Checkpointing**: Enabled for memory efficiency
|
||
|
|
- **Final Model Format**: Full precision (merged LoRA adapters)
|
||
|
|
- **Total Training Time**: ~30-50 minutes
|
||
|
|
|
||
|
|
## 💻 Advanced Usage
|
||
|
|
|
||
|
|
### Multi-turn Conversation
|
||
|
|
|
||
|
|
```python
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "What is machine learning?"},
|
||
|
|
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
|
||
|
|
{"role": "user", "content": "Can you give me an example?"}
|
||
|
|
]
|
||
|
|
|
||
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True)
|
||
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||
|
|
```
|
||
|
|
|
||
|
|
### Streaming Generation
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import TextIteratorStreamer
|
||
|
|
from threading import Thread
|
||
|
|
|
||
|
|
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
||
|
|
|
||
|
|
generation_kwargs = dict(
|
||
|
|
inputs=inputs,
|
||
|
|
max_new_tokens=512,
|
||
|
|
temperature=0.7,
|
||
|
|
top_p=0.9,
|
||
|
|
do_sample=True,
|
||
|
|
streamer=streamer,
|
||
|
|
pad_token_id=tokenizer.eos_token_id
|
||
|
|
)
|
||
|
|
|
||
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||
|
|
thread.start()
|
||
|
|
|
||
|
|
for new_text in streamer:
|
||
|
|
print(new_text, end="", flush=True)
|
||
|
|
|
||
|
|
thread.join()
|
||
|
|
```
|
||
|
|
|
||
|
|
### Batch Inference
|
||
|
|
|
||
|
|
```python
|
||
|
|
prompts = [
|
||
|
|
"Explain neural networks",
|
||
|
|
"What is deep learning?",
|
||
|
|
"How does backpropagation work?"
|
||
|
|
]
|
||
|
|
|
||
|
|
messages_batch = [[{"role": "user", "content": p}] for p in prompts]
|
||
|
|
|
||
|
|
# Tokenize all at once
|
||
|
|
inputs = tokenizer.apply_chat_template(
|
||
|
|
messages_batch,
|
||
|
|
add_generation_prompt=True,
|
||
|
|
return_tensors="pt",
|
||
|
|
padding=True
|
||
|
|
).to(model.device)
|
||
|
|
|
||
|
|
# Generate
|
||
|
|
outputs = model.generate(inputs, max_new_tokens=200, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
|
||
|
|
|
||
|
|
# Decode all
|
||
|
|
for output in outputs:
|
||
|
|
print(tokenizer.decode(output, skip_special_tokens=True))
|
||
|
|
print("-" * 80)
|
||
|
|
```
|
||
|
|
|
||
|
|
### Custom Generation Parameters
|
||
|
|
|
||
|
|
```python
|
||
|
|
# More creative
|
||
|
|
outputs = model.generate(
|
||
|
|
inputs,
|
||
|
|
max_new_tokens=512,
|
||
|
|
temperature=0.9, # Higher = more creative
|
||
|
|
top_p=0.95,
|
||
|
|
top_k=50,
|
||
|
|
do_sample=True,
|
||
|
|
repetition_penalty=1.1
|
||
|
|
)
|
||
|
|
|
||
|
|
# More focused/deterministic
|
||
|
|
outputs = model.generate(
|
||
|
|
inputs,
|
||
|
|
max_new_tokens=256,
|
||
|
|
temperature=0.3, # Lower = more focused
|
||
|
|
top_p=0.85,
|
||
|
|
do_sample=True,
|
||
|
|
repetition_penalty=1.05
|
||
|
|
)
|
||
|
|
```
|
||
|
|
|
||
|
|
## 🎨 Chat Template Format
|
||
|
|
|
||
|
|
The model uses Llama 3.2's official chat format with special tokens:
|
||
|
|
|
||
|
|
```
|
||
|
|
<|start_header_id|>user<|end_header_id|>
|
||
|
|
|
||
|
|
Your question here<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||
|
|
|
||
|
|
Model response here<|eot_id|>
|
||
|
|
```
|
||
|
|
|
||
|
|
The tokenizer's `apply_chat_template` method handles this automatically.
|
||
|
|
|
||
|
|
## 📈 Intended Use Cases
|
||
|
|
|
||
|
|
### ✅ Recommended Applications
|
||
|
|
|
||
|
|
- **Educational Tools**: Safe, informative responses for learning
|
||
|
|
- **Content Safety Research**: Studying AI alignment and safety
|
||
|
|
- **Prototype Development**: Building conversational AI systems
|
||
|
|
- **Instruction Following**: General-purpose task completion
|
||
|
|
- **Safe Text Generation**: Content-aware generation tasks
|
||
|
|
|
||
|
|
### ❌ Out-of-Scope Use
|
||
|
|
|
||
|
|
- **Production Systems**: Without additional safety validation
|
||
|
|
- **High-Stakes Decisions**: Medical, legal, financial advice
|
||
|
|
- **Unsupervised Deployment**: Without human oversight
|
||
|
|
- **Harmful Content**: Generating dangerous or illegal content
|
||
|
|
- **Critical Infrastructure**: Without extensive testing
|
||
|
|
|
||
|
|
## ⚠️ Limitations and Considerations
|
||
|
|
|
||
|
|
### Known Limitations
|
||
|
|
|
||
|
|
1. **Training Data**: Only 500 samples - more data could improve performance
|
||
|
|
2. **Language**: Primarily English-focused, limited multilingual capability
|
||
|
|
3. **Context Length**: Maximum of 1024 tokens
|
||
|
|
4. **Model Size**: 1B parameters - smaller than larger models, may have reduced capabilities
|
||
|
|
5. **Safety Bounds**: Fine-tuned for safety but not perfect - can still make mistakes
|
||
|
|
6. **Domain Knowledge**: Limited to training data cutoff and base model knowledge
|
||
|
|
|
||
|
|
### Biases and Ethical Considerations
|
||
|
|
|
||
|
|
- Inherits biases from base Llama 3.2 model
|
||
|
|
- Safety fine-tuning may make responses overly conservative
|
||
|
|
- Content safety dataset has its own biases
|
||
|
|
- Not suitable for all cultural contexts without adaptation
|
||
|
|
- Should be tested thoroughly before deployment
|
||
|
|
|
||
|
|
### Performance Notes
|
||
|
|
|
||
|
|
- **Speed**: ~10-20 tokens/second on T4 GPU
|
||
|
|
- **Memory**: ~4GB VRAM in BF16, ~2GB with 4-bit quantization
|
||
|
|
- **Best For**: General instruction following with safety awareness
|
||
|
|
- **Trade-offs**: Safety focus may reduce creativity in some cases
|
||
|
|
|
||
|
|
## 🔬 Evaluation
|
||
|
|
|
||
|
|
### Qualitative Assessment
|
||
|
|
|
||
|
|
The model has been tested on:
|
||
|
|
- ✅ General knowledge questions
|
||
|
|
- ✅ Instruction following tasks
|
||
|
|
- ✅ Content safety scenarios
|
||
|
|
- ✅ Multi-turn conversations
|
||
|
|
- ✅ Edge cases and adversarial prompts
|
||
|
|
|
||
|
|
### Sample Outputs
|
||
|
|
|
||
|
|
*(Coming soon - add your evaluation results)*
|
||
|
|
|
||
|
|
### Comparison to Base Model
|
||
|
|
|
||
|
|
| Metric | Base Llama 3.2 | This Model | Improvement |
|
||
|
|
|--------|---------------|------------|-------------|
|
||
|
|
| Safety Awareness | Baseline | Enhanced | +Safety Focus |
|
||
|
|
| Instruction Following | Good | Better | +SFT Training |
|
||
|
|
| Response Quality | High | High | +DPO Alignment |
|
||
|
|
|
||
|
|
## 🛠️ Technical Details
|
||
|
|
|
||
|
|
### Model Architecture
|
||
|
|
|
||
|
|
- **Base**: Llama 3.2 1B
|
||
|
|
- **Vocabulary**: 128,256 tokens
|
||
|
|
- **Hidden Size**: 2048
|
||
|
|
- **Layers**: 16
|
||
|
|
- **Attention Heads**: 32
|
||
|
|
- **Parameters**: ~1.23B total, ~6M trainable (LoRA)
|
||
|
|
|
||
|
|
### Training Efficiency
|
||
|
|
|
||
|
|
- **Trainable Params**: ~0.5% of total (LoRA adapters)
|
||
|
|
- **Memory During Training**: ~8GB VRAM (4-bit quantization)
|
||
|
|
- **Training Time**: ~40 minutes total (SFT + DPO)
|
||
|
|
- **Hardware Cost**: Free tier Google Colab (T4 GPU)
|
||
|
|
|
||
|
|
### Optimization Techniques
|
||
|
|
|
||
|
|
- ✅ 4-bit NF4 quantization
|
||
|
|
- ✅ Gradient checkpointing
|
||
|
|
- ✅ LoRA parameter-efficient fine-tuning
|
||
|
|
- ✅ Gradient accumulation
|
||
|
|
- ✅ BF16 mixed precision
|
||
|
|
- ✅ Optimized memory management
|
||
|
|
|
||
|
|
## 🙏 Acknowledgments
|
||
|
|
|
||
|
|
- **Base Model**: Meta's Llama 3.2 team for the foundation model
|
||
|
|
- **Dataset**: NVIDIA for the Aegis AI Content Safety Dataset
|
||
|
|
- **Methodology**: AMD for the Instella training approach inspiration
|
||
|
|
- **Frameworks**:
|
||
|
|
- Hugging Face Transformers, TRL, PEFT, Datasets
|
||
|
|
- PyTorch team
|
||
|
|
- Google Colab for compute resources
|
||
|
|
|
||
|
|
## 📄 License
|
||
|
|
|
||
|
|
This model is licensed under the **Llama 3.2 Community License**:
|
||
|
|
- Commercial use allowed with restrictions
|
||
|
|
- Attribution required
|
||
|
|
- Cannot be used to train other models without permission
|
||
|
|
- Full license: https://huggingface.co/meta-llama/Llama-3.2-1B
|
||
|
|
|
||
|
|
## 📚 Citations
|
||
|
|
|
||
|
|
### This Model
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{llama_3.2_1b_aegis_sft_dpo,
|
||
|
|
author = {Community Contributor},
|
||
|
|
title = {Llama-3.2-1B-Aegis-SFT-DPO: Content-Safe Fine-tuned Llama 3.2},
|
||
|
|
year = {2024},
|
||
|
|
publisher = {HuggingFace},
|
||
|
|
journal = {HuggingFace Model Hub},
|
||
|
|
howpublished = {\url{https://huggingface.co/ahczhg/Llama-3.2-1B-Aegis-SFT-DPO}}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
### Base Model
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{llama32,
|
||
|
|
title={Llama 3.2: Open Foundation and Fine-Tuned Chat Models},
|
||
|
|
author={Meta AI},
|
||
|
|
year={2024},
|
||
|
|
url={https://huggingface.co/meta-llama/Llama-3.2-1B}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
### Dataset
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@misc{aegis_dataset,
|
||
|
|
title={Aegis AI Content Safety Dataset 2.0},
|
||
|
|
author={NVIDIA},
|
||
|
|
year={2024},
|
||
|
|
url={https://huggingface.co/datasets/nvidia/Aegis-AI-Content-Safety-Dataset-2.0}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## 🔗 Links
|
||
|
|
|
||
|
|
- **Base Model**: [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B)
|
||
|
|
- **Dataset**: [nvidia/Aegis-AI-Content-Safety-Dataset-2.0](https://huggingface.co/nvidia/Aegis-AI-Content-Safety-Dataset-2.0)
|
||
|
|
- **TRL Library**: [Hugging Face TRL](https://github.com/huggingface/trl)
|
||
|
|
- **PEFT Library**: [Hugging Face PEFT](https://github.com/huggingface/peft)
|
||
|
|
|
||
|
|
## 📞 Feedback & Support
|
||
|
|
|
||
|
|
Found an issue or have suggestions? Please:
|
||
|
|
- Open an issue on the model repository
|
||
|
|
- Report safety concerns immediately
|
||
|
|
- Share your use cases and results
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
**Model Card Version**: 1.0
|
||
|
|
**Last Updated**: 2025-11-15
|
||
|
|
**Training Date**: 2025-11-15
|
||
|
|
|
||
|
|
**Framework Versions**:
|
||
|
|
- 🤗 Transformers: `4.57.1`
|
||
|
|
- 🔥 PyTorch: `2.8.0+cu126`
|
||
|
|
- 🎯 TRL: `0.25.1`
|
||
|
|
- 🔧 PEFT: `0.17.1`
|
||
|
|
- 📊 Datasets: `4.0.0`
|
||
|
|
|
||
|
|
**Compute**:
|
||
|
|
- Platform: Google Colab
|
||
|
|
- GPU: NVIDIA T4 (16GB)
|
||
|
|
- Training Duration: ~40-50 minutes
|
||
|
|
- Carbon Footprint: Minimal (free tier compute)
|
||
|
|
|
||
|
|
---
|
||
|
|
|
||
|
|
<div align="center">
|
||
|
|
<sub>Built with ❤️ using Hugging Face libraries | Trained on Google Colab | Released under Llama 3.2 License</sub>
|
||
|
|
</div>
|
||
|
|
|
||
|
|
[](https://ko-fi.com/ahczhg)
|
||
|
|
|