292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
|
|
from typing import Iterable, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .deepseek_v2 import DeepseekV2DecoderLayer
|
|
|
|
|
|
class SharedHead(nn.Module):
|
|
"""Shared head for MTP: norm + lm_head."""
|
|
|
|
def __init__(self, config, prefix: str = ""):
|
|
super().__init__()
|
|
self.norm = RMSNorm(config.hidden_size,
|
|
eps=getattr(config, "rms_norm_eps", 1e-6))
|
|
self.head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
return self.norm(hidden_states)
|
|
|
|
|
|
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
|
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
layer_idx: int,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.enorm = RMSNorm(config.hidden_size,
|
|
eps=getattr(config, "rms_norm_eps", 1e-6))
|
|
self.hnorm = RMSNorm(config.hidden_size,
|
|
eps=getattr(config, "rms_norm_eps", 1e-6))
|
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
|
config.hidden_size,
|
|
bias=False)
|
|
self.shared_head = SharedHead(config,
|
|
prefix=f"{prefix}.shared_head")
|
|
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
|
|
self.mtp_block = DeepseekV2DecoderLayer(
|
|
config,
|
|
prefix=f"model.layers.{layer_idx}",
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
previous_hidden_states: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
assert inputs_embeds is not None
|
|
# Mask inputs at position 0
|
|
inputs_embeds = torch.where(
|
|
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
|
inputs_embeds = self.enorm(inputs_embeds)
|
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
|
|
|
hidden_states = self.eh_proj(
|
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
|
|
|
hidden_states, residual = self.mtp_block(
|
|
positions, hidden_states, kv_caches[0], attn_metadata,
|
|
residual=None)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
|
|
"""Check if weight belongs to a speculative (MTP) layer.
|
|
Returns the layer index if so, None otherwise."""
|
|
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
|
|
if num_nextn and num_nextn > 0:
|
|
layer_idx = config.num_hidden_layers
|
|
for i in range(num_nextn):
|
|
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
|
|
return layer_idx + i
|
|
return None
|
|
|
|
|
|
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
|
|
"""Rewrite weight name for MTP layer.
|
|
Add .mtp_block for transformer block weights,
|
|
rename shared weights to top level."""
|
|
spec_layer_weight_names = [
|
|
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
|
|
]
|
|
shared_weight_names = ["embed_tokens"]
|
|
spec_layer_weight = False
|
|
shared_weight = False
|
|
for weight_name in spec_layer_weight_names:
|
|
if weight_name in name:
|
|
spec_layer_weight = True
|
|
if weight_name in shared_weight_names:
|
|
shared_weight = True
|
|
break
|
|
if not spec_layer_weight:
|
|
# Transformer block weights -> add .mtp_block prefix
|
|
name = name.replace(
|
|
f"model.layers.{spec_layer}.",
|
|
f"model.layers.{spec_layer}.mtp_block.")
|
|
elif shared_weight:
|
|
# Shared weights -> top level
|
|
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
|
return name
|
|
|
|
|
|
class DeepSeekMTP(nn.Module):
|
|
"""DeepSeek V3 Multi-Token Prediction draft model.
|
|
Uses hidden states from the target model to predict the next token
|
|
via a single additional decoder layer."""
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
|
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
|
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
|
|
|
|
self.layers = nn.ModuleDict()
|
|
for i in range(num_mtp):
|
|
layer_idx = self.mtp_start_layer_idx + i
|
|
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
|
|
config=config,
|
|
layer_idx=layer_idx,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"model.layers.{layer_idx}",
|
|
)
|
|
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
previous_hidden_states: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
) -> torch.Tensor:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
# Use the first MTP layer (DeepSeek V3 only has 1)
|
|
layer = self.layers[str(self.mtp_start_layer_idx)]
|
|
hidden_states = layer(
|
|
input_ids, positions, previous_hidden_states,
|
|
kv_caches, attn_metadata, inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
layer = self.layers[str(self.mtp_start_layer_idx)]
|
|
normed = layer.shared_head(hidden_states)
|
|
logits = self.logits_processor(
|
|
layer.shared_head.head, normed, sampling_metadata)
|
|
return logits
|
|
|
|
def sample(
|
|
self,
|
|
logits: Optional[torch.Tensor],
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[SamplerOutput]:
|
|
next_tokens = self.sampler(logits, sampling_metadata)
|
|
return next_tokens
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
# MLU SparseMoeMlp needs pack_params() before loading
|
|
try:
|
|
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
|
|
SparseMoeMlp)
|
|
for name, m in self.named_modules():
|
|
if isinstance(m, SparseMoeMlp):
|
|
m.pack_params()
|
|
except ImportError:
|
|
pass
|
|
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "gate_proj", 0),
|
|
("gate_up_proj", "up_proj", 1),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
|
|
# Only load MTP layer weights
|
|
spec_layer = get_spec_layer_idx_from_weight_name(
|
|
self.config, name)
|
|
if spec_layer is None:
|
|
continue
|
|
|
|
# Rewrite weight name for MTP structure
|
|
name = _rewrite_spec_layer_name(
|
|
self.config, spec_layer, name)
|
|
|
|
# Only load shared weights (embed_tokens) from first
|
|
# MTP layer, per DeepSeek V3 Technical Report
|
|
if (spec_layer != self.mtp_start_layer_idx
|
|
and ".layers" not in name):
|
|
continue
|
|
|
|
# Strip "model." prefix since DeepSeekMTP holds
|
|
# embed_tokens and layers directly (no .model wrapper)
|
|
if name.startswith("model."):
|
|
name = name[len("model."):]
|
|
|
|
self._load_single_weight(
|
|
name, loaded_weight, stacked_params_mapping,
|
|
params_dict)
|
|
|
|
def _load_single_weight(
|
|
self,
|
|
name: str,
|
|
loaded_weight: torch.Tensor,
|
|
stacked_params_mapping: List[Tuple[str, str, int]],
|
|
params_dict: dict,
|
|
):
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip expert weights not in params_dict
|
|
if (("mlp.experts." in name
|
|
or "mlp.shared_experts." in name
|
|
or "mlp.shared_expert_gate." in name
|
|
or "e_score_correction_bias" in name)
|
|
and name not in params_dict):
|
|
return
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
return
|
|
if name not in params_dict:
|
|
return
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
return
|
|
|
|
# Non-stacked weights
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
return
|
|
if (("mlp.experts." in name
|
|
or "mlp.shared_experts." in name
|
|
or "mlp.shared_expert_gate." in name
|
|
or "e_score_correction_bias" in name)
|
|
and name not in params_dict):
|
|
return
|
|
if name not in params_dict:
|
|
return
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|