[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)

### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.

It still has bugs for batched chunked prefill.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
drslark
2025-12-10 22:54:24 +08:00
committed by GitHub
parent 490ddf536f
commit 0fb1dc43a1
8 changed files with 646 additions and 28 deletions

View File

@@ -45,7 +45,9 @@ _MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"DeepseekV32ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM":
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
}
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
@@ -200,15 +202,16 @@ class MtpProposer(Proposer):
process_weights_after_loading(self.model, draft_model_config,
target_device)
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.model_config.is_deepseek_mla:
# check if mtp model use main model's embedding and LMhead
main_model = model
if torch.equal(self.model.model.embed_tokens.weight,
main_model.model.embed_tokens.weight):
self.model.model.embed_tokens = main_model.model.embed_tokens
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
main_model.lm_head.weight):
layer_module.shared_head.head = main_model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
):