diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index e46453b46..061f415fd 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -212,6 +212,64 @@ "terminate_process(server_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE-3 Decoding\n", + "\n", + "You can enable EAGLE-3 decoding by setting `--speculative_draft_model_path: EAGLE3`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algorithm EAGLE3 \\\n", + " --speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 32 --mem-fraction 0.6 \\\n", + " --cuda-graph-max-bs 2 --dtype float16\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -223,6 +281,7 @@ "- Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence $(f_1, ..., f_k)$ and the token sequence $(t_2, ..., t_{k+1})$. \n", "- The next token is then sampled from $p_{k+2}=\\text{LMHead}(f_{k+1})$. Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the `speculative_eagle_topk` parameter—to ensure a more coherent connection of context, and are given as input again.\n", "- EAGLE-2 additionally uses the draft model to evaluate how probable certain branches in the draft tree are, dynamically stopping the expansion of unlikely branches. After the expansion phase, reranking is employed to select only the top `speculative_num_draft_tokens` final nodes as draft tokens.\n", + "- EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.\n", "\n", "This enhances drafting accuracy by operating on the features instead of tokens for more regular inputs and passing the tokens from the next timestep additionaly to minimize randomness effects from sampling. Furthermore the dynamic adjustment of the draft tree and selection of reranked final nodes increases acceptance rate of draft tokens further. For more details see [the paper](https://arxiv.org/abs/2406.16858)." ] diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b398e052d..0f8329ae2 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -223,16 +223,18 @@ class LogitsProcessor(nn.Module): hidden_states, lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], + aux_hidden_states: Optional[torch.Tensor] = None, ) -> LogitsProcessorOutput: if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) - # Get the last hidden states and last logits for the next token prediction if ( logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() ): pruned_states = hidden_states + if aux_hidden_states is not None: + aux_pruned_states = [hidden for hidden in aux_hidden_states] sample_indices = None input_logprob_indices = None elif ( @@ -256,6 +258,8 @@ class LogitsProcessor(nn.Module): - 1 ) pruned_states = hidden_states[last_index] + if aux_hidden_states is not None: + aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states] sample_indices = None input_logprob_indices = None else: @@ -319,13 +323,27 @@ class LogitsProcessor(nn.Module): hidden_states_to_store: Optional[torch.Tensor] = None if logits_metadata.capture_hidden_mode.need_capture(): if logits_metadata.capture_hidden_mode.is_full(): - hidden_states_to_store = hidden_states + if aux_hidden_states is not None: + aux_hidden_states = torch.cat(aux_hidden_states, dim=-1) + hidden_states_to_store = aux_hidden_states + else: + hidden_states_to_store = hidden_states elif logits_metadata.capture_hidden_mode.is_last(): # Get the last token hidden states. If sample_indices is None, # pruned states only contain the last tokens already. - hidden_states_to_store = ( - pruned_states[sample_indices] if sample_indices else pruned_states - ) + if aux_hidden_states is not None: + aux_pruned_states = torch.cat(aux_pruned_states, dim=-1) + hidden_states_to_store = ( + aux_pruned_states[sample_indices] + if sample_indices + else aux_pruned_states + ) + else: + hidden_states_to_store = ( + pruned_states[sample_indices] + if sample_indices + else pruned_states + ) else: assert False, "Should never reach" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9add51eef..95a4dd6af 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -220,7 +220,19 @@ class CudaGraphRunner: self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) # Speculative_inference - if model_runner.spec_algorithm.is_eagle(): + if ( + model_runner.spec_algorithm.is_eagle3() + and not model_runner.is_draft_worker + ): + self.hidden_states = torch.zeros( + ( + self.max_num_token, + 3 * self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + self.model_runner.model.set_eagle3_layers_to_capture() + elif model_runner.spec_algorithm.is_eagle(): self.hidden_states = torch.zeros( (self.max_num_token, self.model_runner.model_config.hidden_size), dtype=self.model_runner.dtype, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e90c37ed5..55558e0f6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -210,6 +210,10 @@ class ModelRunner: self.cuda_graph_runner = None self.init_attention_backend() + # auxiliary hidden capture mode. TODO: expose this to server args? + if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: + self.model.set_eagle3_layers_to_capture() + def model_specific_adjustment(self): server_args = self.server_args diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 4127bfcdf..873b28e02 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -17,7 +17,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -285,6 +285,7 @@ class LlamaModel(nn.Module): ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.layers_to_capture = [] def forward( self, @@ -292,13 +293,16 @@ class LlamaModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) else: hidden_states = input_embeds residual = None + aux_hidden_states = [] for i in range(len(self.layers)): + if i in self.layers_to_capture: + aux_hidden_states.append(hidden_states + residual) layer = self.layers[i] hidden_states, residual = layer( positions, @@ -307,7 +311,11 @@ class LlamaModel(nn.Module): residual, ) hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should @@ -335,7 +343,6 @@ class LlamaModel(nn.Module): class LlamaForCausalLM(nn.Module): - # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -391,6 +398,8 @@ class LlamaForCausalLM(nn.Module): (".gate_up_proj", ".up_proj", 1), ] + self.capture_aux_hidden_states = False + @torch.no_grad() def forward( self, @@ -400,10 +409,19 @@ class LlamaForCausalLM(nn.Module): input_embeds: torch.Tensor = None, get_embedding: bool = False, ) -> LogitsProcessorOutput: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds + ) + else: + hidden_states = self.model( + input_ids, positions, forward_batch, input_embeds + ) + if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) else: return self.pooler(hidden_states, forward_batch) @@ -586,9 +604,23 @@ class LlamaForCausalLM(nn.Module): torch.cuda.empty_cache() torch.cuda.synchronize() + def get_embed(self): + return self.model.embed_tokens.weight + + def set_embed(self, embed): + del self.model.embed_tokens.weight + self.model.embed_tokens.weight = embed + torch.cuda.empty_cache() + torch.cuda.synchronize() + def load_kv_cache_scales(self, quantization_param_path: str) -> None: self.model.load_kv_cache_scales(quantization_param_path) + def set_eagle3_layers_to_capture(self): + self.capture_aux_hidden_states = True + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 769ee6736..b04d334bd 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -134,6 +134,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): ) self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py new file mode 100644 index 000000000..56342fe24 --- /dev/null +++ b/python/sglang/srt/models/llama_eagle3.py @@ -0,0 +1,193 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from sglang.srt.utils import add_prefix + +# Adapted from +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM + + +class LlamaDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, layer_id, quant_config, prefix) + + # override qkv + self.self_attn.qkv_proj = QKVParallelLinear( + 2 * self.hidden_size, + self.self_attn.head_dim, + self.self_attn.total_num_heads, + self.self_attn.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + embeds = self.input_layernorm(embeds) + hidden_states = self.hidden_norm(hidden_states) + + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=add_prefix("embed_tokens", prefix), + ) + self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix) + self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + embeds = self.embed_tokens(input_ids) + else: + embeds = input_embeds + + hidden_states = forward_batch.spec_info.hidden_states + if hidden_states.shape[-1] != embeds.shape[-1]: + hidden_states = self.fc(hidden_states) + + residual = None + hidden_states, residual = self.midlayer( + positions, + embeds, + hidden_states, + forward_batch, + residual, + ) + + hidden_states_to_logits, hidden_states_to_aux = self.norm( + hidden_states, residual + ) + + # For draft decode, we capture the hidden state before norm + return hidden_states_to_logits, [hidden_states_to_aux] + + +class LlamaForCausalLMEagle3(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + + if self.config.num_hidden_layers != 1: + raise ValueError("EAGLE3 currently only supports 1 layer") + + self.model = LlamaModel( + config, quant_config=quant_config, prefix=add_prefix("model", prefix) + ) + # Llama 3.2 1B Instruct set tie_word_embeddings to True + # Llama 3.1 8B Instruct set tie_word_embeddings to False + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.draft_vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = True + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + for name, loaded_weight in weights: + if "d2t" in name: + # d2t stores diffs between draft id and target id + self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0]) + + if "d2t" not in name and "t2d" not in name and "lm_head" not in name: + new_name = f"model.{name}" + super().load_weights([(new_name, loaded_weight)]) + elif "lm_head" in name: + super().load_weights([(name, loaded_weight)]) + + def get_hot_token_id(self): + return self.hot_token_id + + +EntryClass = [LlamaForCausalLMEagle3] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 77a97b9bc..251618268 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -287,7 +287,10 @@ class ServerArgs: # NEXTN shares the same implementation of EAGLE self.speculative_algorithm = "EAGLE" - if self.speculative_algorithm == "EAGLE": + if ( + self.speculative_algorithm == "EAGLE" + or self.speculative_algorithm == "EAGLE3" + ): if self.max_running_requests is None: self.max_running_requests = 32 self.disable_overlap_schedule = True @@ -779,7 +782,7 @@ class ServerArgs: parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "NEXTN"], + choices=["EAGLE", "EAGLE3", "NEXTN"], help="Speculative algorithm.", ) parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e2dee9e12..8d29b8fb7 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -30,6 +30,7 @@ from sglang.srt.speculative.eagle_utils import ( fast_topk, select_top_k_tokens, ) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available if is_cuda_available(): @@ -66,6 +67,9 @@ class EAGLEWorker(TpModelWorker): self.gpu_id = gpu_id self.device = server_args.device self.target_worker = target_worker + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) # Override context length with target model's context length server_args.context_length = target_worker.model_runner.model_config.context_len @@ -81,7 +85,13 @@ class EAGLEWorker(TpModelWorker): ) # Load hot token ids - if server_args.speculative_token_map is not None: + if self.speculative_algorithm.is_eagle3(): + if server_args.speculative_token_map is not None: + logger.warning( + "Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map." + ) + self.hot_token_id = None + elif server_args.speculative_token_map is not None: self.hot_token_id = load_token_map(server_args.speculative_token_map) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' @@ -102,13 +112,24 @@ class EAGLEWorker(TpModelWorker): token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, ) - # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() - if self.hot_token_id is not None: - head = head.clone() - self.hot_token_id = self.hot_token_id.to(head.device) - head.data = head.data[self.hot_token_id] - self.draft_model_runner.model.set_embed_and_head(embed, head) + + if self.speculative_algorithm.is_eagle3(): + # EAGLE3 models don't share lm_head + self.draft_model_runner.model.set_embed(embed) + + # grab hot token ids + self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to( + embed.device + ) + else: + if self.hot_token_id is not None: + head = head.clone() + self.hot_token_id = self.hot_token_id.to(head.device) + head.data = head.data[self.hot_token_id] + + # Share the embedding and lm_head + self.draft_model_runner.model.set_embed_and_head(embed, head) # Init attention backend and cuda graphs self.draft_model_runner.server_args.disable_cuda_graph = ( diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 4eead0c6b..af556b99c 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -4,17 +4,22 @@ from enum import IntEnum, auto class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE = auto() + EAGLE3 = auto() def is_none(self): return self == SpeculativeAlgorithm.NONE def is_eagle(self): - return self == SpeculativeAlgorithm.EAGLE + return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 + + def is_eagle3(self): + return self == SpeculativeAlgorithm.EAGLE3 @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, + "EAGLE3": SpeculativeAlgorithm.EAGLE3, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index a464c9f24..30c846353 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -164,6 +164,21 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine): NUM_CONFIGS = 1 +class TestEAGLE3Engine(TestEAGLEEngine): + BASE_CONFIG = { + "model_path": "meta-llama/Llama-3.1-8B-Instruct", + "speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + "speculative_algorithm": "EAGLE3", + "speculative_num_steps": 5, + "speculative_eagle_topk": 16, + "speculative_num_draft_tokens": 64, + "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 5, + "dtype": "float16", + } + NUM_CONFIGS = 1 + + class TestEAGLEServer(unittest.TestCase): PROMPTS = [ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"