# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable import torch import torch.nn as nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM from .utils import ( AutoWeightsLoader, get_draft_quant_config, maybe_prefix, process_eagle_weight, ) logger = init_logger(__name__) class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, vllm_config: VllmConfig, disable_input_layernorm: bool, prefix: str = "", config: LlamaConfig | None = None, ) -> None: super().__init__(vllm_config, prefix=prefix, config=config) # Skip the input_layernorm # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 if disable_input_layernorm: del self.input_layernorm self.input_layernorm = nn.Identity() def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: """Use drafter's quantization config instead of verifier's.""" return get_draft_quant_config(vllm_config) @support_torch_compile class LlamaModel(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0, ) -> None: super().__init__() self.config = vllm_config.speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size # Get drafter's quantization config self.quant_config = get_draft_quant_config(vllm_config) self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, prefix=maybe_prefix(prefix, "embed_tokens"), ) self.layers = nn.ModuleList( [ LlamaDecoderLayer( vllm_config, i == 0, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), config=self.config, ) for i in range(self.config.num_hidden_layers) ] ) self.fc = ReplicatedLinear( input_size=self.config.hidden_size * 2, output_size=self.config.hidden_size, bias=False, params_dtype=vllm_config.model_config.dtype, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "fc"), return_bias=False, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( positions, hidden_states, residual, ) hidden_states = hidden_states + residual return hidden_states, hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: # Handle kv cache quantization scales if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) loaded_weight = ( loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] ) weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue # Remapping the name FP8 kv-scale if "scale" in name: name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleLlamaForCausalLM(LlamaForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config.speculative_config.draft_model_config.hf_config # Ensure draft_vocab_size is set # default to the base vocab size when absent if getattr(self.config, "draft_vocab_size", None) is None: base_vocab_size = getattr(self.config, "vocab_size", None) self.config.draft_vocab_size = base_vocab_size target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config ) self.model = LlamaModel( vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num ) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.config.vocab_size, scale=logit_scale ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: raise NotImplementedError( f"{type(self).__name__} does not support multimodal inputs yet." ) return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def transform(inputs): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( self, skip_prefixes=None, ) loader.load_weights(map(transform, weights))