diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d340cdbc..08930190 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -200,6 +200,16 @@ class MtpProposer(Proposer): process_weights_after_loading(self.model, draft_model_config, target_device) + # check if mtp model use main model's embedding and LMhead + main_model = model + if torch.equal(self.model.model.embed_tokens.weight, + main_model.model.embed_tokens.weight): + self.model.model.embed_tokens = main_model.model.embed_tokens + for _, layer_module in self.model.model.layers.items(): + if torch.equal(layer_module.shared_head.head.weight, + main_model.lm_head.weight): + layer_module.shared_head.head = main_model.lm_head + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ): self.update_stream: torch.npu.Stream = torch.npu.Stream()