[Feat] Adapted mtp function to Qwen3-next (#3918)
### What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -20,10 +20,17 @@
|
||||
|
||||
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
|
||||
"""
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
# NZ will cause precision error in Qwen3-Next
|
||||
# When it is fixed, this set-up can be removed
|
||||
_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"
|
||||
|
||||
|
||||
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||
def test_models_distributed_Qwen3_NEXT_TP4():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4():
|
||||
distributed_executor_backend="mp",
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
||||
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
|
||||
}) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
max_tokens = 20
|
||||
|
||||
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.8,
|
||||
distributed_executor_backend="mp") as vllm_model:
|
||||
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.8,
|
||||
distributed_executor_backend="mp",
|
||||
speculative_config={
|
||||
"method": "qwen3_next_mtp",
|
||||
"num_speculative_tokens": 1
|
||||
}) as spec_vllm_model:
|
||||
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
del spec_vllm_model
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
ref_token_ids = ref_output[0]
|
||||
spec_token_ids = spec_output[0]
|
||||
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output[1]}")
|
||||
print(f"spec_output: {spec_output[1]}")
|
||||
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
|
||||
@@ -77,6 +77,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.mock_vllm_config.speculative_config = None
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
self.mock_vllm_config.cache_config.block_size = 64
|
||||
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||
|
||||
@@ -252,6 +252,17 @@ class AscendAttentionMetadataBuilder:
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
|
||||
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||
|
||||
def reorder_batch(self, input_batch,
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
@@ -35,6 +35,10 @@ def register_model():
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")
|
||||
|
||||
@@ -260,6 +260,24 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.1: process the mutli-query part
|
||||
if spec_sequence_masks is not None:
|
||||
mixed_qkv_spec = mixed_qkv_spec.view(
|
||||
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||
mixed_qkv_spec = causal_conv1d_update(
|
||||
mixed_qkv_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||
[:attn_metadata.num_spec_decodes],
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
validate_data=False,
|
||||
)
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
|
||||
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Qwen3Next MTP model."""
|
||||
import torch
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.qwen3_next_mtp import (
|
||||
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
|
||||
from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
|
||||
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
|
||||
Qwen3NextRMSNorm)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3NextMultiTokenPredictor, self).__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.fc')
|
||||
|
||||
# use old version mtp layer name to avoid a exception in vllm
|
||||
self.layers = torch.nn.ModuleList(
|
||||
CustomQwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen3NextMTP currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super(Qwen3NextMTP, self).__init__()
|
||||
self.config = config
|
||||
self.model = CustomQwen3NextMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
@@ -55,7 +55,7 @@ def causal_conv1d_ref(
|
||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||
dtype_in) # (batch, dim, width - 1)
|
||||
if final_states_out is not None:
|
||||
final_states_out.copy_(final_states)
|
||||
final_states_out[..., :(width - 1)].copy_(final_states)
|
||||
else:
|
||||
final_states_out = final_states
|
||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||
|
||||
@@ -29,9 +29,9 @@ def get_spec_decode_method(method,
|
||||
is_torchair_graph=False):
|
||||
if method == "ngram":
|
||||
return NgramProposer(vllm_config, device, runner)
|
||||
elif method in ["eagle", "eagle3"]:
|
||||
elif method in ("eagle", "eagle3"):
|
||||
return EagleProposer(vllm_config, device, runner)
|
||||
elif method == 'deepseek_mtp':
|
||||
elif method in ('deepseek_mtp', 'qwen3_next_mtp'):
|
||||
if is_torchair_graph:
|
||||
return TorchairMtpProposer(vllm_config, device, runner)
|
||||
return MtpProposer(vllm_config, device, runner)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -12,7 +13,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.utils import \
|
||||
process_weights_after_loading
|
||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import cdiv
|
||||
@@ -42,6 +42,26 @@ logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
|
||||
}
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
|
||||
def _load_model(architecture):
|
||||
if architecture not in _MTP_MODELS:
|
||||
raise ValueError("Invalid architecture for mtp.")
|
||||
module_name, model_name = _MTP_MODELS[architecture]
|
||||
module = importlib.import_module(module_name)
|
||||
model = getattr(module, model_name)
|
||||
return model
|
||||
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
|
||||
@@ -150,9 +170,7 @@ class MtpProposer(Proposer):
|
||||
with set_default_torch_dtype(
|
||||
draft_model_config.dtype), set_current_vllm_config(
|
||||
self.vllm_config):
|
||||
self.model = DeepSeekMTP(
|
||||
vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
self._init_mtp_model()
|
||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||
self.vllm_config, AttentionLayerBase).keys() -
|
||||
target_attn_layer_names)
|
||||
@@ -228,8 +246,7 @@ class MtpProposer(Proposer):
|
||||
attn_metadata=None,
|
||||
aux_hidden_states: torch.Tensor = None):
|
||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||
attn_metadata = self._get_attn_metadata(attn_metadata)
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
@@ -311,6 +328,20 @@ class MtpProposer(Proposer):
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
def _init_mtp_model(self):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
target_device = self.vllm_config.device_config.device
|
||||
model = _load_model(architecture)
|
||||
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@@ -1852,7 +1852,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_draft_tokens=self.num_draft_tokens.
|
||||
num_decode_draft_tokens_cpu=self.num_draft_tokens.
|
||||
gpu[:num_reqs],
|
||||
)
|
||||
attn_metadata_i = builder.build(
|
||||
@@ -1948,11 +1948,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
|
||||
or self.drafter.name == SpecDcodeType.EAGLE3):
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
# splitfuse
|
||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -2548,7 +2547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with ProfileExecuteDuration().capture_async("Draft"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
self.speculative_config.method == "deepseek_mtp" and \
|
||||
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
|
||||
not self.speculative_config.disable_padded_drafter_batch
|
||||
if use_padded_batch_for_eagle:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
|
||||
Reference in New Issue
Block a user