[Feat] MTP support DeepSeekV3.2 (#4465)

### What this PR does / why we need it?
Currently, MTP does not support the DeepSeekV3.2 model. In this PR, we
have enabled this feature.

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-12-03 14:24:33 +08:00
committed by GitHub
parent 38bd95229f
commit 7271f0d536

View File

@@ -47,6 +47,8 @@ PADDING_SLOT_ID = -1
_MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"DeepseekV32ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
}
@@ -813,26 +815,28 @@ class MtpProposer(Proposer):
attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0:
if getattr(attn_metadata_i, "num_decode_tokens", 0):
attn_metadata_i.num_decode_tokens = batch_size
input_ids = draft_token_ids_list[-1].int()
positions += 1
decode_metadata = getattr(attn_metadata_i, "decode", None)
prefill_metadata = getattr(attn_metadata_i, "prefill", None)
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL):
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = \
decode_metadata.actual_seq_lengths_q = \
builder.pad_actual_seq_len_q_mtp_disable_pad(
graph_pad_size - batch_size,
batch_size,
attn_metadata_i.decode.actual_seq_lengths_q)
attn_metadata_i.decode.cos = builder.cos_cache[
decode_metadata.actual_seq_lengths_q)
decode_metadata.cos = builder.cos_cache[
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[
decode_metadata.sin = builder.sin_cache[
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
@@ -870,32 +874,32 @@ class MtpProposer(Proposer):
self.input_ids[batch_size:num_input_tokens] = 0
self.hidden_states[batch_size:num_input_tokens].fill_(0)
if attn_metadata_i.prefill is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
if prefill_metadata is not None:
prefill_metadata.seq_lens = attn_metadata_i.seq_lens
prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist(
)
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens,
prefill_metadata.context_lens = attn_metadata_i.seq_lens
prefill_metadata.input_positions = self.positions[:
num_input_tokens]
prefill_metadata.max_seq_lens += 1
prefill_metadata.max_seq_lens = min(
prefill_metadata.max_seq_lens,
self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
if decode_metadata is not None:
decode_metadata.seq_lens = attn_metadata_i.seq_lens
decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist(
)
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
decode_seq_lens_list = decode_metadata.seq_lens_list
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
self.speculative_config.disable_padded_drafter_batch:
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
decode_metadata.seq_lens_list = decode_seq_lens_list + [
0
] * (graph_pad_size - len(decode_seq_lens_list))
attn_metadata_i.decode.input_positions = self.positions[:
num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens,
decode_metadata.input_positions = self.positions[:
num_input_tokens]
decode_metadata.max_seq_lens += 1
decode_metadata.max_seq_lens = min(
decode_metadata.max_seq_lens,
self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k]