[Feat] MTP support DeepSeekV3.2 (#4465)
### What this PR does / why we need it? Currently, MTP does not support the DeepSeekV3.2 model. In this PR, we have enabled this feature. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -47,6 +47,8 @@ PADDING_SLOT_ID = -1
|
|||||||
|
|
||||||
_MTP_MODELS = {
|
_MTP_MODELS = {
|
||||||
"DeepseekV3ForCausalLM":
|
"DeepseekV3ForCausalLM":
|
||||||
|
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||||
|
"DeepseekV32ForCausalLM":
|
||||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
|
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -813,26 +815,28 @@ class MtpProposer(Proposer):
|
|||||||
attn_metadata_i.slot_mapping.fill_(-1)
|
attn_metadata_i.slot_mapping.fill_(-1)
|
||||||
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||||
last_token_indices = self.arange[:batch_size]
|
last_token_indices = self.arange[:batch_size]
|
||||||
if attn_metadata_i.num_decode_tokens != 0:
|
if getattr(attn_metadata_i, "num_decode_tokens", 0):
|
||||||
attn_metadata_i.num_decode_tokens = batch_size
|
attn_metadata_i.num_decode_tokens = batch_size
|
||||||
|
|
||||||
input_ids = draft_token_ids_list[-1].int()
|
input_ids = draft_token_ids_list[-1].int()
|
||||||
positions += 1
|
positions += 1
|
||||||
|
|
||||||
|
decode_metadata = getattr(attn_metadata_i, "decode", None)
|
||||||
|
prefill_metadata = getattr(attn_metadata_i, "prefill", None)
|
||||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||||
if self.speculative_config.disable_padded_drafter_batch or \
|
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
|
||||||
aclgraph_runtime_mode != CUDAGraphMode.FULL:
|
aclgraph_runtime_mode != CUDAGraphMode.FULL):
|
||||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||||
1:batch_size + 1].tolist()
|
1:batch_size + 1].tolist()
|
||||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
attn_metadata_i.decode.actual_seq_lengths_q = \
|
decode_metadata.actual_seq_lengths_q = \
|
||||||
builder.pad_actual_seq_len_q_mtp_disable_pad(
|
builder.pad_actual_seq_len_q_mtp_disable_pad(
|
||||||
graph_pad_size - batch_size,
|
graph_pad_size - batch_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
attn_metadata_i.decode.actual_seq_lengths_q)
|
decode_metadata.actual_seq_lengths_q)
|
||||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
decode_metadata.cos = builder.cos_cache[
|
||||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
decode_metadata.sin = builder.sin_cache[
|
||||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||||
# NOTE(woosuk): We should handle the case where the draft model
|
# NOTE(woosuk): We should handle the case where the draft model
|
||||||
# generates tokens beyond the max model length. Since it is complex
|
# generates tokens beyond the max model length. Since it is complex
|
||||||
@@ -870,32 +874,32 @@ class MtpProposer(Proposer):
|
|||||||
self.input_ids[batch_size:num_input_tokens] = 0
|
self.input_ids[batch_size:num_input_tokens] = 0
|
||||||
self.hidden_states[batch_size:num_input_tokens].fill_(0)
|
self.hidden_states[batch_size:num_input_tokens].fill_(0)
|
||||||
|
|
||||||
if attn_metadata_i.prefill is not None:
|
if prefill_metadata is not None:
|
||||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
prefill_metadata.seq_lens = attn_metadata_i.seq_lens
|
||||||
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist(
|
||||||
)
|
)
|
||||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
prefill_metadata.context_lens = attn_metadata_i.seq_lens
|
||||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
prefill_metadata.input_positions = self.positions[:
|
||||||
num_input_tokens]
|
num_input_tokens]
|
||||||
attn_metadata_i.prefill.max_seq_lens += 1
|
prefill_metadata.max_seq_lens += 1
|
||||||
attn_metadata_i.prefill.max_seq_lens = min(
|
prefill_metadata.max_seq_lens = min(
|
||||||
attn_metadata_i.prefill.max_seq_lens,
|
prefill_metadata.max_seq_lens,
|
||||||
self.runner.model_config.max_model_len)
|
self.runner.model_config.max_model_len)
|
||||||
if attn_metadata_i.decode is not None:
|
if decode_metadata is not None:
|
||||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
decode_metadata.seq_lens = attn_metadata_i.seq_lens
|
||||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist(
|
||||||
)
|
)
|
||||||
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
|
decode_seq_lens_list = decode_metadata.seq_lens_list
|
||||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
|
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||||
self.speculative_config.disable_padded_drafter_batch:
|
self.speculative_config.disable_padded_drafter_batch:
|
||||||
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
|
decode_metadata.seq_lens_list = decode_seq_lens_list + [
|
||||||
0
|
0
|
||||||
] * (graph_pad_size - len(decode_seq_lens_list))
|
] * (graph_pad_size - len(decode_seq_lens_list))
|
||||||
attn_metadata_i.decode.input_positions = self.positions[:
|
decode_metadata.input_positions = self.positions[:
|
||||||
num_input_tokens]
|
num_input_tokens]
|
||||||
attn_metadata_i.decode.max_seq_lens += 1
|
decode_metadata.max_seq_lens += 1
|
||||||
attn_metadata_i.decode.max_seq_lens = min(
|
decode_metadata.max_seq_lens = min(
|
||||||
attn_metadata_i.decode.max_seq_lens,
|
decode_metadata.max_seq_lens,
|
||||||
self.runner.model_config.max_model_len)
|
self.runner.model_config.max_model_len)
|
||||||
|
|
||||||
# mtp>1: [batch_size, k]
|
# mtp>1: [batch_size, k]
|
||||||
|
|||||||
Reference in New Issue
Block a user