diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index e9f007ca2..e46453b46 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -150,6 +150,68 @@ "terminate_process(server_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n", + "\n", + "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n", + "\n", + "In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n", + "\n", + "Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n", + " --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {},