[Feature] Support EAGLE 3 (#4247)
This commit is contained in:
@@ -223,16 +223,18 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||
aux_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
if isinstance(logits_metadata, ForwardBatch):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
):
|
||||
pruned_states = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
aux_pruned_states = [hidden for hidden in aux_hidden_states]
|
||||
sample_indices = None
|
||||
input_logprob_indices = None
|
||||
elif (
|
||||
@@ -256,6 +258,8 @@ class LogitsProcessor(nn.Module):
|
||||
- 1
|
||||
)
|
||||
pruned_states = hidden_states[last_index]
|
||||
if aux_hidden_states is not None:
|
||||
aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
|
||||
sample_indices = None
|
||||
input_logprob_indices = None
|
||||
else:
|
||||
@@ -319,13 +323,27 @@ class LogitsProcessor(nn.Module):
|
||||
hidden_states_to_store: Optional[torch.Tensor] = None
|
||||
if logits_metadata.capture_hidden_mode.need_capture():
|
||||
if logits_metadata.capture_hidden_mode.is_full():
|
||||
hidden_states_to_store = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
|
||||
hidden_states_to_store = aux_hidden_states
|
||||
else:
|
||||
hidden_states_to_store = hidden_states
|
||||
elif logits_metadata.capture_hidden_mode.is_last():
|
||||
# Get the last token hidden states. If sample_indices is None,
|
||||
# pruned states only contain the last tokens already.
|
||||
hidden_states_to_store = (
|
||||
pruned_states[sample_indices] if sample_indices else pruned_states
|
||||
)
|
||||
if aux_hidden_states is not None:
|
||||
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
||||
hidden_states_to_store = (
|
||||
aux_pruned_states[sample_indices]
|
||||
if sample_indices
|
||||
else aux_pruned_states
|
||||
)
|
||||
else:
|
||||
hidden_states_to_store = (
|
||||
pruned_states[sample_indices]
|
||||
if sample_indices
|
||||
else pruned_states
|
||||
)
|
||||
else:
|
||||
assert False, "Should never reach"
|
||||
|
||||
|
||||
@@ -220,7 +220,19 @@ class CudaGraphRunner:
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
|
||||
# Speculative_inference
|
||||
if model_runner.spec_algorithm.is_eagle():
|
||||
if (
|
||||
model_runner.spec_algorithm.is_eagle3()
|
||||
and not model_runner.is_draft_worker
|
||||
):
|
||||
self.hidden_states = torch.zeros(
|
||||
(
|
||||
self.max_num_token,
|
||||
3 * self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
self.model_runner.model.set_eagle3_layers_to_capture()
|
||||
elif model_runner.spec_algorithm.is_eagle():
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
||||
dtype=self.model_runner.dtype,
|
||||
|
||||
@@ -210,6 +210,10 @@ class ModelRunner:
|
||||
self.cuda_graph_runner = None
|
||||
self.init_attention_backend()
|
||||
|
||||
# auxiliary hidden capture mode. TODO: expose this to server args?
|
||||
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
||||
self.model.set_eagle3_layers_to_capture()
|
||||
|
||||
def model_specific_adjustment(self):
|
||||
server_args = self.server_args
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -285,6 +285,7 @@ class LlamaModel(nn.Module):
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.layers_to_capture = []
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -292,13 +293,16 @@ class LlamaModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
residual = None
|
||||
aux_hidden_states = []
|
||||
for i in range(len(self.layers)):
|
||||
if i in self.layers_to_capture:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
@@ -307,7 +311,11 @@ class LlamaModel(nn.Module):
|
||||
residual,
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
if len(aux_hidden_states) == 0:
|
||||
return hidden_states
|
||||
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
@@ -335,7 +343,6 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
@@ -391,6 +398,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -400,10 +409,19 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
get_embedding: bool = False,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
aux_hidden_states = None
|
||||
if self.capture_aux_hidden_states:
|
||||
hidden_states, aux_hidden_states = self.model(
|
||||
input_ids, positions, forward_batch, input_embeds
|
||||
)
|
||||
else:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, forward_batch, input_embeds
|
||||
)
|
||||
|
||||
if not get_embedding:
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||
)
|
||||
else:
|
||||
return self.pooler(hidden_states, forward_batch)
|
||||
@@ -586,9 +604,23 @@ class LlamaForCausalLM(nn.Module):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def get_embed(self):
|
||||
return self.model.embed_tokens.weight
|
||||
|
||||
def set_embed(self, embed):
|
||||
del self.model.embed_tokens.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
def set_eagle3_layers_to_capture(self):
|
||||
self.capture_aux_hidden_states = True
|
||||
num_layers = self.config.num_hidden_layers
|
||||
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
||||
|
||||
|
||||
class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
pass
|
||||
|
||||
@@ -134,6 +134,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = False
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
193
python/sglang/srt/models/llama_eagle3.py
Normal file
193
python/sglang/srt/models/llama_eagle3.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
||||
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
||||
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, layer_id, quant_config, prefix)
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
2 * self.hidden_size,
|
||||
self.self_attn.head_dim,
|
||||
self.self_attn.total_num_heads,
|
||||
self.self_attn.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
embeds: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
||||
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
embeds = self.embed_tokens(input_ids)
|
||||
else:
|
||||
embeds = input_embeds
|
||||
|
||||
hidden_states = forward_batch.spec_info.hidden_states
|
||||
if hidden_states.shape[-1] != embeds.shape[-1]:
|
||||
hidden_states = self.fc(hidden_states)
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.midlayer(
|
||||
positions,
|
||||
embeds,
|
||||
hidden_states,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
|
||||
hidden_states_to_logits, hidden_states_to_aux = self.norm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# For draft decode, we capture the hidden state before norm
|
||||
return hidden_states_to_logits, [hidden_states_to_aux]
|
||||
|
||||
|
||||
class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
if self.config.num_hidden_layers != 1:
|
||||
raise ValueError("EAGLE3 currently only supports 1 layer")
|
||||
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.draft_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
for name, loaded_weight in weights:
|
||||
if "d2t" in name:
|
||||
# d2t stores diffs between draft id and target id
|
||||
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
||||
|
||||
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
|
||||
new_name = f"model.{name}"
|
||||
super().load_weights([(new_name, loaded_weight)])
|
||||
elif "lm_head" in name:
|
||||
super().load_weights([(name, loaded_weight)])
|
||||
|
||||
def get_hot_token_id(self):
|
||||
return self.hot_token_id
|
||||
|
||||
|
||||
EntryClass = [LlamaForCausalLMEagle3]
|
||||
@@ -287,7 +287,10 @@ class ServerArgs:
|
||||
# NEXTN shares the same implementation of EAGLE
|
||||
self.speculative_algorithm = "EAGLE"
|
||||
|
||||
if self.speculative_algorithm == "EAGLE":
|
||||
if (
|
||||
self.speculative_algorithm == "EAGLE"
|
||||
or self.speculative_algorithm == "EAGLE3"
|
||||
):
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 32
|
||||
self.disable_overlap_schedule = True
|
||||
@@ -779,7 +782,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--speculative-algorithm",
|
||||
type=str,
|
||||
choices=["EAGLE", "NEXTN"],
|
||||
choices=["EAGLE", "EAGLE3", "NEXTN"],
|
||||
help="Speculative algorithm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.speculative.eagle_utils import (
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import empty_context, get_available_gpu_memory, is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
@@ -66,6 +67,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.gpu_id = gpu_id
|
||||
self.device = server_args.device
|
||||
self.target_worker = target_worker
|
||||
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
@@ -81,7 +85,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
)
|
||||
|
||||
# Load hot token ids
|
||||
if server_args.speculative_token_map is not None:
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
if server_args.speculative_token_map is not None:
|
||||
logger.warning(
|
||||
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
||||
)
|
||||
self.hot_token_id = None
|
||||
elif server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
||||
@@ -102,13 +112,24 @@ class EAGLEWorker(TpModelWorker):
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
if self.speculative_algorithm.is_eagle3():
|
||||
# EAGLE3 models don't share lm_head
|
||||
self.draft_model_runner.model.set_embed(embed)
|
||||
|
||||
# grab hot token ids
|
||||
self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
|
||||
embed.device
|
||||
)
|
||||
else:
|
||||
if self.hot_token_id is not None:
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
|
||||
# Share the embedding and lm_head
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
|
||||
# Init attention backend and cuda graphs
|
||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||
|
||||
@@ -4,17 +4,22 @@ from enum import IntEnum, auto
|
||||
class SpeculativeAlgorithm(IntEnum):
|
||||
NONE = auto()
|
||||
EAGLE = auto()
|
||||
EAGLE3 = auto()
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeAlgorithm.NONE
|
||||
|
||||
def is_eagle(self):
|
||||
return self == SpeculativeAlgorithm.EAGLE
|
||||
return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3
|
||||
|
||||
def is_eagle3(self):
|
||||
return self == SpeculativeAlgorithm.EAGLE3
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
name_map = {
|
||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
||||
None: SpeculativeAlgorithm.NONE,
|
||||
}
|
||||
if name is not None:
|
||||
|
||||
Reference in New Issue
Block a user