[Model][2/N] Remove deepseek_mtp modeling. (#3561)
This PR is step 2 of deepseek model refactoring and removes deepseek_mtp. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -29,6 +30,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
@@ -557,6 +559,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.prefill_mask = None
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
@@ -654,7 +657,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
if envs.VLLM_ASCEND_ENABLE_MLAPO:
|
||||
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
||||
# TODO(whx): modify this limitation when mlapo supports floating point
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||
None), AscendW8A8LinearMethod):
|
||||
self.enable_mlapo = False
|
||||
logger.warning(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
@@ -1229,7 +1242,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
# MLA Preprocess
|
||||
forward_context = get_forward_context()
|
||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
|
||||
if (self.enable_mlapo and
|
||||
(attn_metadata is None or not forward_context.with_prefill)):
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
@@ -33,10 +33,6 @@ def register_model():
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||
ModelRegistry.register_model(
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class CustomDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = CustomDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "shared_head"))
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
inputs_embeds, True)
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
CustomDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata=None, # type: ignore
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True)
|
||||
return hidden_states
|
||||
@@ -56,6 +56,14 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
def __init__(self, quant_config: Dict[str, Any]):
|
||||
super().__init__()
|
||||
self.quant_description = quant_config
|
||||
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||
# to prefix of DeepSeekShareHead in vLLM.
|
||||
extra_quant_dict = {}
|
||||
for k in self.quant_description.keys():
|
||||
if "shared_head" in k:
|
||||
new_k = k.replace(".shared_head.", ".")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "AscendQuantConfig:\n" + super().__repr__()
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
@@ -18,7 +19,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||
TorchairDeepSeekMTP
|
||||
@@ -86,7 +86,7 @@ class MtpProposer(Proposer):
|
||||
self.model = TorchairDeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
else:
|
||||
self.model = CustomDeepSeekMTP(
|
||||
self.model = DeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
draft_attn_layer_names = (
|
||||
@@ -184,7 +184,7 @@ class MtpProposer(Proposer):
|
||||
else:
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
previous_hidden_states=previous_hidden_states)
|
||||
hidden_states=previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
@@ -470,9 +470,8 @@ class MtpProposer(Proposer):
|
||||
hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
previous_hidden_states=self.
|
||||
hidden_states[:num_input_tokens],
|
||||
kv_caches=self.runner.kv_caches[-1:])
|
||||
hidden_states=self.hidden_states[:num_input_tokens]
|
||||
)
|
||||
|
||||
num_indices = last_token_indices.shape[0]
|
||||
if lmhead_tp_enable():
|
||||
@@ -485,7 +484,7 @@ class MtpProposer(Proposer):
|
||||
(0, max_num_reqs_across_dp - num_indices))
|
||||
|
||||
sample_hidden_states = hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user