add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 15:24:13 +08:00
parent 2ad23aa8da
commit c584139543
5 changed files with 320 additions and 9 deletions

View File

@@ -1403,6 +1403,18 @@ class SpeculativeConfig:
draft_hf_config = draft_model_config.hf_config
# Detect DeepSeek V3 MTP: same model path with
# num_nextn_predict_layers > 0
num_nextn = getattr(draft_hf_config,
"num_nextn_predict_layers", 0)
if (num_nextn and num_nextn > 0
and getattr(draft_hf_config, "model_type", "")
in ("deepseek_v3",)):
draft_hf_config.model_type = "deepseek_mtp"
draft_hf_config.architectures = ["DeepSeekMTPModel"]
if num_speculative_tokens is None:
num_speculative_tokens = num_nextn
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
@@ -1421,7 +1433,7 @@ class SpeculativeConfig:
f"{num_speculative_tokens=} was provided.")
if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
"medusa", "mlp_speculator", "eagle", "deepseek_mtp"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")

View File

@@ -143,11 +143,14 @@ class RMSNorm(CustomOp):
from vllm import _mlu_ops as mlu_ops
x = x.view(-1, self.weight.data.shape[0])
weight = self.weight.data
if weight.dtype != x.dtype:
weight = weight.to(x.dtype)
if residual is not None:
residual = residual.view(-1, self.weight.data.shape[0])
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True)
else:
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False)
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False)
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"

View File

@@ -0,0 +1,288 @@
"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
import re
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import DeepseekV2DecoderLayer
from .utils import maybe_prefix
class SharedHead(nn.Module):
"""Shared head for MTP: norm + lm_head."""
def __init__(self, config, prefix: str = ""):
super().__init__()
self.norm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(hidden_states)
class DeepSeekMultiTokenPredictorLayer(nn.Module):
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
def __init__(
self,
config,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.enorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.hnorm = RMSNorm(config.hidden_size,
eps=getattr(config, "rms_norm_eps", 1e-6))
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = SharedHead(config,
prefix=f"{prefix}.shared_head")
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
self.mtp_block = DeepseekV2DecoderLayer(
config,
prefix=f"model.layers.{layer_idx}",
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert inputs_embeds is not None
# Mask inputs at position 0
inputs_embeds = torch.where(
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(
positions, hidden_states, kv_caches[0], attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
"""Check if weight belongs to a speculative (MTP) layer.
Returns the layer index if so, None otherwise."""
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
if num_nextn and num_nextn > 0:
layer_idx = config.num_hidden_layers
for i in range(num_nextn):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
"""Rewrite weight name for MTP layer.
Add .mtp_block for transformer block weights,
rename shared weights to top level."""
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# Transformer block weights -> add .mtp_block prefix
name = name.replace(
f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# Shared weights -> top level
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
class DeepSeekMTP(nn.Module):
"""DeepSeek V3 Multi-Token Prediction draft model.
Uses hidden states from the target model to predict the next token
via a single additional decoder layer."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.mtp_start_layer_idx = config.num_hidden_layers
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
self.layers = nn.ModuleDict()
for i in range(num_mtp):
layer_idx = self.mtp_start_layer_idx + i
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
config=config,
layer_idx=layer_idx,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"model.layers.{layer_idx}",
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
# Use the first MTP layer (DeepSeek V3 only has 1)
layer = self.layers[str(self.mtp_start_layer_idx)]
hidden_states = layer(
input_ids, positions, previous_hidden_states,
kv_caches, attn_metadata, inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
layer = self.layers[str(self.mtp_start_layer_idx)]
normed = layer.shared_head(hidden_states)
logits = self.logits_processor(
layer.shared_head.head, normed, sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# MLU SparseMoeMlp needs pack_params() before loading
try:
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
SparseMoeMlp)
for name, m in self.named_modules():
if isinstance(m, SparseMoeMlp):
m.pack_params()
except ImportError:
pass
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# Only load MTP layer weights
spec_layer = get_spec_layer_idx_from_weight_name(
self.config, name)
if spec_layer is None:
continue
# Rewrite weight name for MTP structure
name = _rewrite_spec_layer_name(
self.config, spec_layer, name)
# Only load shared weights (embed_tokens) from first
# MTP layer, per DeepSeek V3 Technical Report
if (spec_layer != self.mtp_start_layer_idx
and ".layers" not in name):
continue
self._load_single_weight(
name, loaded_weight, stacked_params_mapping,
params_dict)
def _load_single_weight(
self,
name: str,
loaded_weight: torch.Tensor,
stacked_params_mapping: List[Tuple[str, str, int]],
params_dict: dict,
):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip expert weights not in params_dict
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name.endswith(".bias") and name not in params_dict:
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
return
# Non-stacked weights
if name.endswith(".bias") and name not in params_dict:
return
if (("mlp.experts." in name
or "mlp.shared_experts." in name
or "mlp.shared_expert_gate." in name
or "e_score_correction_bias" in name)
and name not in params_dict):
return
if name not in params_dict:
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@@ -166,6 +166,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EAGLEModel": ("eagle", "EAGLE"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
}
# Transformers backend models - wrapper classes for custom HuggingFace models

View File

@@ -59,12 +59,19 @@ class MLUWorker(Worker):
# mlp_speculator
speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"]) \
else {"return_hidden_states": True}
is_mtp = (speculative_config is not None
and getattr(
speculative_config.draft_model_config.hf_config,
"model_type", None) == "deepseek_mtp")
speculative_args = (
{"return_hidden_states": True} if is_mtp else
({} if speculative_config is None
or (speculative_config.draft_model_config.model ==
model_config.model)
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator", "eagle"])
else {"return_hidden_states": True})
)
ModelRunnerClass: Type[MLUModelRunnerBase] = MLUModelRunner
if model_runner_cls is not None: