[BugFix][main] Adapted Qwen3-Next-MTP to chunked prefill (#4770)
### What this PR does / why we need it?
The pad `-1` modification is from
https://github.com/vllm-project/vllm/pull/25743.
It still has bugs for batched chunked prefill.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: drslark <slarksblood@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -45,7 +45,9 @@ _MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"DeepseekV32ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP")
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
||||
}
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
@@ -200,15 +202,16 @@ class MtpProposer(Proposer):
|
||||
process_weights_after_loading(self.model, draft_model_config,
|
||||
target_device)
|
||||
|
||||
# check if mtp model use main model's embedding and LMhead
|
||||
main_model = model
|
||||
if torch.equal(self.model.model.embed_tokens.weight,
|
||||
main_model.model.embed_tokens.weight):
|
||||
self.model.model.embed_tokens = main_model.model.embed_tokens
|
||||
for _, layer_module in self.model.model.layers.items():
|
||||
if torch.equal(layer_module.shared_head.head.weight,
|
||||
main_model.lm_head.weight):
|
||||
layer_module.shared_head.head = main_model.lm_head
|
||||
if self.vllm_config.model_config.is_deepseek_mla:
|
||||
# check if mtp model use main model's embedding and LMhead
|
||||
main_model = model
|
||||
if torch.equal(self.model.model.embed_tokens.weight,
|
||||
main_model.model.embed_tokens.weight):
|
||||
self.model.model.embed_tokens = main_model.model.embed_tokens
|
||||
for _, layer_module in self.model.model.layers.items():
|
||||
if torch.equal(layer_module.shared_head.head.weight,
|
||||
main_model.lm_head.weight):
|
||||
layer_module.shared_head.head = main_model.lm_head
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user