[Feat] Adapted mtp function to Qwen3-next (#3918)

### What this PR does / why we need it?

Adapts mtp function to Qwen3-next.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2025-11-07 16:39:03 +08:00
committed by GitHub
parent 46ef280105
commit 23b785fdfb
10 changed files with 244 additions and 15 deletions

View File

@@ -1,3 +1,4 @@
import importlib
from typing import Optional
import numpy as np
@@ -12,7 +13,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
@@ -42,6 +42,26 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1
_MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM":
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
}
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
def _load_model(architecture):
if architecture not in _MTP_MODELS:
raise ValueError("Invalid architecture for mtp.")
module_name, model_name = _MTP_MODELS[architecture]
module = importlib.import_module(module_name)
model = getattr(module, model_name)
return model
class MtpProposer(Proposer):
@@ -150,9 +170,7 @@ class MtpProposer(Proposer):
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
@@ -228,8 +246,7 @@ class MtpProposer(Proposer):
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
attn_metadata = self._get_attn_metadata(attn_metadata)
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -311,6 +328,20 @@ class MtpProposer(Proposer):
return draft_token_ids
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)
def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]
return attn_metadata
def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,