[Feat] Adapted mtp function to Qwen3-next (#3918)

### What this PR does / why we need it?

Adapts mtp function to Qwen3-next.

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2025-11-07 16:39:03 +08:00
committed by GitHub
parent 46ef280105
commit 23b785fdfb
10 changed files with 244 additions and 15 deletions

View File

@@ -20,10 +20,17 @@
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
"""
import os
from unittest.mock import patch
from tests.e2e.conftest import VllmRunner
# NZ will cause precision error in Qwen3-Next
# When it is fixed, this set-up can be removed
_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4():
example_prompts = [
"Hello, my name is",
@@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4():
distributed_executor_backend="mp",
enforce_eager=True) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
example_prompts = [
"Hello, my name is",
@@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
}) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
max_tokens = 20
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp") as vllm_model:
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
tensor_parallel_size=4,
max_model_len=4096,
gpu_memory_utilization=0.8,
distributed_executor_backend="mp",
speculative_config={
"method": "qwen3_next_mtp",
"num_speculative_tokens": 1
}) as spec_vllm_model:
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
max_tokens)
del spec_vllm_model
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0]
spec_token_ids = spec_output[0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1]}")
print(f"spec_output: {spec_output[1]}")
assert matches > int(0.66 * len(ref_outputs))

View File

@@ -77,6 +77,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
mock_get_dcp_group.return_value = dcp_group
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.speculative_config = None
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_vllm_config.compilation_config.cudagraph_mode = None

View File

@@ -252,6 +252,17 @@ class AscendAttentionMetadataBuilder:
self.dcp_rank = get_decode_context_model_parallel_rank(
) if self.dcp_size > 1 else 0
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
def reorder_batch(self, input_batch,
scheduler_output: "SchedulerOutput") -> bool:
return False

View File

@@ -35,6 +35,10 @@ def register_model():
"PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
ModelRegistry.register_model(
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")

View File

@@ -260,6 +260,24 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 2.1: process the mutli-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = mixed_qkv_spec.view(
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0]
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
validate_data=False,
)
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions

View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Qwen3Next MTP model."""
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.qwen3_next_mtp import (
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory, maybe_prefix)
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
Qwen3NextRMSNorm)
@support_torch_compile
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Qwen3NextMultiTokenPredictor, self).__init__()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
self.config.hidden_size,
gather_output=True,
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f'{prefix}.fc')
# use old version mtp layer name to avoid a exception in vllm
self.layers = torch.nn.ModuleList(
CustomQwen3NextDecoderLayer(
vllm_config,
layer_type="full_attention",
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
) for idx in range(self.num_mtp_layers))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@support_torch_compile
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": ["up_proj", "down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
cache_config = vllm_config.cache_config
assert not cache_config.enable_prefix_caching, \
"Qwen3NextMTP currently does not support prefix caching"
self.quant_config = vllm_config.quant_config
super(Qwen3NextMTP, self).__init__()
self.config = config
self.model = CustomQwen3NextMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
prefix=maybe_prefix(prefix, "lm_head"))
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

View File

@@ -55,7 +55,7 @@ def causal_conv1d_ref(
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
final_states_out[..., :(width - 1)].copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)

View File

@@ -29,9 +29,9 @@ def get_spec_decode_method(method,
is_torchair_graph=False):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
elif method in ("eagle", "eagle3"):
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
elif method in ('deepseek_mtp', 'qwen3_next_mtp'):
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)

View File

@@ -1,3 +1,4 @@
import importlib
from typing import Optional
import numpy as np
@@ -12,7 +13,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import \
process_weights_after_loading
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import cdiv
@@ -42,6 +42,26 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1
_MTP_MODELS = {
"DeepseekV3ForCausalLM":
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
"Qwen3NextForCausalLM":
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
}
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
def _load_model(architecture):
if architecture not in _MTP_MODELS:
raise ValueError("Invalid architecture for mtp.")
module_name, model_name = _MTP_MODELS[architecture]
module = importlib.import_module(module_name)
model = getattr(module, model_name)
return model
class MtpProposer(Proposer):
@@ -150,9 +170,7 @@ class MtpProposer(Proposer):
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = DeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
self._init_mtp_model()
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
@@ -228,8 +246,7 @@ class MtpProposer(Proposer):
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if attn_metadata is not None and isinstance(attn_metadata, dict):
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
attn_metadata = self._get_attn_metadata(attn_metadata)
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
@@ -311,6 +328,20 @@ class MtpProposer(Proposer):
return draft_token_ids
def _init_mtp_model(self):
architecture = self.vllm_config.model_config.architecture
target_device = self.vllm_config.device_config.device
model = _load_model(architecture)
self.model = model(vllm_config=self.vllm_config).to(target_device)
def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]
return attn_metadata
def _prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@@ -1852,7 +1852,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.
num_decode_draft_tokens_cpu=self.num_draft_tokens.
gpu[:num_reqs],
)
attn_metadata_i = builder.build(
@@ -1948,11 +1948,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
or self.drafter.name == SpecDcodeType.EAGLE3):
attn_state = AscendAttentionState.ChunkedPrefill
else:
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
# splitfuse
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
attn_state = AscendAttentionState.ChunkedPrefill
@@ -2548,7 +2547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with ProfileExecuteDuration().capture_async("Draft"):
if self.speculative_config:
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.method == "deepseek_mtp" and \
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens