[Feature] Support EAGLE 3 (#4247)
This commit is contained in:
193
python/sglang/srt/models/llama_eagle3.py
Normal file
193
python/sglang/srt/models/llama_eagle3.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
|
||||
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
|
||||
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
layer_id: int = 0,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, layer_id, quant_config, prefix)
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
2 * self.hidden_size,
|
||||
self.self_attn.head_dim,
|
||||
self.self_attn.total_num_heads,
|
||||
self.self_attn.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
embeds: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
|
||||
self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
embeds = self.embed_tokens(input_ids)
|
||||
else:
|
||||
embeds = input_embeds
|
||||
|
||||
hidden_states = forward_batch.spec_info.hidden_states
|
||||
if hidden_states.shape[-1] != embeds.shape[-1]:
|
||||
hidden_states = self.fc(hidden_states)
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.midlayer(
|
||||
positions,
|
||||
embeds,
|
||||
hidden_states,
|
||||
forward_batch,
|
||||
residual,
|
||||
)
|
||||
|
||||
hidden_states_to_logits, hidden_states_to_aux = self.norm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# For draft decode, we capture the hidden state before norm
|
||||
return hidden_states_to_logits, [hidden_states_to_aux]
|
||||
|
||||
|
||||
class LlamaForCausalLMEagle3(LlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
if self.config.num_hidden_layers != 1:
|
||||
raise ValueError("EAGLE3 currently only supports 1 layer")
|
||||
|
||||
self.model = LlamaModel(
|
||||
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
||||
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.draft_vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.capture_aux_hidden_states = True
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
for name, loaded_weight in weights:
|
||||
if "d2t" in name:
|
||||
# d2t stores diffs between draft id and target id
|
||||
self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
|
||||
|
||||
if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
|
||||
new_name = f"model.{name}"
|
||||
super().load_weights([(new_name, loaded_weight)])
|
||||
elif "lm_head" in name:
|
||||
super().load_weights([(name, loaded_weight)])
|
||||
|
||||
def get_hot_token_id(self):
|
||||
return self.hot_token_id
|
||||
|
||||
|
||||
EntryClass = [LlamaForCausalLMEagle3]
|
||||
Reference in New Issue
Block a user