--- license: llama3.2 base_model: meta-llama/Llama-3.2-1B tags: - multi-token-prediction - speculative-decoding - self-distillation - mtp - llama - consumer-gpu - rtx-5090 - paper-reproduction datasets: - jwkirchenbauer/metamathqa-grouped-split language: - en pipeline_tag: text-generation library_name: transformers model-index: - name: Llama-3.2-1B-MTP-k8 results: - task: type: text-generation name: GSM8K (8-shot CoT) dataset: type: gsm8k name: GSM8K metrics: - name: exact_match (flexible) type: exact_match value: 5.08 - name: exact_match (strict) type: exact_match value: 3.03 --- # Llama-3.2-1B-MTP-k8: Multi-Token Prediction on a Single Consumer GPU This is a reproduction of **"Multi-Token Prediction via Self-Distillation"** ([arXiv 2602.06019](https://arxiv.org/abs/2602.06019)) adapted for a single NVIDIA RTX 5090 (32GB). The original paper used 4x NVIDIA GH200 (384GB total) with Llama-3.1-8B. We scaled it down to Llama-3.2-1B on consumer hardware. ## What is Multi-Token Prediction (MTP)? Standard language models predict **one token at a time** (autoregressive decoding). MTP trains the model to predict **multiple future tokens simultaneously** using online self-distillation: 1. A **frozen teacher** (the original model) generates soft probability distributions 2. A **trainable student** (same architecture) learns to predict k future tokens at each position 3. At inference, **ConfAdapt decoding** emits multiple tokens when the model is confident, falling back to single-token when uncertain The result: **faster inference with minimal quality loss**. ## Results: GSM8K 8-shot Chain-of-Thought | Configuration | Exact Match (flexible) | Exact Match (strict) | Throughput | |---|---|---|---| | **Baseline** (Llama-3.2-1B, standard AR) | **7.13%** ± 0.71 | **6.07%** ± 0.66 | ~1.5 s/sample | | **MTP k=1** (single token, quality check) | 5.23% ± 0.61 | 2.96% ± 0.47 | ~2.4 s/sample | | **MTP k=8 + ConfAdapt 90%** | 5.08% ± 0.60 | 3.03% ± 0.47 | **~1.3 s/sample** | ### Key Findings - **ConfAdapt works:** k=8 with ConfAdapt matches k=1 quality while being **1.8x faster** (avg 2.82 tokens emitted per step) - **Quality drop is expected:** The ~2% accuracy drop from baseline is consistent with our smaller setup (1B model, 500M training tokens vs paper's 8B model, 2B tokens) - **The core claim holds:** Multi-token decoding via ConfAdapt preserves generation quality while improving throughput, even on a tiny 1B model ## Training Details ### What We Changed from the Paper | Parameter | Paper (8B / 4x GH200) | Ours (1B / 1x RTX 5090) | |---|---|---| | Base model | Llama-3.1-8B | Llama-3.2-1B | | GPUs | 4x GH200 (96GB each) | 1x RTX 5090 (32GB) | | FSDP mesh | 1x4 | 1x1 (no FSDP) | | k_toks | Randomized 2-16 across ranks | Fixed 8 | | Training tokens | 2B | 500M | | micro_batch_size | 32 | 8 | | global_batch_size | 128 | 64 (grad accumulation) | | mask_region_ct | 5 | 1 | | rollout_multiplier | 4 | 2 | | Template | Chat (Instruct tokenizer) | Plain text (base tokenizer) | ### What We Kept the Same - **Supervision method:** Soft teacher via KL divergence (paper's recommended self-distillation) - **Dataset:** MetaMathQA (`jwkirchenbauer/metamathqa-grouped-split`) - **Sequence length:** 160 tokens - **Peak learning rate:** 1e-5 - **Optimizer:** AdamW with cosine decay ### Training Metrics - **Total steps:** 48,828 - **Training time:** ~17 hours on RTX 5090 - **Final train loss:** ~0.9 - **Final val loss:** 1.895 (perplexity 6.65) ## Why k=8 Instead of the Paper's Randomized k=2-16? The paper's approach randomizes k across GPU ranks each step. With 4 GPUs, the model sees k=2, k=5, k=12, k=16 simultaneously in a single batch, learning to handle any prediction horizon. With a single GPU, we can only train one k value per step. We chose k=8 as a middle ground — large enough to demonstrate meaningful multi-token speedup, small enough to fit in 32GB VRAM. This is an important tradeoff: our model is specialized for k=8, while the paper's model generalizes across all k values. A production deployment would benefit from the paper's multi-GPU randomized approach. ## Infrastructure: Running on Consumer Hardware This reproduction ran entirely on a home Kubernetes cluster: - **GPU:** NVIDIA RTX 5090 (32GB, Blackwell architecture / sm_120) - **System:** 16GB RAM, Debian 13 - **Stack:** Kubernetes + containerd + NVIDIA device plugin - **PyTorch:** Nightly build with CUDA 12.8 (required for Blackwell sm_120 support) ### Challenges We Solved 1. **Blackwell GPU support:** RTX 5090 (sm_120) requires PyTorch nightly with cu128 — stable releases don't include sm_120 yet 2. **Single-GPU checkpoint saving:** The original code uses `torch.distributed.all_reduce()` for checkpoint state sync, which crashes when distributed is not initialized. We added an `is_initialized()` guard 3. **W&B configuration:** Default config points to the paper authors' organization. Override with `--wandb.entity=null` 4. **HuggingFace checkpoint format:** The litgpt converter outputs `model.pth` but transformers expects `pytorch_model.bin` ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "celestialcreator/Llama-3.2-1B-MTP-k8", trust_remote_code=True, torch_dtype="float16", ) tokenizer = AutoTokenizer.from_pretrained("celestialcreator/Llama-3.2-1B-MTP-k8") # Standard generation (single token, works like any Llama model) inputs = tokenizer("The capital of France is", return_tensors="pt") outputs = model.generate(**inputs, max_new_tokens=50) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` For MTP inference with ConfAdapt decoding, use the [mtp-lm evaluation harness fork](https://github.com/jwkirchenbauer/lm-evaluation-harness-mtp-lm). ## Reproduction Guide Full reproduction instructions with Kubernetes manifests and configs: [GitHub Fork](https://github.com/CelestialCreator/mtp-lm) ## Citation If you use this model, please cite the original paper: ```bibtex @article{kirchenbauer2025multitokenpredictionselfdistillation, title={Multi-Token Prediction via Self-Distillation}, author={John Kirchenbauer and Jonas Geiping and Yuxin Wen and Tom Goldstein}, journal={arXiv preprint arXiv:2602.06019}, year={2025} } ``` ## Acknowledgments - Original paper and code by [John Kirchenbauer et al.](https://github.com/jwkirchenbauer/mtp-lm) - Built with [LitGPT](https://github.com/Lightning-AI/litgpt), [PyTorch](https://pytorch.org/), and [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness)