diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index c2a2a845..9f6d7874 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -13,6 +13,7 @@ from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import \ process_weights_after_loading from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.utils import cdiv from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -139,6 +140,9 @@ class MtpProposer(Proposer): target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()) + target_indexer_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config target_device = self.vllm_config.device_config.device @@ -152,6 +156,13 @@ class MtpProposer(Proposer): draft_attn_layer_names = (get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase).keys() - target_attn_layer_names) + indexer_layers = get_layers_from_vllm_config(self.vllm_config, + DeepseekV32IndexerCache) + draft_indexer_layer_names = indexer_layers.keys( + ) - target_indexer_layer_names + # NOTE: Currently we don't have specific attention backend and attention metadata + # for deepseek v3.2 indexer, so we just exclude the indexer layers here. + draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names assert len(draft_attn_layer_names) == 1 self.attn_layer_name = list(draft_attn_layer_names)