Co-authored-by: Chayenne <zhaochenyang@ucla.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: Jhin <jhinpan@umich.edu>
182 lines
5.1 KiB
Plaintext
182 lines
5.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Speculative Decoding\n",
|
|
"\n",
|
|
"SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n",
|
|
"\n",
|
|
"To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n",
|
|
"> ```bash\n",
|
|
"> pip install cutex\n",
|
|
"> ```\n",
|
|
"\n",
|
|
"### Performance Highlights\n",
|
|
"\n",
|
|
"- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n",
|
|
"- Standard SGLang Decoding: ~156 tokens/s\n",
|
|
"- EAGLE Decoding in SGLang: ~297 tokens/s\n",
|
|
"- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n",
|
|
"\n",
|
|
"All benchmarks below were run on a single H100."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## EAGLE Decoding\n",
|
|
"\n",
|
|
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# EAGLE decoding\n",
|
|
"from sglang.utils import (\n",
|
|
" execute_shell_command,\n",
|
|
" wait_for_server,\n",
|
|
" terminate_process,\n",
|
|
" print_highlight,\n",
|
|
")\n",
|
|
"\n",
|
|
"server_process = execute_shell_command(\n",
|
|
" \"\"\"\n",
|
|
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
|
|
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
|
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020\n",
|
|
"\"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"wait_for_server(\"http://localhost:30020\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import openai\n",
|
|
"\n",
|
|
"client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
|
|
"\n",
|
|
"response = client.chat.completions.create(\n",
|
|
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
|
|
" messages=[\n",
|
|
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
|
|
" ],\n",
|
|
" temperature=0,\n",
|
|
" max_tokens=64,\n",
|
|
")\n",
|
|
"\n",
|
|
"print_highlight(f\"Response: {response}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"terminate_process(server_process)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### EAGLE Decoding with `torch.compile`\n",
|
|
"\n",
|
|
"You can also enable `torch.compile` for further optimizations and optionally set `--cuda-graph-max-bs`:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"server_process = execute_shell_command(\n",
|
|
" \"\"\"\n",
|
|
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
|
|
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
|
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n",
|
|
" --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n",
|
|
"\"\"\"\n",
|
|
")\n",
|
|
"\n",
|
|
"wait_for_server(\"http://localhost:30020\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Benchmark Script\n",
|
|
"\n",
|
|
"The following code example shows how to measure the decoding speed when generating tokens:\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"import requests\n",
|
|
"\n",
|
|
"tic = time.time()\n",
|
|
"response = requests.post(\n",
|
|
" \"http://localhost:30020/generate\",\n",
|
|
" json={\n",
|
|
" \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n",
|
|
" \"sampling_params\": {\n",
|
|
" \"temperature\": 0,\n",
|
|
" \"max_new_tokens\": 256,\n",
|
|
" },\n",
|
|
" },\n",
|
|
")\n",
|
|
"latency = time.time() - tic\n",
|
|
"ret = response.json()\n",
|
|
"completion_text = ret[\"text\"]\n",
|
|
"speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n",
|
|
"\n",
|
|
"print_highlight(completion_text)\n",
|
|
"print_highlight(f\"speed: {speed:.2f} token/s\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"terminate_process(server_process)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|