diff --git a/vllm-v0.6.2/vllm/config.py b/vllm-v0.6.2/vllm/config.py index fdac329..dddb754 100644 --- a/vllm-v0.6.2/vllm/config.py +++ b/vllm-v0.6.2/vllm/config.py @@ -1403,6 +1403,18 @@ class SpeculativeConfig: draft_hf_config = draft_model_config.hf_config + # Detect DeepSeek V3 MTP: same model path with + # num_nextn_predict_layers > 0 + num_nextn = getattr(draft_hf_config, + "num_nextn_predict_layers", 0) + if (num_nextn and num_nextn > 0 + and getattr(draft_hf_config, "model_type", "") + in ("deepseek_v3",)): + draft_hf_config.model_type = "deepseek_mtp" + draft_hf_config.architectures = ["DeepSeekMTPModel"] + if num_speculative_tokens is None: + num_speculative_tokens = num_nextn + if (num_speculative_tokens is not None and hasattr(draft_hf_config, "num_lookahead_tokens")): draft_hf_config.num_lookahead_tokens = num_speculative_tokens @@ -1421,7 +1433,7 @@ class SpeculativeConfig: f"{num_speculative_tokens=} was provided.") if enable_chunked_prefill and draft_hf_config.model_type in ( - "medusa", "mlp_speculator", "eagle"): + "medusa", "mlp_speculator", "eagle", "deepseek_mtp"): raise ValueError( "Chunked prefill and hidden-state based draft models are " "not compatible.") diff --git a/vllm-v0.6.2/vllm/model_executor/layers/layernorm.py b/vllm-v0.6.2/vllm/model_executor/layers/layernorm.py index ec72499..914bb11 100644 --- a/vllm-v0.6.2/vllm/model_executor/layers/layernorm.py +++ b/vllm-v0.6.2/vllm/model_executor/layers/layernorm.py @@ -143,11 +143,14 @@ class RMSNorm(CustomOp): from vllm import _mlu_ops as mlu_ops x = x.view(-1, self.weight.data.shape[0]) + weight = self.weight.data + if weight.dtype != x.dtype: + weight = weight.to(x.dtype) if residual is not None: residual = residual.view(-1, self.weight.data.shape[0]) - return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True) + return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True) else: - return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False) + return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False) def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" diff --git a/vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py b/vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py new file mode 100644 index 0000000..50e4ad9 --- /dev/null +++ b/vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py @@ -0,0 +1,288 @@ +"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model.""" +import re +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .deepseek_v2 import DeepseekV2DecoderLayer +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + """Shared head for MTP: norm + lm_head.""" + + def __init__(self, config, prefix: str = ""): + super().__init__() + self.norm = RMSNorm(config.hidden_size, + eps=getattr(config, "rms_norm_eps", 1e-6)) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class DeepSeekMultiTokenPredictorLayer(nn.Module): + """Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block.""" + + def __init__( + self, + config, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.enorm = RMSNorm(config.hidden_size, + eps=getattr(config, "rms_norm_eps", 1e-6)) + self.hnorm = RMSNorm(config.hidden_size, + eps=getattr(config, "rms_norm_eps", 1e-6)) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config, + prefix=f"{prefix}.shared_head") + # Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies) + self.mtp_block = DeepseekV2DecoderLayer( + config, + prefix=f"model.layers.{layer_idx}", + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert inputs_embeds is not None + # Mask inputs at position 0 + inputs_embeds = torch.where( + positions.unsqueeze(-1) == 0, 0, inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block( + positions, hidden_states, kv_caches[0], attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +def get_spec_layer_idx_from_weight_name(config, weight_name: str): + """Check if weight belongs to a speculative (MTP) layer. + Returns the layer index if so, None otherwise.""" + num_nextn = getattr(config, "num_nextn_predict_layers", 0) + if num_nextn and num_nextn > 0: + layer_idx = config.num_hidden_layers + for i in range(num_nextn): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None + + +def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str: + """Rewrite weight name for MTP layer. + Add .mtp_block for transformer block weights, + rename shared weights to top level.""" + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head", + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # Transformer block weights -> add .mtp_block prefix + name = name.replace( + f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # Shared weights -> top level + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name + + +class DeepSeekMTP(nn.Module): + """DeepSeek V3 Multi-Token Prediction draft model. + Uses hidden states from the target model to predict the next token + via a single additional decoder layer.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + self.mtp_start_layer_idx = config.num_hidden_layers + num_mtp = getattr(config, "num_nextn_predict_layers", 1) + + self.layers = nn.ModuleDict() + for i in range(num_mtp): + layer_idx = self.mtp_start_layer_idx + i + self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer( + config=config, + layer_idx=layer_idx, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"model.layers.{layer_idx}", + ) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + # Use the first MTP layer (DeepSeek V3 only has 1) + layer = self.layers[str(self.mtp_start_layer_idx)] + hidden_states = layer( + input_ids, positions, previous_hidden_states, + kv_caches, attn_metadata, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + layer = self.layers[str(self.mtp_start_layer_idx)] + normed = layer.shared_head(hidden_states) + logits = self.logits_processor( + layer.shared_head.head, normed, sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # MLU SparseMoeMlp needs pack_params() before loading + try: + from vllm_mlu.model_executor.layers.sparse_moe_mlp import ( + SparseMoeMlp) + for name, m in self.named_modules(): + if isinstance(m, SparseMoeMlp): + m.pack_params() + except ImportError: + pass + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # Only load MTP layer weights + spec_layer = get_spec_layer_idx_from_weight_name( + self.config, name) + if spec_layer is None: + continue + + # Rewrite weight name for MTP structure + name = _rewrite_spec_layer_name( + self.config, spec_layer, name) + + # Only load shared weights (embed_tokens) from first + # MTP layer, per DeepSeek V3 Technical Report + if (spec_layer != self.mtp_start_layer_idx + and ".layers" not in name): + continue + + self._load_single_weight( + name, loaded_weight, stacked_params_mapping, + params_dict) + + def _load_single_weight( + self, + name: str, + loaded_weight: torch.Tensor, + stacked_params_mapping: List[Tuple[str, str, int]], + params_dict: dict, + ): + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip expert weights not in params_dict + if (("mlp.experts." in name + or "mlp.shared_experts." in name + or "mlp.shared_expert_gate." in name + or "e_score_correction_bias" in name) + and name not in params_dict): + return + if name.endswith(".bias") and name not in params_dict: + return + if name not in params_dict: + return + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + return + + # Non-stacked weights + if name.endswith(".bias") and name not in params_dict: + return + if (("mlp.experts." in name + or "mlp.shared_experts." in name + or "mlp.shared_expert_gate." in name + or "e_score_correction_bias" in name) + and name not in params_dict): + return + if name not in params_dict: + return + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm-v0.6.2/vllm/model_executor/models/registry.py b/vllm-v0.6.2/vllm/model_executor/models/registry.py index dde8f02..399567c 100644 --- a/vllm-v0.6.2/vllm/model_executor/models/registry.py +++ b/vllm-v0.6.2/vllm/model_executor/models/registry.py @@ -166,6 +166,7 @@ _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), } # Transformers backend models - wrapper classes for custom HuggingFace models diff --git a/vllm-v0.6.2/vllm/worker/mlu_worker.py b/vllm-v0.6.2/vllm/worker/mlu_worker.py index 78e36db..53fdf1c 100644 --- a/vllm-v0.6.2/vllm/worker/mlu_worker.py +++ b/vllm-v0.6.2/vllm/worker/mlu_worker.py @@ -59,12 +59,19 @@ class MLUWorker(Worker): # mlp_speculator speculative_config = self.speculative_config model_config = self.model_config - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ - else {"return_hidden_states": True} + is_mtp = (speculative_config is not None + and getattr( + speculative_config.draft_model_config.hf_config, + "model_type", None) == "deepseek_mtp") + speculative_args = ( + {"return_hidden_states": True} if is_mtp else + ({} if speculative_config is None + or (speculative_config.draft_model_config.model == + model_config.model) + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) + else {"return_hidden_states": True}) + ) ModelRunnerClass: Type[MLUModelRunnerBase] = MLUModelRunner if model_runner_cls is not None: