[Model][2/N] Remove deepseek_mtp modeling. (#3561)

This PR is step 2 of deepseek model refactoring and removes
deepseek_mtp.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-21 20:17:09 +08:00
committed by GitHub
parent ffb42a8daa
commit 220df60c61
7 changed files with 38 additions and 422 deletions

View File

@@ -11,6 +11,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -18,7 +19,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
@@ -86,7 +86,7 @@ class MtpProposer(Proposer):
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = CustomDeepSeekMTP(
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
@@ -184,7 +184,7 @@ class MtpProposer(Proposer):
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
hidden_states=previous_hidden_states)
if with_prefill:
break
@@ -470,9 +470,8 @@ class MtpProposer(Proposer):
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
hidden_states=self.hidden_states[:num_input_tokens]
)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable():
@@ -485,7 +484,7 @@ class MtpProposer(Proposer):
(0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices]
draft_token_ids = logits.argmax(dim=-1)