54 lines
1.7 KiB
Markdown
54 lines
1.7 KiB
Markdown
|
|
---
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
library_name: transformers
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
tags:
|
||
|
|
- qwen2
|
||
|
|
- supervised-fine-tuning
|
||
|
|
- alignment
|
||
|
|
- sparsemax
|
||
|
|
- transformers
|
||
|
|
---
|
||
|
|
|
||
|
|
# Qwen2-7B-TS2
|
||
|
|
|
||
|
|
Training with Sparsemax+, Testing with Softmax
|
||
|
|
|
||
|
|
This model is a supervised fine-tuned variant of `Qwen2-7B`, trained with our TS^2 objective.
|
||
|
|
|
||
|
|
TS^2 is designed to improve alignment stability and mitigate token-level probability collapse during fine-tuning by incorporating entropy-aware adaptive weighting into the training objective.
|
||
|
|
|
||
|
|
More details could check our paper [ICLR 2026](https://openreview.net/forum?id=CylRqa82Rk) **"TS^2: Training with Sparsemax+, Testing with Softmax for Accurate and Diverse LLM Fine-Tuning"**
|
||
|
|
|
||
|
|
|
||
|
|
## Model Description
|
||
|
|
|
||
|
|
- Base model: `Qwen2-7B`
|
||
|
|
- Training method: Sparsemax+
|
||
|
|
- Objective: token-level entropy-aware TS^2-style regularization
|
||
|
|
- Framework: PyTorch + Hugging Face Transformers
|
||
|
|
- Precision: bfloat16
|
||
|
|
|
||
|
|
Instead of applying uniform likelihood maximization across all tokens as in standard supervised fine-tuning, this model introduces an adaptive weighting mechanism that dynamically adjusts training emphasis based on predictive entropy.
|
||
|
|
|
||
|
|
This design is motivated by observations that overconfident likelihood-based training may lead to:
|
||
|
|
|
||
|
|
- degeneration of token diversity
|
||
|
|
- inference-time mode collapse
|
||
|
|
- reduced generalization under distribution shift
|
||
|
|
|
||
|
|
TS^2 modifies the training objective to improve both accuracy and diversity.
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("xzybit/qwen2-7b-ts2")
|
||
|
|
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
"xzybit/qwen2-7b-ts2",
|
||
|
|
device_map="auto"
|
||
|
|
)
|
||
|
|
```
|