[Feat] Adapted mtp function to Qwen3-next (#3918)
### What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -12,7 +13,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import cdiv
|
||||
@@ -42,6 +42,26 @@ logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
|
||||
}
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
|
||||
def _load_model(architecture):
|
||||
if architecture not in _MTP_MODELS:
|
||||
raise ValueError("Invalid architecture for mtp.")
|
||||
module_name, model_name = _MTP_MODELS[architecture]
|
||||
module = importlib.import_module(module_name)
|
||||
model = getattr(module, model_name)
|
||||
return model
|
||||
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
|
||||
@@ -150,9 +170,7 @@ class MtpProposer(Proposer):
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
self.model = DeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
self._init_mtp_model()
|
||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase).keys() -
|
||||
target_attn_layer_names)
|
||||
@@ -228,8 +246,7 @@ class MtpProposer(Proposer):
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
attn_metadata = self._get_attn_metadata(attn_metadata)
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
@@ -311,6 +328,20 @@ class MtpProposer(Proposer):
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def _init_mtp_model(self):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
target_device = self.vllm_config.device_config.device
|
||||
model = _load_model(architecture)
|
||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
Reference in New Issue
Block a user