[feat] add small vocab table for eagle's draft model[1]. (#3822)
Co-authored-by: Achazwl <323163497@qq.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `speculative_num_steps`: How many draft passes we run before verifying.
|
||||
* `speculative_num_draft_tokens`: The number of tokens proposed in a draft.
|
||||
* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1).
|
||||
* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1).
|
||||
|
||||
|
||||
## Double Sparsity
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
"source": [
|
||||
"## EAGLE Decoding\n",
|
||||
"\n",
|
||||
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:"
|
||||
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -46,8 +46,8 @@
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"\"\"\n",
|
||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
|
||||
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
|
||||
"\"\"\"\n",
|
||||
")\n",
|
||||
@@ -103,8 +103,8 @@
|
||||
"source": [
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"\"\"\n",
|
||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
|
||||
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
|
||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
|
||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
|
||||
" --enable-torch-compile --cuda-graph-max-bs 2\n",
|
||||
"\"\"\"\n",
|
||||
@@ -135,6 +135,77 @@
|
||||
"print_highlight(f\"Response: {response}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"terminate_process(server_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n",
|
||||
"\n",
|
||||
"By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n",
|
||||
"\n",
|
||||
"In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n",
|
||||
"\n",
|
||||
"Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sglang.test.test_utils import is_in_ci\n",
|
||||
"\n",
|
||||
"if is_in_ci():\n",
|
||||
" from patch import launch_server_cmd\n",
|
||||
"else:\n",
|
||||
" from sglang.utils import launch_server_cmd\n",
|
||||
"\n",
|
||||
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
|
||||
"\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"\"\"\n",
|
||||
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
|
||||
" --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
|
||||
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
|
||||
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
|
||||
"\"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wait_for_server(f\"http://localhost:{port}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"\n",
|
||||
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
|
||||
" ],\n",
|
||||
" temperature=0,\n",
|
||||
" max_tokens=64,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print_highlight(f\"Response: {response}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -117,9 +117,14 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
if hasattr(config, "hot_vocab_size"):
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -128,6 +128,7 @@ class ServerArgs:
|
||||
speculative_num_steps: int = 5
|
||||
speculative_eagle_topk: int = 8
|
||||
speculative_num_draft_tokens: int = 64
|
||||
speculative_token_map: Optional[str] = None
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
@@ -751,6 +752,12 @@ class ServerArgs:
|
||||
help="The number of token sampled from draft model in Speculative Decoding.",
|
||||
default=ServerArgs.speculative_num_draft_tokens,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-token-map",
|
||||
type=str,
|
||||
help="The path of the draft model's small vocab table.",
|
||||
default=ServerArgs.speculative_token_map,
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
@@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
# We will capture it later
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
if server_args.speculative_token_map is not None:
|
||||
if os.path.exists(server_args.speculative_token_map):
|
||||
self.hot_token_id = torch.load(server_args.speculative_token_map)
|
||||
else:
|
||||
cache_dir = snapshot_download(
|
||||
os.path.dirname(server_args.speculative_token_map),
|
||||
ignore_patterns=["*.bin", "*.safetensors"],
|
||||
)
|
||||
file_path = os.path.join(
|
||||
cache_dir, os.path.basename(server_args.speculative_token_map)
|
||||
)
|
||||
self.hot_token_id = torch.load(file_path)
|
||||
server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
@@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Share the embedding and lm_head
|
||||
if not self.speculative_algorithm.is_nextn():
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if server_args.speculative_token_map is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = torch.tensor(
|
||||
self.hot_token_id, dtype=torch.int32, device=head.device
|
||||
)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
else:
|
||||
if server_args.speculative_token_map is not None:
|
||||
raise NotImplementedError(
|
||||
"NEXTN does not support speculative-token-map now"
|
||||
)
|
||||
self.hot_token_id = None
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
@@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
spec_info.topk_index,
|
||||
spec_info.hidden_states,
|
||||
)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
|
||||
# Return values
|
||||
score_list: List[torch.Tensor] = []
|
||||
@@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
hidden_states = logits_output.hidden_states
|
||||
|
||||
return score_list, token_list, parents_list
|
||||
|
||||
Reference in New Issue
Block a user