[Feat] MTP support DeepSeekV3.2 (#4465)
### What this PR does / why we need it? Currently, MTP does not support the DeepSeekV3.2 model. In this PR, we have enabled this feature. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 Signed-off-by: ZYang6263 <zy626375@gmail.com>
This commit is contained in:
@@ -47,6 +47,8 @@ PADDING_SLOT_ID = -1
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"DeepseekV32ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
|
||||
}
|
||||
|
||||
@@ -813,26 +815,28 @@ class MtpProposer(Proposer):
|
||||
attn_metadata_i.slot_mapping.fill_(-1)
|
||||
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||
last_token_indices = self.arange[:batch_size]
|
||||
if attn_metadata_i.num_decode_tokens != 0:
|
||||
if getattr(attn_metadata_i, "num_decode_tokens", 0):
|
||||
attn_metadata_i.num_decode_tokens = batch_size
|
||||
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
decode_metadata = getattr(attn_metadata_i, "decode", None)
|
||||
prefill_metadata = getattr(attn_metadata_i, "prefill", None)
|
||||
# When disable_padded_drafter_batch=False, it should not to be updating these params, maybe.
|
||||
if self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \
|
||||
aclgraph_runtime_mode != CUDAGraphMode.FULL):
|
||||
decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||
1:batch_size + 1].tolist()
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
attn_metadata_i.decode.actual_seq_lengths_q = \
|
||||
decode_metadata.actual_seq_lengths_q = \
|
||||
builder.pad_actual_seq_len_q_mtp_disable_pad(
|
||||
graph_pad_size - batch_size,
|
||||
batch_size,
|
||||
attn_metadata_i.decode.actual_seq_lengths_q)
|
||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||
decode_metadata.actual_seq_lengths_q)
|
||||
decode_metadata.cos = builder.cos_cache[
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||
decode_metadata.sin = builder.sin_cache[
|
||||
positions[:batch_size]].unsqueeze(1).unsqueeze(2)
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
# generates tokens beyond the max model length. Since it is complex
|
||||
@@ -870,32 +874,32 @@ class MtpProposer(Proposer):
|
||||
self.input_ids[batch_size:num_input_tokens] = 0
|
||||
self.hidden_states[batch_size:num_input_tokens].fill_(0)
|
||||
|
||||
if attn_metadata_i.prefill is not None:
|
||||
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
||||
if prefill_metadata is not None:
|
||||
prefill_metadata.seq_lens = attn_metadata_i.seq_lens
|
||||
prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist(
|
||||
)
|
||||
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.prefill.max_seq_lens += 1
|
||||
attn_metadata_i.prefill.max_seq_lens = min(
|
||||
attn_metadata_i.prefill.max_seq_lens,
|
||||
prefill_metadata.context_lens = attn_metadata_i.seq_lens
|
||||
prefill_metadata.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
prefill_metadata.max_seq_lens += 1
|
||||
prefill_metadata.max_seq_lens = min(
|
||||
prefill_metadata.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
if attn_metadata_i.decode is not None:
|
||||
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||
if decode_metadata is not None:
|
||||
decode_metadata.seq_lens = attn_metadata_i.seq_lens
|
||||
decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist(
|
||||
)
|
||||
decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list
|
||||
decode_seq_lens_list = decode_metadata.seq_lens_list
|
||||
if aclgraph_runtime_mode == CUDAGraphMode.FULL and \
|
||||
self.speculative_config.disable_padded_drafter_batch:
|
||||
attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [
|
||||
decode_metadata.seq_lens_list = decode_seq_lens_list + [
|
||||
0
|
||||
] * (graph_pad_size - len(decode_seq_lens_list))
|
||||
attn_metadata_i.decode.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
attn_metadata_i.decode.max_seq_lens += 1
|
||||
attn_metadata_i.decode.max_seq_lens = min(
|
||||
attn_metadata_i.decode.max_seq_lens,
|
||||
decode_metadata.input_positions = self.positions[:
|
||||
num_input_tokens]
|
||||
decode_metadata.max_seq_lens += 1
|
||||
decode_metadata.max_seq_lens = min(
|
||||
decode_metadata.max_seq_lens,
|
||||
self.runner.model_config.max_model_len)
|
||||
|
||||
# mtp>1: [batch_size, k]
|
||||
|
||||
Reference in New Issue
Block a user