[Feat] MTP support DeepSeekV3.2 (#4465)

### What this PR does / why we need it?
Currently, MTP does not support the DeepSeekV3.2 model. In this PR, we
have enabled this feature.

- vLLM version: v0.11.2
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2

Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
ZYang6263
2025-12-03 14:24:33 +08:00
committed by GitHub
parent 38bd95229f
commit 7271f0d536

View File

@@ -47,6 +47,8 @@ PADDING_SLOT_ID = -1
_MTP_MODELS = { _MTP_MODELS = {
"DeepseekV3ForCausalLM": "DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"DeepseekV32ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP") ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
} }
@@ -813,26 +815,28 @@ class MtpProposer(Proposer):
attn_metadata_i.slot_mapping.fill_(-1) attn_metadata_i.slot_mapping.fill_(-1)
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
last_token_indices = self.arange[:batch_size] last_token_indices = self.arange[:batch_size]
if attn_metadata_i.num_decode_tokens != 0: if getattr(attn_metadata_i, "num_decode_tokens", 0):
attn_metadata_i.num_decode_tokens = batch_size attn_metadata_i.num_decode_tokens = batch_size
input_ids = draft_token_ids_list[-1].int() input_ids = draft_token_ids_list[-1].int()
positions += 1 positions += 1
decode_metadata = getattr(attn_metadata_i, "decode", None)
prefill_metadata = getattr(attn_metadata_i, "prefill", None)
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
if self.speculative_config.disable_padded_drafter_batch or \ if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
aclgraph_runtime_mode != CUDAGraphMode.FULL: aclgraph_runtime_mode != CUDAGraphMode.FULL):
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
1:batch_size + 1].tolist() 1:batch_size + 1].tolist()
if aclgraph_runtime_mode == CUDAGraphMode.FULL: if aclgraph_runtime_mode == CUDAGraphMode.FULL:
attn_metadata_i.decode.actual_seq_lengths_q = \ decode_metadata.actual_seq_lengths_q = \
builder.pad_actual_seq_len_q_mtp_disable_pad( builder.pad_actual_seq_len_q_mtp_disable_pad(
graph_pad_size - batch_size, graph_pad_size - batch_size,
batch_size, batch_size,
attn_metadata_i.decode.actual_seq_lengths_q) decode_metadata.actual_seq_lengths_q)
attn_metadata_i.decode.cos = builder.cos_cache[ decode_metadata.cos = builder.cos_cache[
positions[:batch_size]].unsqueeze(1).unsqueeze(2) positions[:batch_size]].unsqueeze(1).unsqueeze(2)
attn_metadata_i.decode.sin = builder.sin_cache[ decode_metadata.sin = builder.sin_cache[
positions[:batch_size]].unsqueeze(1).unsqueeze(2) positions[:batch_size]].unsqueeze(1).unsqueeze(2)
# NOTE(woosuk): We should handle the case where the draft model # NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex # generates tokens beyond the max model length. Since it is complex
@@ -870,32 +874,32 @@ class MtpProposer(Proposer):
self.input_ids[batch_size:num_input_tokens] = 0 self.input_ids[batch_size:num_input_tokens] = 0
self.hidden_states[batch_size:num_input_tokens].fill_(0) self.hidden_states[batch_size:num_input_tokens].fill_(0)
if attn_metadata_i.prefill is not None: if prefill_metadata is not None:
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens prefill_metadata.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist(
) )
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens prefill_metadata.context_lens = attn_metadata_i.seq_lens
attn_metadata_i.prefill.input_positions = self.positions[: prefill_metadata.input_positions = self.positions[:
num_input_tokens] num_input_tokens]
attn_metadata_i.prefill.max_seq_lens += 1 prefill_metadata.max_seq_lens += 1
attn_metadata_i.prefill.max_seq_lens = min( prefill_metadata.max_seq_lens = min(
attn_metadata_i.prefill.max_seq_lens, prefill_metadata.max_seq_lens,
self.runner.model_config.max_model_len) self.runner.model_config.max_model_len)
if attn_metadata_i.decode is not None: if decode_metadata is not None:
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens decode_metadata.seq_lens = attn_metadata_i.seq_lens
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist(
) )
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list decode_seq_lens_list = decode_metadata.seq_lens_list
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
self.speculative_config.disable_padded_drafter_batch: self.speculative_config.disable_padded_drafter_batch:
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [ decode_metadata.seq_lens_list = decode_seq_lens_list + [
0 0
] * (graph_pad_size - len(decode_seq_lens_list)) ] * (graph_pad_size - len(decode_seq_lens_list))
attn_metadata_i.decode.input_positions = self.positions[: decode_metadata.input_positions = self.positions[:
num_input_tokens] num_input_tokens]
attn_metadata_i.decode.max_seq_lens += 1 decode_metadata.max_seq_lens += 1
attn_metadata_i.decode.max_seq_lens = min( decode_metadata.max_seq_lens = min(
attn_metadata_i.decode.max_seq_lens, decode_metadata.max_seq_lens,
self.runner.model_config.max_model_len) self.runner.model_config.max_model_len)
# mtp>1: [batch_size, k] # mtp>1: [batch_size, k]