Files
Llama-3.2-1B-MTP-k8/README.md

163 lines
6.6 KiB
Markdown
Raw Permalink Normal View History

---
license: llama3.2
base_model: meta-llama/Llama-3.2-1B
tags:
- multi-token-prediction
- speculative-decoding
- self-distillation
- mtp
- llama
- consumer-gpu
- rtx-5090
- paper-reproduction
datasets:
- jwkirchenbauer/metamathqa-grouped-split
language:
- en
pipeline_tag: text-generation
library_name: transformers
model-index:
- name: Llama-3.2-1B-MTP-k8
results:
- task:
type: text-generation
name: GSM8K (8-shot CoT)
dataset:
type: gsm8k
name: GSM8K
metrics:
- name: exact_match (flexible)
type: exact_match
value: 5.08
- name: exact_match (strict)
type: exact_match
value: 3.03
---
# Llama-3.2-1B-MTP-k8: Multi-Token Prediction on a Single Consumer GPU
This is a reproduction of **"Multi-Token Prediction via Self-Distillation"** ([arXiv 2602.06019](https://arxiv.org/abs/2602.06019)) adapted for a single NVIDIA RTX 5090 (32GB). The original paper used 4x NVIDIA GH200 (384GB total) with Llama-3.1-8B. We scaled it down to Llama-3.2-1B on consumer hardware.
## What is Multi-Token Prediction (MTP)?
Standard language models predict **one token at a time** (autoregressive decoding). MTP trains the model to predict **multiple future tokens simultaneously** using online self-distillation:
1. A **frozen teacher** (the original model) generates soft probability distributions
2. A **trainable student** (same architecture) learns to predict k future tokens at each position
3. At inference, **ConfAdapt decoding** emits multiple tokens when the model is confident, falling back to single-token when uncertain
The result: **faster inference with minimal quality loss**.
## Results: GSM8K 8-shot Chain-of-Thought
| Configuration | Exact Match (flexible) | Exact Match (strict) | Throughput |
|---|---|---|---|
| **Baseline** (Llama-3.2-1B, standard AR) | **7.13%** ± 0.71 | **6.07%** ± 0.66 | ~1.5 s/sample |
| **MTP k=1** (single token, quality check) | 5.23% ± 0.61 | 2.96% ± 0.47 | ~2.4 s/sample |
| **MTP k=8 + ConfAdapt 90%** | 5.08% ± 0.60 | 3.03% ± 0.47 | **~1.3 s/sample** |
### Key Findings
- **ConfAdapt works:** k=8 with ConfAdapt matches k=1 quality while being **1.8x faster** (avg 2.82 tokens emitted per step)
- **Quality drop is expected:** The ~2% accuracy drop from baseline is consistent with our smaller setup (1B model, 500M training tokens vs paper's 8B model, 2B tokens)
- **The core claim holds:** Multi-token decoding via ConfAdapt preserves generation quality while improving throughput, even on a tiny 1B model
## Training Details
### What We Changed from the Paper
| Parameter | Paper (8B / 4x GH200) | Ours (1B / 1x RTX 5090) |
|---|---|---|
| Base model | Llama-3.1-8B | Llama-3.2-1B |
| GPUs | 4x GH200 (96GB each) | 1x RTX 5090 (32GB) |
| FSDP mesh | 1x4 | 1x1 (no FSDP) |
| k_toks | Randomized 2-16 across ranks | Fixed 8 |
| Training tokens | 2B | 500M |
| micro_batch_size | 32 | 8 |
| global_batch_size | 128 | 64 (grad accumulation) |
| mask_region_ct | 5 | 1 |
| rollout_multiplier | 4 | 2 |
| Template | Chat (Instruct tokenizer) | Plain text (base tokenizer) |
### What We Kept the Same
- **Supervision method:** Soft teacher via KL divergence (paper's recommended self-distillation)
- **Dataset:** MetaMathQA (`jwkirchenbauer/metamathqa-grouped-split`)
- **Sequence length:** 160 tokens
- **Peak learning rate:** 1e-5
- **Optimizer:** AdamW with cosine decay
### Training Metrics
- **Total steps:** 48,828
- **Training time:** ~17 hours on RTX 5090
- **Final train loss:** ~0.9
- **Final val loss:** 1.895 (perplexity 6.65)
## Why k=8 Instead of the Paper's Randomized k=2-16?
The paper's approach randomizes k across GPU ranks each step. With 4 GPUs, the model sees k=2, k=5, k=12, k=16 simultaneously in a single batch, learning to handle any prediction horizon.
With a single GPU, we can only train one k value per step. We chose k=8 as a middle ground — large enough to demonstrate meaningful multi-token speedup, small enough to fit in 32GB VRAM.
This is an important tradeoff: our model is specialized for k=8, while the paper's model generalizes across all k values. A production deployment would benefit from the paper's multi-GPU randomized approach.
## Infrastructure: Running on Consumer Hardware
This reproduction ran entirely on a home Kubernetes cluster:
- **GPU:** NVIDIA RTX 5090 (32GB, Blackwell architecture / sm_120)
- **System:** 16GB RAM, Debian 13
- **Stack:** Kubernetes + containerd + NVIDIA device plugin
- **PyTorch:** Nightly build with CUDA 12.8 (required for Blackwell sm_120 support)
### Challenges We Solved
1. **Blackwell GPU support:** RTX 5090 (sm_120) requires PyTorch nightly with cu128 — stable releases don't include sm_120 yet
2. **Single-GPU checkpoint saving:** The original code uses `torch.distributed.all_reduce()` for checkpoint state sync, which crashes when distributed is not initialized. We added an `is_initialized()` guard
3. **W&B configuration:** Default config points to the paper authors' organization. Override with `--wandb.entity=null`
4. **HuggingFace checkpoint format:** The litgpt converter outputs `model.pth` but transformers expects `pytorch_model.bin`
## Usage
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"celestialcreator/Llama-3.2-1B-MTP-k8",
trust_remote_code=True,
torch_dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained("celestialcreator/Llama-3.2-1B-MTP-k8")
# Standard generation (single token, works like any Llama model)
inputs = tokenizer("The capital of France is", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
For MTP inference with ConfAdapt decoding, use the [mtp-lm evaluation harness fork](https://github.com/jwkirchenbauer/lm-evaluation-harness-mtp-lm).
## Reproduction Guide
Full reproduction instructions with Kubernetes manifests and configs: [GitHub Fork](https://github.com/CelestialCreator/mtp-lm)
## Citation
If you use this model, please cite the original paper:
```bibtex
@article{kirchenbauer2025multitokenpredictionselfdistillation,
title={Multi-Token Prediction via Self-Distillation},
author={John Kirchenbauer and Jonas Geiping and Yuxin Wen and Tom Goldstein},
journal={arXiv preprint arXiv:2602.06019},
year={2025}
}
```
## Acknowledgments
- Original paper and code by [John Kirchenbauer et al.](https://github.com/jwkirchenbauer/mtp-lm)
- Built with [LitGPT](https://github.com/Lightning-AI/litgpt), [PyTorch](https://pytorch.org/), and [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)