163 lines
6.6 KiB
Markdown
163 lines
6.6 KiB
Markdown
|
|
---
|
||
|
|
license: llama3.2
|
||
|
|
base_model: meta-llama/Llama-3.2-1B
|
||
|
|
tags:
|
||
|
|
- multi-token-prediction
|
||
|
|
- speculative-decoding
|
||
|
|
- self-distillation
|
||
|
|
- mtp
|
||
|
|
- llama
|
||
|
|
- consumer-gpu
|
||
|
|
- rtx-5090
|
||
|
|
- paper-reproduction
|
||
|
|
datasets:
|
||
|
|
- jwkirchenbauer/metamathqa-grouped-split
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
library_name: transformers
|
||
|
|
model-index:
|
||
|
|
- name: Llama-3.2-1B-MTP-k8
|
||
|
|
results:
|
||
|
|
- task:
|
||
|
|
type: text-generation
|
||
|
|
name: GSM8K (8-shot CoT)
|
||
|
|
dataset:
|
||
|
|
type: gsm8k
|
||
|
|
name: GSM8K
|
||
|
|
metrics:
|
||
|
|
- name: exact_match (flexible)
|
||
|
|
type: exact_match
|
||
|
|
value: 5.08
|
||
|
|
- name: exact_match (strict)
|
||
|
|
type: exact_match
|
||
|
|
value: 3.03
|
||
|
|
---
|
||
|
|
|
||
|
|
# Llama-3.2-1B-MTP-k8: Multi-Token Prediction on a Single Consumer GPU
|
||
|
|
|
||
|
|
This is a reproduction of **"Multi-Token Prediction via Self-Distillation"** ([arXiv 2602.06019](https://arxiv.org/abs/2602.06019)) adapted for a single NVIDIA RTX 5090 (32GB). The original paper used 4x NVIDIA GH200 (384GB total) with Llama-3.1-8B. We scaled it down to Llama-3.2-1B on consumer hardware.
|
||
|
|
|
||
|
|
## What is Multi-Token Prediction (MTP)?
|
||
|
|
|
||
|
|
Standard language models predict **one token at a time** (autoregressive decoding). MTP trains the model to predict **multiple future tokens simultaneously** using online self-distillation:
|
||
|
|
|
||
|
|
1. A **frozen teacher** (the original model) generates soft probability distributions
|
||
|
|
2. A **trainable student** (same architecture) learns to predict k future tokens at each position
|
||
|
|
3. At inference, **ConfAdapt decoding** emits multiple tokens when the model is confident, falling back to single-token when uncertain
|
||
|
|
|
||
|
|
The result: **faster inference with minimal quality loss**.
|
||
|
|
|
||
|
|
## Results: GSM8K 8-shot Chain-of-Thought
|
||
|
|
|
||
|
|
| Configuration | Exact Match (flexible) | Exact Match (strict) | Throughput |
|
||
|
|
|---|---|---|---|
|
||
|
|
| **Baseline** (Llama-3.2-1B, standard AR) | **7.13%** ± 0.71 | **6.07%** ± 0.66 | ~1.5 s/sample |
|
||
|
|
| **MTP k=1** (single token, quality check) | 5.23% ± 0.61 | 2.96% ± 0.47 | ~2.4 s/sample |
|
||
|
|
| **MTP k=8 + ConfAdapt 90%** | 5.08% ± 0.60 | 3.03% ± 0.47 | **~1.3 s/sample** |
|
||
|
|
|
||
|
|
### Key Findings
|
||
|
|
|
||
|
|
- **ConfAdapt works:** k=8 with ConfAdapt matches k=1 quality while being **1.8x faster** (avg 2.82 tokens emitted per step)
|
||
|
|
- **Quality drop is expected:** The ~2% accuracy drop from baseline is consistent with our smaller setup (1B model, 500M training tokens vs paper's 8B model, 2B tokens)
|
||
|
|
- **The core claim holds:** Multi-token decoding via ConfAdapt preserves generation quality while improving throughput, even on a tiny 1B model
|
||
|
|
|
||
|
|
## Training Details
|
||
|
|
|
||
|
|
### What We Changed from the Paper
|
||
|
|
|
||
|
|
| Parameter | Paper (8B / 4x GH200) | Ours (1B / 1x RTX 5090) |
|
||
|
|
|---|---|---|
|
||
|
|
| Base model | Llama-3.1-8B | Llama-3.2-1B |
|
||
|
|
| GPUs | 4x GH200 (96GB each) | 1x RTX 5090 (32GB) |
|
||
|
|
| FSDP mesh | 1x4 | 1x1 (no FSDP) |
|
||
|
|
| k_toks | Randomized 2-16 across ranks | Fixed 8 |
|
||
|
|
| Training tokens | 2B | 500M |
|
||
|
|
| micro_batch_size | 32 | 8 |
|
||
|
|
| global_batch_size | 128 | 64 (grad accumulation) |
|
||
|
|
| mask_region_ct | 5 | 1 |
|
||
|
|
| rollout_multiplier | 4 | 2 |
|
||
|
|
| Template | Chat (Instruct tokenizer) | Plain text (base tokenizer) |
|
||
|
|
|
||
|
|
### What We Kept the Same
|
||
|
|
|
||
|
|
- **Supervision method:** Soft teacher via KL divergence (paper's recommended self-distillation)
|
||
|
|
- **Dataset:** MetaMathQA (`jwkirchenbauer/metamathqa-grouped-split`)
|
||
|
|
- **Sequence length:** 160 tokens
|
||
|
|
- **Peak learning rate:** 1e-5
|
||
|
|
- **Optimizer:** AdamW with cosine decay
|
||
|
|
|
||
|
|
### Training Metrics
|
||
|
|
|
||
|
|
- **Total steps:** 48,828
|
||
|
|
- **Training time:** ~17 hours on RTX 5090
|
||
|
|
- **Final train loss:** ~0.9
|
||
|
|
- **Final val loss:** 1.895 (perplexity 6.65)
|
||
|
|
|
||
|
|
## Why k=8 Instead of the Paper's Randomized k=2-16?
|
||
|
|
|
||
|
|
The paper's approach randomizes k across GPU ranks each step. With 4 GPUs, the model sees k=2, k=5, k=12, k=16 simultaneously in a single batch, learning to handle any prediction horizon.
|
||
|
|
|
||
|
|
With a single GPU, we can only train one k value per step. We chose k=8 as a middle ground — large enough to demonstrate meaningful multi-token speedup, small enough to fit in 32GB VRAM.
|
||
|
|
|
||
|
|
This is an important tradeoff: our model is specialized for k=8, while the paper's model generalizes across all k values. A production deployment would benefit from the paper's multi-GPU randomized approach.
|
||
|
|
|
||
|
|
## Infrastructure: Running on Consumer Hardware
|
||
|
|
|
||
|
|
This reproduction ran entirely on a home Kubernetes cluster:
|
||
|
|
|
||
|
|
- **GPU:** NVIDIA RTX 5090 (32GB, Blackwell architecture / sm_120)
|
||
|
|
- **System:** 16GB RAM, Debian 13
|
||
|
|
- **Stack:** Kubernetes + containerd + NVIDIA device plugin
|
||
|
|
- **PyTorch:** Nightly build with CUDA 12.8 (required for Blackwell sm_120 support)
|
||
|
|
|
||
|
|
### Challenges We Solved
|
||
|
|
|
||
|
|
1. **Blackwell GPU support:** RTX 5090 (sm_120) requires PyTorch nightly with cu128 — stable releases don't include sm_120 yet
|
||
|
|
2. **Single-GPU checkpoint saving:** The original code uses `torch.distributed.all_reduce()` for checkpoint state sync, which crashes when distributed is not initialized. We added an `is_initialized()` guard
|
||
|
|
3. **W&B configuration:** Default config points to the paper authors' organization. Override with `--wandb.entity=null`
|
||
|
|
4. **HuggingFace checkpoint format:** The litgpt converter outputs `model.pth` but transformers expects `pytorch_model.bin`
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
```python
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
"celestialcreator/Llama-3.2-1B-MTP-k8",
|
||
|
|
trust_remote_code=True,
|
||
|
|
torch_dtype="float16",
|
||
|
|
)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("celestialcreator/Llama-3.2-1B-MTP-k8")
|
||
|
|
|
||
|
|
# Standard generation (single token, works like any Llama model)
|
||
|
|
inputs = tokenizer("The capital of France is", return_tensors="pt")
|
||
|
|
outputs = model.generate(**inputs, max_new_tokens=50)
|
||
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||
|
|
```
|
||
|
|
|
||
|
|
For MTP inference with ConfAdapt decoding, use the [mtp-lm evaluation harness fork](https://github.com/jwkirchenbauer/lm-evaluation-harness-mtp-lm).
|
||
|
|
|
||
|
|
## Reproduction Guide
|
||
|
|
|
||
|
|
Full reproduction instructions with Kubernetes manifests and configs: [GitHub Fork](https://github.com/CelestialCreator/mtp-lm)
|
||
|
|
|
||
|
|
## Citation
|
||
|
|
|
||
|
|
If you use this model, please cite the original paper:
|
||
|
|
|
||
|
|
```bibtex
|
||
|
|
@article{kirchenbauer2025multitokenpredictionselfdistillation,
|
||
|
|
title={Multi-Token Prediction via Self-Distillation},
|
||
|
|
author={John Kirchenbauer and Jonas Geiping and Yuxin Wen and Tom Goldstein},
|
||
|
|
journal={arXiv preprint arXiv:2602.06019},
|
||
|
|
year={2025}
|
||
|
|
}
|
||
|
|
```
|
||
|
|
|
||
|
|
## Acknowledgments
|
||
|
|
|
||
|
|
- Original paper and code by [John Kirchenbauer et al.](https://github.com/jwkirchenbauer/mtp-lm)
|
||
|
|
- Built with [LitGPT](https://github.com/Lightning-AI/litgpt), [PyTorch](https://pytorch.org/), and [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
||
|
|
|