forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -1403,6 +1403,18 @@ class SpeculativeConfig:
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
|
||||
# Detect DeepSeek V3 MTP: same model path with
|
||||
# num_nextn_predict_layers > 0
|
||||
num_nextn = getattr(draft_hf_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
if (num_nextn and num_nextn > 0
|
||||
and getattr(draft_hf_config, "model_type", "")
|
||||
in ("deepseek_v3",)):
|
||||
draft_hf_config.model_type = "deepseek_mtp"
|
||||
draft_hf_config.architectures = ["DeepSeekMTPModel"]
|
||||
if num_speculative_tokens is None:
|
||||
num_speculative_tokens = num_nextn
|
||||
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
@@ -1421,7 +1433,7 @@ class SpeculativeConfig:
|
||||
f"{num_speculative_tokens=} was provided.")
|
||||
|
||||
if enable_chunked_prefill and draft_hf_config.model_type in (
|
||||
"medusa", "mlp_speculator", "eagle"):
|
||||
"medusa", "mlp_speculator", "eagle", "deepseek_mtp"):
|
||||
raise ValueError(
|
||||
"Chunked prefill and hidden-state based draft models are "
|
||||
"not compatible.")
|
||||
|
||||
@@ -143,11 +143,14 @@ class RMSNorm(CustomOp):
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
x = x.view(-1, self.weight.data.shape[0])
|
||||
weight = self.weight.data
|
||||
if weight.dtype != x.dtype:
|
||||
weight = weight.to(x.dtype)
|
||||
if residual is not None:
|
||||
residual = residual.view(-1, self.weight.data.shape[0])
|
||||
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, True)
|
||||
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, True)
|
||||
else:
|
||||
return mlu_ops.fused_rms_norm(x, residual, self.weight.data, None, None, self.variance_epsilon, False)
|
||||
return mlu_ops.fused_rms_norm(x, residual, weight, None, None, self.variance_epsilon, False)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"hidden_size={self.weight.data.size(0)}"
|
||||
|
||||
288
vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py
Normal file
288
vllm-v0.6.2/vllm/model_executor/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""Inference-only DeepSeek V3 Multi-Token Prediction (MTP) model."""
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
"""Shared head for MTP: norm + lm_head."""
|
||||
|
||||
def __init__(self, config, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size,
|
||||
eps=getattr(config, "rms_norm_eps", 1e-6))
|
||||
self.head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
"""Single MTP layer: enorm + hnorm + eh_proj + shared_head + mtp_block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
layer_idx: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.enorm = RMSNorm(config.hidden_size,
|
||||
eps=getattr(config, "rms_norm_eps", 1e-6))
|
||||
self.hnorm = RMSNorm(config.hidden_size,
|
||||
eps=getattr(config, "rms_norm_eps", 1e-6))
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config,
|
||||
prefix=f"{prefix}.shared_head")
|
||||
# Reuse DeepseekV2DecoderLayer (MLU hijack auto-applies)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(
|
||||
config,
|
||||
prefix=f"model.layers.{layer_idx}",
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# Mask inputs at position 0
|
||||
inputs_embeds = torch.where(
|
||||
positions.unsqueeze(-1) == 0, 0, inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(
|
||||
positions, hidden_states, kv_caches[0], attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config, weight_name: str):
|
||||
"""Check if weight belongs to a speculative (MTP) layer.
|
||||
Returns the layer index if so, None otherwise."""
|
||||
num_nextn = getattr(config, "num_nextn_predict_layers", 0)
|
||||
if num_nextn and num_nextn > 0:
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(num_nextn):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
|
||||
def _rewrite_spec_layer_name(config, spec_layer: int, name: str) -> str:
|
||||
"""Rewrite weight name for MTP layer.
|
||||
Add .mtp_block for transformer block weights,
|
||||
rename shared weights to top level."""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head",
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# Transformer block weights -> add .mtp_block prefix
|
||||
name = name.replace(
|
||||
f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
elif shared_weight:
|
||||
# Shared weights -> top level
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
||||
|
||||
|
||||
class DeepSeekMTP(nn.Module):
|
||||
"""DeepSeek V3 Multi-Token Prediction draft model.
|
||||
Uses hidden states from the target model to predict the next token
|
||||
via a single additional decoder layer."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
num_mtp = getattr(config, "num_nextn_predict_layers", 1)
|
||||
|
||||
self.layers = nn.ModuleDict()
|
||||
for i in range(num_mtp):
|
||||
layer_idx = self.mtp_start_layer_idx + i
|
||||
self.layers[str(layer_idx)] = DeepSeekMultiTokenPredictorLayer(
|
||||
config=config,
|
||||
layer_idx=layer_idx,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"model.layers.{layer_idx}",
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# Use the first MTP layer (DeepSeek V3 only has 1)
|
||||
layer = self.layers[str(self.mtp_start_layer_idx)]
|
||||
hidden_states = layer(
|
||||
input_ids, positions, previous_hidden_states,
|
||||
kv_caches, attn_metadata, inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
layer = self.layers[str(self.mtp_start_layer_idx)]
|
||||
normed = layer.shared_head(hidden_states)
|
||||
logits = self.logits_processor(
|
||||
layer.shared_head.head, normed, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# MLU SparseMoeMlp needs pack_params() before loading
|
||||
try:
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import (
|
||||
SparseMoeMlp)
|
||||
for name, m in self.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# Only load MTP layer weights
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(
|
||||
self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
|
||||
# Rewrite weight name for MTP structure
|
||||
name = _rewrite_spec_layer_name(
|
||||
self.config, spec_layer, name)
|
||||
|
||||
# Only load shared weights (embed_tokens) from first
|
||||
# MTP layer, per DeepSeek V3 Technical Report
|
||||
if (spec_layer != self.mtp_start_layer_idx
|
||||
and ".layers" not in name):
|
||||
continue
|
||||
|
||||
self._load_single_weight(
|
||||
name, loaded_weight, stacked_params_mapping,
|
||||
params_dict)
|
||||
|
||||
def _load_single_weight(
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
stacked_params_mapping: List[Tuple[str, str, int]],
|
||||
params_dict: dict,
|
||||
):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip expert weights not in params_dict
|
||||
if (("mlp.experts." in name
|
||||
or "mlp.shared_experts." in name
|
||||
or "mlp.shared_expert_gate." in name
|
||||
or "e_score_correction_bias" in name)
|
||||
and name not in params_dict):
|
||||
return
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
return
|
||||
if name not in params_dict:
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
return
|
||||
|
||||
# Non-stacked weights
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
return
|
||||
if (("mlp.experts." in name
|
||||
or "mlp.shared_experts." in name
|
||||
or "mlp.shared_expert_gate." in name
|
||||
or "e_score_correction_bias" in name)
|
||||
and name not in params_dict):
|
||||
return
|
||||
if name not in params_dict:
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
@@ -166,6 +166,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EAGLEModel": ("eagle", "EAGLE"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
}
|
||||
|
||||
# Transformers backend models - wrapper classes for custom HuggingFace models
|
||||
|
||||
@@ -59,12 +59,19 @@ class MLUWorker(Worker):
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
is_mtp = (speculative_config is not None
|
||||
and getattr(
|
||||
speculative_config.draft_model_config.hf_config,
|
||||
"model_type", None) == "deepseek_mtp")
|
||||
speculative_args = (
|
||||
{"return_hidden_states": True} if is_mtp else
|
||||
({} if speculative_config is None
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model)
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"])
|
||||
else {"return_hidden_states": True})
|
||||
)
|
||||
|
||||
ModelRunnerClass: Type[MLUModelRunnerBase] = MLUModelRunner
|
||||
if model_runner_cls is not None:
|
||||
|
||||
Reference in New Issue
Block a user