# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable import torch import torch.nn as nn from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2DecoderLayer, DeepseekV3ForCausalLM, ) from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight @support_torch_compile class DeepseekV2Model(nn.Module): def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", start_layer_id: int = 0, ) -> None: super().__init__() self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "embed_tokens"), ) self.layers = nn.ModuleList( [ DeepseekV2DecoderLayer( vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), config=self.config, ) for i in range(self.config.num_hidden_layers) ] ) self.fc = nn.Linear( self.config.model.hidden_size * 2, self.config.model.hidden_size, bias=False, ) self.enorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.hnorm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: input_embeds = self.embed_tokens(input_ids) inputs = torch.cat( [self.enorm(input_embeds), self.hnorm(hidden_states)], dim=-1 ) hidden_states = self.fc(inputs) residual = None for layer in self.layers: hidden_states, residual = layer( positions, hidden_states, residual, ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states, hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal # weight loading if it's not enabled # if go with fusion option, then update name if ( param_name == "fused_qkv_a_proj" ) and name_mapped not in params_dict: continue else: name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) self.config = vllm_config.speculative_config.draft_model_config.hf_config quant_config = vllm_config.quant_config target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config ) self.model = DeepseekV2Model( vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num ) self.lm_head = ParallelLMHead( self.config.vocab_size, self.config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) logit_scale = getattr(self.config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor( self.config.vocab_size, scale=logit_scale ) # Set MoE hyperparameters self.num_moe_layers = self.config.num_hidden_layers self.set_moe_parameters() def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, inputs_embeds: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if inputs_embeds is not None: raise NotImplementedError( f"{type(self).__name__} does not support multimodal inputs yet." ) return self.model(input_ids, positions, hidden_states) def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def transform(inputs): name, loaded_weight = inputs if "lm_head" not in name: name = "model." + name process_eagle_weight(self, name) return name, loaded_weight loader = AutoWeightsLoader( self, skip_prefixes=None, ) loader.load_weights(map(transform, weights))