diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index a6a43b4d..d6bf784f 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -47,6 +47,8 @@ PADDING_SLOT_ID = -1 _MTP_MODELS = { "DeepseekV3ForCausalLM": + ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), + "DeepseekV32ForCausalLM": ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP") } @@ -813,26 +815,28 @@ class MtpProposer(Proposer): attn_metadata_i.slot_mapping.fill_(-1) attn_metadata_i.query_start_loc = self.arange[:batch_size + 1] last_token_indices = self.arange[:batch_size] - if attn_metadata_i.num_decode_tokens != 0: + if getattr(attn_metadata_i, "num_decode_tokens", 0): attn_metadata_i.num_decode_tokens = batch_size input_ids = draft_token_ids_list[-1].int() positions += 1 + decode_metadata = getattr(attn_metadata_i, "decode", None) + prefill_metadata = getattr(attn_metadata_i, "prefill", None) # When disable_padded_drafter_batch=False, it should not to be updating these params, maybe. - if self.speculative_config.disable_padded_drafter_batch or \ - aclgraph_runtime_mode != CUDAGraphMode.FULL: - attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ + if decode_metadata is not None and (self.speculative_config.disable_padded_drafter_batch or \ + aclgraph_runtime_mode != CUDAGraphMode.FULL): + decode_metadata.actual_seq_lengths_q = attn_metadata_i.query_start_loc[ 1:batch_size + 1].tolist() if aclgraph_runtime_mode == CUDAGraphMode.FULL: - attn_metadata_i.decode.actual_seq_lengths_q = \ + decode_metadata.actual_seq_lengths_q = \ builder.pad_actual_seq_len_q_mtp_disable_pad( graph_pad_size - batch_size, batch_size, - attn_metadata_i.decode.actual_seq_lengths_q) - attn_metadata_i.decode.cos = builder.cos_cache[ + decode_metadata.actual_seq_lengths_q) + decode_metadata.cos = builder.cos_cache[ positions[:batch_size]].unsqueeze(1).unsqueeze(2) - attn_metadata_i.decode.sin = builder.sin_cache[ + decode_metadata.sin = builder.sin_cache[ positions[:batch_size]].unsqueeze(1).unsqueeze(2) # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex @@ -870,32 +874,32 @@ class MtpProposer(Proposer): self.input_ids[batch_size:num_input_tokens] = 0 self.hidden_states[batch_size:num_input_tokens].fill_(0) - if attn_metadata_i.prefill is not None: - attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens - attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist( + if prefill_metadata is not None: + prefill_metadata.seq_lens = attn_metadata_i.seq_lens + prefill_metadata.seq_lens_list = prefill_metadata.seq_lens.tolist( ) - attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens - attn_metadata_i.prefill.input_positions = self.positions[: - num_input_tokens] - attn_metadata_i.prefill.max_seq_lens += 1 - attn_metadata_i.prefill.max_seq_lens = min( - attn_metadata_i.prefill.max_seq_lens, + prefill_metadata.context_lens = attn_metadata_i.seq_lens + prefill_metadata.input_positions = self.positions[: + num_input_tokens] + prefill_metadata.max_seq_lens += 1 + prefill_metadata.max_seq_lens = min( + prefill_metadata.max_seq_lens, self.runner.model_config.max_model_len) - if attn_metadata_i.decode is not None: - attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens - attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist( + if decode_metadata is not None: + decode_metadata.seq_lens = attn_metadata_i.seq_lens + decode_metadata.seq_lens_list = decode_metadata.seq_lens.tolist( ) - decode_seq_lens_list = attn_metadata_i.decode.seq_lens_list + decode_seq_lens_list = decode_metadata.seq_lens_list if aclgraph_runtime_mode == CUDAGraphMode.FULL and \ self.speculative_config.disable_padded_drafter_batch: - attn_metadata_i.decode.seq_lens_list = decode_seq_lens_list + [ + decode_metadata.seq_lens_list = decode_seq_lens_list + [ 0 ] * (graph_pad_size - len(decode_seq_lens_list)) - attn_metadata_i.decode.input_positions = self.positions[: - num_input_tokens] - attn_metadata_i.decode.max_seq_lens += 1 - attn_metadata_i.decode.max_seq_lens = min( - attn_metadata_i.decode.max_seq_lens, + decode_metadata.input_positions = self.positions[: + num_input_tokens] + decode_metadata.max_seq_lens += 1 + decode_metadata.max_seq_lens = min( + decode_metadata.max_seq_lens, self.runner.model_config.max_model_len) # mtp>1: [batch_size, k]