From 866347a6210d2b5ec45a28d791b0e403cef71cb9 Mon Sep 17 00:00:00 2001 From: zzhxxx Date: Mon, 8 Dec 2025 10:33:29 +0800 Subject: [PATCH] Deepseek Mtp model uses the lm_head and embedding from the main model (#2790) ### What this PR does / why we need it? In the Deepseek technical report, it is mentioned that the embedding and lmhead layers of the MTP layer are shared with the main model, but the current implementation independently loads the complete embedding and lmhead. In the Deepseek-R1 model, their weight sizes are 129280*7168 in fp16 format, which is 1.72G. This PR fixes the MTP layer to use the lmhead and embedding of the main model, saving 3.45G of GPU memory in the pure DP scenario. The current process will first create temporary spaces for the embedding and lmhead in the mtp layer, then I will call torch.equal to determine if the two matrices are the same. If they are the same, they will be reused, and the previous tensor will be released. - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: zzhx1 Co-authored-by: wangxiyuan --- vllm_ascend/spec_decode/mtp_proposer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index d340cdbc..08930190 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -200,6 +200,16 @@ class MtpProposer(Proposer): process_weights_after_loading(self.model, draft_model_config, target_device) + # check if mtp model use main model's embedding and LMhead + main_model = model + if torch.equal(self.model.model.embed_tokens.weight, + main_model.model.embed_tokens.weight): + self.model.model.embed_tokens = main_model.model.embed_tokens + for _, layer_module in self.model.model.layers.items(): + if torch.equal(layer_module.shared_head.head.weight, + main_model.lm_head.weight): + layer_module.shared_head.head = main_model.lm_head + if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs( ): self.update_stream: torch.npu.Stream = torch.npu.Stream()