[Feat] Adapted mtp function to Qwen3-next (#3918)
### What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -35,6 +35,10 @@ def register_model():
|
||||
"PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextForCausalLM",
|
||||
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")
|
||||
|
||||
@@ -260,6 +260,24 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
|
||||
# 2.1: process the mutli-query part
|
||||
if spec_sequence_masks is not None:
|
||||
mixed_qkv_spec = mixed_qkv_spec.view(
|
||||
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||
mixed_qkv_spec = causal_conv1d_update(
|
||||
mixed_qkv_spec,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||
[:attn_metadata.num_spec_decodes],
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
validate_data=False,
|
||||
)
|
||||
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
|
||||
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only Qwen3Next MTP model."""
|
||||
import torch
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.interfaces import SupportsPP
|
||||
from vllm.model_executor.models.qwen3_next_mtp import (
|
||||
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
|
||||
from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||
|
||||
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
|
||||
Qwen3NextRMSNorm)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Qwen3NextMultiTokenPredictor, self).__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
|
||||
self.config.hidden_size,
|
||||
gather_output=True,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.fc')
|
||||
|
||||
# use old version mtp layer name to avoid a exception in vllm
|
||||
self.layers = torch.nn.ModuleList(
|
||||
CustomQwen3NextDecoderLayer(
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": ["up_proj", "down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
"Qwen3NextMTP currently does not support prefix caching"
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super(Qwen3NextMTP, self).__init__()
|
||||
self.config = config
|
||||
self.model = CustomQwen3NextMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
Reference in New Issue
Block a user