diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7a614b61a..db8bb9514 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma * `speculative_num_steps`: How many draft passes we run before verifying. * `speculative_num_draft_tokens`: The number of tokens proposed in a draft. * `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). +* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1). ## Double Sparsity diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index c7bd7fb31..2ca1c84fb 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -26,7 +26,7 @@ "source": [ "## EAGLE Decoding\n", "\n", - "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" + "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:" ] }, { @@ -46,8 +46,8 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n", "\"\"\"\n", ")\n", @@ -103,8 +103,8 @@ "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n", " --enable-torch-compile --cuda-graph-max-bs 2\n", "\"\"\"\n", @@ -135,6 +135,77 @@ "print_highlight(f\"Response: {response}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n", + "\n", + "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n", + "\n", + "In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n", + "\n", + "Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sglang.test.test_utils import is_in_ci\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + "\n", + "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", + "\n", + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n", + " --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 09bfbb170..4f34e625e 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -117,9 +117,14 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + if hasattr(config, "hot_vocab_size"): + self.lm_head = ParallelLMHead( + config.hot_vocab_size, config.hidden_size, quant_config=quant_config + ) + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ddb10e390..296400fa6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -128,6 +128,7 @@ class ServerArgs: speculative_num_steps: int = 5 speculative_eagle_topk: int = 8 speculative_num_draft_tokens: int = 64 + speculative_token_map: Optional[str] = None # Double Sparsity enable_double_sparsity: bool = False @@ -751,6 +752,12 @@ class ServerArgs: help="The number of token sampled from draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-token-map", + type=str, + help="The path of the draft model's small vocab table.", + default=ServerArgs.speculative_token_map, + ) # Double Sparsity parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index eb8e839f9..7639bd999 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,8 +1,10 @@ import logging +import os import time from typing import List, Optional, Union import torch +from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import Req, ScheduleBatch @@ -44,6 +46,23 @@ class EAGLEWorker(TpModelWorker): # We will capture it later backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True + + if server_args.speculative_token_map is not None: + if os.path.exists(server_args.speculative_token_map): + self.hot_token_id = torch.load(server_args.speculative_token_map) + else: + cache_dir = snapshot_download( + os.path.dirname(server_args.speculative_token_map), + ignore_patterns=["*.bin", "*.safetensors"], + ) + file_path = os.path.join( + cache_dir, os.path.basename(server_args.speculative_token_map) + ) + self.hot_token_id = torch.load(file_path) + server_args.json_model_override_args = ( + f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' + ) + super().__init__( gpu_id=gpu_id, tp_rank=tp_rank, @@ -66,7 +85,21 @@ class EAGLEWorker(TpModelWorker): # Share the embedding and lm_head if not self.speculative_algorithm.is_nextn(): embed, head = self.target_worker.model_runner.model.get_embed_and_head() + if server_args.speculative_token_map is not None: + head = head.clone() + self.hot_token_id = torch.tensor( + self.hot_token_id, dtype=torch.int32, device=head.device + ) + head.data = head.data[self.hot_token_id] + else: + self.hot_token_id = None self.model_runner.model.set_embed_and_head(embed, head) + else: + if server_args.speculative_token_map is not None: + raise NotImplementedError( + "NEXTN does not support speculative-token-map now" + ) + self.hot_token_id = None self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners @@ -223,6 +256,8 @@ class EAGLEWorker(TpModelWorker): spec_info.topk_index, spec_info.hidden_states, ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] # Return values score_list: List[torch.Tensor] = [] @@ -262,6 +297,8 @@ class EAGLEWorker(TpModelWorker): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] hidden_states = logits_output.hidden_states return score_list, token_list, parents_list