[Model][2/N] Remove deepseek_mtp modeling. (#3561)

This PR is step 2 of deepseek model refactoring and removes
deepseek_mtp.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-21 20:17:09 +08:00
committed by GitHub
parent ffb42a8daa
commit 220df60c61
7 changed files with 38 additions and 422 deletions

View File

@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
@@ -29,6 +30,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
from vllm_ascend.worker.npu_input_batch import InputBatch
@@ -557,6 +559,7 @@ class AscendMLAImpl(MLAAttentionImpl):
self.prefill_mask = None
self.speculative_config = vllm_config.speculative_config
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
def _v_up_proj(self, x):
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
@@ -654,7 +657,17 @@ class AscendMLAImpl(MLAAttentionImpl):
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
if envs.VLLM_ASCEND_ENABLE_MLAPO:
# Currently mlapo only supports W8A8 quantization in MLA scenario
# TODO(whx): modify this limitation when mlapo supports floating point
if self.fused_qkv_a_proj is None or not isinstance(
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
None), AscendW8A8LinearMethod):
self.enable_mlapo = False
logger.warning(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
@@ -1229,7 +1242,7 @@ class AscendMLAImpl(MLAAttentionImpl):
# MLA Preprocess
forward_context = get_forward_context()
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
if (self.enable_mlapo and
(attn_metadata is None or not forward_context.with_prefill)):
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
hidden_states, kv_cache, attn_metadata)

View File

@@ -33,10 +33,6 @@ def register_model():
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(

View File

@@ -1,209 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/deepseek_mtp.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix
from vllm.sequence import IntermediateTensors
class CustomDeepSeekShareHead(SharedHead):
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
nn.Module.__init__(self)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "head"))
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
vllm_config = get_current_vllm_config()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.shared_head = CustomDeepSeekShareHead(config=config,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "shared_head"))
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
prefix=prefix)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
inputs_embeds, True)
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds),
inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.eh_proj(
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict({
str(idx):
CustomDeepSeekMultiTokenPredictorLayer(
config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
# Note: torch._dynamo.exc.Unsupported: builtin: str
self.layers_list = [
self.layers[str(idx)]
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
]
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: torch.Tensor,
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
step_kv_cache = kv_caches[
current_step_idx] if kv_caches is not None else None
return self.layers_list[current_step_idx](
input_ids,
positions,
step_kv_cache,
attn_metadata,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata=None, # type: ignore
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers_list[current_step_idx]
logits = self.logits_processor(mtp_layer.shared_head.head,
mtp_layer.shared_head(hidden_states),
sampling_metadata)
return logits
@support_torch_compile
class CustomDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.model_config.hf_config
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True)
return hidden_states

View File

@@ -56,6 +56,14 @@ class AscendQuantConfig(QuantizationConfig):
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
# TODO(whx): remove this adaptation after adding "shared_head"
# to prefix of DeepSeekShareHead in vLLM.
extra_quant_dict = {}
for k in self.quant_description.keys():
if "shared_head" in k:
new_k = k.replace(".shared_head.", ".")
extra_quant_dict[new_k] = self.quant_description[k]
self.quant_description.update(extra_quant_dict)
def __repr__(self) -> str:
return "AscendQuantConfig:\n" + super().__repr__()

View File

@@ -11,6 +11,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -18,7 +19,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
@@ -86,7 +86,7 @@ class MtpProposer(Proposer):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = CustomDeepSeekMTP(
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
@@ -184,7 +184,7 @@ class MtpProposer(Proposer):
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
hidden_states=previous_hidden_states)
if with_prefill:
break
@@ -470,9 +470,8 @@ class MtpProposer(Proposer):
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
hidden_states=self.hidden_states[:num_input_tokens]
)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
@@ -485,7 +484,7 @@ class MtpProposer(Proposer):
(0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)