Files
enginex-mlu370-vllm/vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py
2026-02-11 17:47:15 +08:00

292 lines
11 KiB
Python

"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
class SharedHead(nn.Module):
"""Shared head for MTP: norm + lm_head."""
def __init__(self, config, prefix: str = ""):
super().__init__()
self.norm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
def __init__(
self,
config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.enorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.hnorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config,
prefix=f"{prefix}.shared_head")
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
self.mtp_block = DeepseekV2DecoderLayer(
config,
prefix=f"model.layers.{layer_idx}",
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert inputs_embeds is not None
# Mask inputs at position 0
inputs_embeds = torch.where(
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(
positions, hidden_states, kv_caches[0], attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
"""Check if weight belongs to a speculative (MTP) layer.
Returns the layer index if so, None otherwise."""
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
"""Rewrite weight name for MTP layer.
Add .mtp_block for transformer block weights,
rename shared weights to top level."""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# Transformer block weights -> add .mtp_block prefix
name = name.replace(
f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# Shared weights -> top level
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
class DeepSeekMTP(nn.Module):
"""DeepSeek V3 Multi-Token Prediction draft model.
Uses hidden states from the target model to predict the next token
via a single additional decoder layer."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
self.layers = nn.ModuleDict()
for i in range(num_mtp):
layer_idx = self.mtp_start_layer_idx + i
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
config=config,
layer_idx=layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"model.layers.{layer_idx}",
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# Use the first MTP layer (DeepSeek V3 only has 1)
layer = self.layers[str(self.mtp_start_layer_idx)]
hidden_states = layer(
input_ids, positions, previous_hidden_states,
kv_caches, attn_metadata, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
layer = self.layers[str(self.mtp_start_layer_idx)]
normed = layer.shared_head(hidden_states)
logits = self.logits_processor(
layer.shared_head.head, normed, sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# MLU SparseMoeMlp needs pack_params() before loading
try:
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
SparseMoeMlp)
for name, m in self.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
except ImportError:
pass
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Only load MTP layer weights
spec_layer = get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
# Rewrite weight name for MTP structure
name = _rewrite_spec_layer_name(
self.config, spec_layer, name)
# Only load shared weights (embed_tokens) from first
# MTP layer, per DeepSeek V3 Technical Report
if (spec_layer != self.mtp_start_layer_idx
and ".layers" not in name):
continue
# Strip "model." prefix since DeepSeekMTP holds
# embed_tokens and layers directly (no .model wrapper)
if name.startswith("model."):
name = name[len("model."):]
self._load_single_weight(
name, loaded_weight, stacked_params_mapping,
params_dict)
def _load_single_weight(
self,
name: str,
loaded_weight: torch.Tensor,
stacked_params_mapping: List[Tuple[str, str, int]],
params_dict: dict,
):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip expert weights not in params_dict
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name.endswith(".bias") and name not in params_dict:
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
return
# Non-stacked weights
if name.endswith(".bias") and name not in params_dict:
return
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)