init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -4,23 +4,20 @@ import vllm_ascend.envs as envs_ascend
def register_model():
from .deepseek_dbo import CustomDeepseekDBOForCausalLM # noqa: F401
from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401
from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401
from .deepseek_v3 import CustomDeepseekV3ForCausalLM # noqa: F401
from .qwen2_5_vl import \
AscendQwen2_5_VLForConditionalGeneration # noqa: F401
from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401
from .qwen3 import CustomQwen3ForCausalLM # noqa: F401
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen2VLForConditionalGeneration",
"vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration")
ModelRegistry.register_model(
"Qwen3VLMoeForConditionalGeneration",
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLMoeForConditionalGeneration"
)
ModelRegistry.register_model(
"Qwen3VLForConditionalGeneration",
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen3VLForConditionalGeneration"
)
if envs_ascend.USE_OPTIMIZED_MODEL:
ModelRegistry.register_model(
"Qwen2_5_VLForConditionalGeneration",
@@ -32,30 +29,32 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
)
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v3:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",
"vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM")
ModelRegistry.register_model(
"Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model(
"PanguProMoEForCausalLM",
"vllm_ascend.models.pangu_moe:PanguProMoEForCausalLM")
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
ModelRegistry.register_model(
"Qwen3NextForCausalLM",
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")

File diff suppressed because it is too large Load Diff

View File

@@ -23,22 +23,20 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
SharedHead)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
class CustomDeepSeekShareHead(SharedHead):
@@ -65,6 +63,7 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
vllm_config = get_current_vllm_config()
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -75,10 +74,8 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "shared_head"))
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
model_config,
cache_config,
quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
prefix=prefix)
def forward(
self,
@@ -103,8 +100,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
residual=None)
hidden_states = residual + hidden_states
return hidden_states
@@ -171,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata=None, # type: ignore
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
@@ -183,14 +178,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
class CustomDeepSeekMTP(DeepSeekMTP):
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
@@ -199,8 +186,6 @@ class CustomDeepSeekMTP(DeepSeekMTP):
prefix=maybe_prefix(
prefix, "model"))
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
@@ -215,4 +200,4 @@ class CustomDeepSeekMTP(DeepSeekMTP):
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, previous_hidden_states,
inputs_embeds, spec_step_idx)
return hidden_states
return hidden_states

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2ForCausalLM
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass

View File

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
class AscendMLAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
mla_modules: AscendMLAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
qk_nope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=mla_modules.rotary_emb,
q_a_proj=mla_modules.q_a_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.mla_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def mla_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.mla_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
kv_cache, attn_metadata, need_gather_q_kv,
output)
return
def mla_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mla_forward",
op_func=mla_forward,
mutates_args=["output"],
fake_impl=mla_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
class AscendSparseFlashAttention(MultiHeadLatentAttention):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
sfa_modules: AscendSFAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
qk_nope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=sfa_modules.rotary_emb,
q_a_proj=sfa_modules.q_a_proj,
q_a_layernorm=sfa_modules.q_a_layernorm,
q_proj=sfa_modules.q_proj,
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
kv_a_layernorm=sfa_modules.kv_a_layernorm,
kv_b_proj=sfa_modules.kv_b_proj,
o_proj=sfa_modules.o_proj,
indexer=sfa_modules.indexer)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

File diff suppressed because it is too large Load Diff

View File

@@ -42,6 +42,8 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -291,6 +293,40 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
self.hidden_size, -1)
return out_weight
def pad_qkv_weight_scale_offset(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, 1)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head, :]
data2 = reshaped_data[:, :, self.
half_origin_hidden_size_per_attention_head:, :]
data1_paded = torch.nn.functional.pad(
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
data2_paded = torch.nn.functional.pad(
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1, 1)
return res
def pad_qkv_deq_scale_quant_bias(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head]
data2 = reshaped_data[:, :,
self.half_origin_hidden_size_per_attention_head:]
data1_paded = torch.nn.functional.pad(
data1, (0, self.half_pad_hidden_size_per_attention_head))
data2_paded = torch.nn.functional.pad(
data2, (0, self.half_pad_hidden_size_per_attention_head))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1)
return res
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
@@ -318,11 +354,23 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if ("attn.proj.weight" in name) and self.enable_pad:
if ("attn.proj.weight_scale" in name or
"attn.proj.weight_offset" in name) and self.enable_pad:
continue
elif ("attn.proj.deq_scale" in name
or "attn.proj.quant_bias" in name) and self.enable_pad:
continue
elif ("attn.qkv.weight_scale" in name
or "attn.qkv.weight_offset" in name) and self.enable_pad:
param.data = self.pad_qkv_weight_scale_offset(param.data)
elif ("attn.qkv.deq_scale" in name
or "attn.qkv.quant_bias" in name) and self.enable_pad:
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
elif ("attn.proj.weight" in name) and self.enable_pad:
param.data = self.pad_proj_weight(param.data)
if ("attn.qkv.weight" in name) and self.enable_pad:
elif ("attn.qkv.weight" in name) and self.enable_pad:
param.data = self.pad_qkv_weight(param.data)
if ("attn.qkv.bias" in name) and self.enable_pad:
elif ("attn.qkv.bias" in name) and self.enable_pad:
param.data = self.pad_qkv_bias(param.data)
loaded_params.add(name)
return loaded_params
@@ -450,12 +498,20 @@ class AscendQwen2_5_VLForConditionalGeneration(
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2_5_VisionTransformer(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:

View File

@@ -1,6 +1,5 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Adapted from vllm/model_executor/models/qwen2_5_vl.py
# Copyright 2023 The vLLM team.
#
# This file is a part of the vllm-ascend project.
@@ -27,10 +26,19 @@ import torch_npu
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
try:
from transformers.models.qwen3_vl.configuration_qwen3_vl import \
Qwen3VLConfig
from transformers.models.qwen3_vl_moe.configuration_qwen3_vl_moe import \
Qwen3VLMoeConfig
except ImportError:
pass
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
get_act_and_mul_fn)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2_5_vl import (
@@ -38,10 +46,29 @@ from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLDummyInputsBuilder,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLMultiModalProcessor,
Qwen2_5_VLProcessingInfo)
from vllm.model_executor.models.utils import maybe_prefix
try:
from vllm.model_executor.models.qwen3_vl import (
Qwen3_VisionBlock, Qwen3_VisionPatchEmbed, Qwen3_VisionTransformer,
Qwen3VLDummyInputsBuilder, Qwen3VLForConditionalGeneration,
Qwen3VLMultiModalProcessor, Qwen3VLProcessingInfo)
from vllm.model_executor.models.qwen3_vl_moe import (
Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeProcessingInfo)
except ImportError:
Qwen3_VisionBlock = object
Qwen3_VisionPatchEmbed = object
Qwen3_VisionTransformer = object
Qwen3VLDummyInputsBuilder = object
Qwen3VLForConditionalGeneration = object
Qwen3VLMultiModalProcessor = object
Qwen3VLProcessingInfo = object
Qwen3VLMoeForConditionalGeneration = object
Qwen3VLMoeProcessingInfo = object
from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
from vllm_ascend.utils import vllm_version_is
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
@@ -112,16 +139,14 @@ class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
class AscendQwen2_5_VisionBlock_Without_Padding(Qwen2_5_VisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix)
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
@@ -321,6 +346,133 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
return x
class AscendQwen3_VisionPatchEmbed(Qwen3_VisionPatchEmbed):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.matmul(
self.proj.weight.data.view(self.hidden_size, -1).transpose(0, 1))
x = x + self.proj.bias
return x
class AscendQwen3_VisionBlock(Qwen3_VisionBlock):
def __init__(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix, use_data_parallel)
self.attn = AscendQwen2_5_VisionAttention_Without_Padding(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
x = x + self.attn(
self.norm1(x), cu_seqlens=cu_seqlens, cos=cos, sin=sin)
x = x + self.mlp(self.norm2(x))
return x
class AscendQwen3_VisionTransformer(Qwen3_VisionTransformer):
def __init__(
self,
vision_config,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix,
use_data_parallel)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.patch_embed = AscendQwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
self.blocks = nn.ModuleList([
AscendQwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}")
for layer_idx in range(vision_config.depth)
])
self.hidden_size_per_attention_head = dist_utils.divide(
self.hidden_size, self.num_heads)
def cal_cos_sin(self, rotary_pos_emb):
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
sin = rotary_pos_emb.sin()
cos_new = torch.cat((cos, cos), dim=-1)
sin_new = torch.cat((sin, sin), dim=-1)
cos_new = cos_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
sin_new = sin_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
return cos_new, sin_new
def forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
grid_thw_tensor = torch.tensor(grid_thw,
device=self.device,
dtype=torch.int32)
cu_seqlens = torch.repeat_interleave(
grid_thw_tensor[:, 1] * grid_thw_tensor[:, 2],
grid_thw_tensor[:, 0]).cpu().to(torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
hidden_states = hidden_states.unsqueeze(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
cos, sin = self.cal_cos_sin(rotary_pos_emb)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(
layer_num)
deepstack_feature = self.deepstack_merger_list[
deepstack_merger_idx](hidden_states)
deepstack_feature_lists.append(deepstack_feature)
hidden_states = self.merger(hidden_states)
hidden_states = torch.cat(
[hidden_states] + deepstack_feature_lists,
dim=1) # [seq_len, hidden_size * (1 + depth_of_deepstack)]
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
Qwen2_5_VLMultiModalProcessor,
info=Qwen2_5_VLProcessingInfo,
@@ -332,12 +484,20 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config=config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
)
def _process_image_input(self, image_input) -> tuple[torch.Tensor, ...]:
@@ -371,3 +531,101 @@ class AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
merge_size = self.visual.spatial_merge_size
sizes = grid_thw.prod(-1) // merge_size // merge_size
return video_embeds.split(sizes.tolist())
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class AscendQwen3VLForConditionalGeneration(Qwen3VLForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen3VLConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
if vllm_version_is("0.10.2"):
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel)
else:
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel)
@MULTIMODAL_REGISTRY.register_processor(Qwen3VLMultiModalProcessor,
info=Qwen3VLMoeProcessingInfo,
dummy_inputs=Qwen3VLDummyInputsBuilder)
class AscendQwen3VLMoeForConditionalGeneration(
Qwen3VLMoeForConditionalGeneration):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
supports_encoder_tp_data = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.visual.": "visual.",
"lm_head.": "language_model.lm_head.",
"model.language_model.": "language_model.model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
config: Qwen3VLMoeConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if vllm_version_is("0.10.2"):
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
else:
self.visual = AscendQwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)

View File

@@ -40,6 +40,8 @@ from vllm.model_executor.models.qwen2_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@@ -343,10 +345,18 @@ class AscendQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
vllm_config.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
if vllm_version_is("0.10.2"):
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(
vllm_config.quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
else:
self.visual = AscendQwen2VisionTransformer(
self.config.vision_config,
norm_eps=getattr(self.config, "rms_norm_eps", 1e-6),
quant_config=vllm_config.quant_config,
prefix=maybe_prefix(prefix, "visual"),
)

View File

@@ -1,156 +0,0 @@
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import Qwen3Config
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
from vllm.model_executor.models.utils import (AutoWeightsLoader,
PPMissingLayer, maybe_prefix)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant
class CustomQwen3DecoderLayer(Qwen3DecoderLayer):
def __init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
if quant_config is None:
return
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
assert isinstance(quant_config, AscendQuantConfig), \
"Expected quant_config to be an instance of AscendQuantConfig"
if isinstance(self.self_attn.qkv_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.input_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.self_attn.qkv_proj,
eps=config.rms_norm_eps)
if isinstance(self.mlp.gate_up_proj.quant_method.quant_method,
AscendW8A8LinearMethod):
self.post_attention_layernorm = AddRMSNormW8A8Quant(
config.hidden_size,
layer=self.mlp.gate_up_proj,
eps=config.rms_norm_eps)
ALL_DECODER_LAYER_TYPES = {
"attention": CustomQwen3DecoderLayer,
}
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
# otherwise (seq_len, ).
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class CustomQwen3Model(Qwen2Model):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
decoder_layer_type=CustomQwen3DecoderLayer)
class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# add `CustomQwen3Model` to init self.model
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = CustomQwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@@ -17,14 +17,14 @@
# Adapted from vllm/model_executor/models/qwen3_moe.py
# This file is a part of the vllm-ascend project.
from typing import Optional, Union
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
@@ -45,11 +45,8 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.ops.sequence_parallel import (MetadataForPadding,
init_metadata_for_sp)
from vllm_ascend.utils import vllm_version_is
@@ -101,7 +98,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
self,
hidden_states,
attn_metadata=None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
@@ -120,7 +116,6 @@ class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
_metadata_for_padding=_metadata_for_padding,
)
return hidden_states
@@ -175,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
if vllm_version_is("0.10.2"):
self.mlp = Qwen3MoeSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@@ -189,60 +189,6 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.enable_sequence_parallelism = (
vllm_config.compilation_config.pass_config.
enable_sequence_parallelism if vllm_config is not None else False)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> torch.Tensor:
# To prevent precision issues during the decoder phase when only prefilling enables SP
if not self.enable_sequence_parallelism:
self.self_attn.o_proj.reduce_results = True
else:
self.self_attn.o_proj.reduce_results = not _metadata_for_padding.not_dummy_and_is_prefill if _metadata_for_padding is not None else True
# Self Attention
if residual is None:
residual = hidden_states
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
residual = _metadata_for_padding.padding_slice(residual)
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.padding_aligned_reduce_scatter(
hidden_states)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if not self.use_aclgraph:
hidden_states = self.mlp(
hidden_states, _metadata_for_padding=_metadata_for_padding)
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class CustomQwen3MoeModel(Qwen3MoeModel):
@@ -254,11 +200,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
self.num_redundant_experts = parallel_config.num_redundant_experts
else:
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.config = config
@@ -281,60 +224,8 @@ class CustomQwen3MoeModel(Qwen3MoeModel):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
_metadata_for_padding: Optional[MetadataForPadding] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
residual,
_metadata_for_padding=_metadata_for_padding)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if _metadata_for_padding and _metadata_for_padding.not_dummy_and_is_prefill:
hidden_states = _metadata_for_padding.allgather_unpadding_aligned(
hidden_states)
return hidden_states
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
@@ -357,7 +248,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
# Set MoE hyperparameters
self.expert_weights: list[torch.Tensor] = []
@@ -378,16 +268,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
_metadata_for_padding = init_metadata_for_sp(
input_ids, self.enable_sequence_parallelism)
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds, _metadata_for_padding)
return hidden_states

View File

@@ -0,0 +1,676 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# mypy: ignore-errors
"""Inference-only Qwen3Next model."""
from collections.abc import Iterable
from typing import Optional
import torch
from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.attention import AttentionBackend, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import RMSNormGated
from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule
from vllm.model_executor.layers.fla.ops.fused_recurrent import \
fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.fused_moe import FusedMoE
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.layernorm import \
GemmaRMSNorm as Qwen3NextRMSNorm
# yapf: enable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_mixer2 import \
mamba_v2_sharded_weight_loader
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import (
PPMissingLayer, extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.model_executor.models.qwen3_next import ( # isort: skip
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
fused_gdn_gating)
class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.gated_delta_net_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.gated_delta_net_state_shape(
self.tp_size, self.num_k_heads, self.num_v_heads, self.head_k_dim,
self.head_v_dim, self.conv_kernel_size, self.num_spec)
def __init__(
self,
config: Qwen3NextConfig,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = extract_layer_index(prefix)
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
self.prefix = prefix
self.config = config
self.model_config = model_config
self.cache_config = cache_config
self.quant_config = quant_config
self.speculative_config = speculative_config
self.num_spec = (self.speculative_config.num_speculative_tokens
if self.speculative_config else 0)
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.conv_dim,
bias=False,
prefix=f"{prefix}.conv1d",
)
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# projection of the input hidden states
self.projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
self.projection_size_ba = self.num_v_heads * 2
self.in_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[self.projection_size_qkvz, self.projection_size_ba],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
query_key_settings = (self.key_dim, 0, False)
value_settings = (self.value_dim, 0, False)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight, {
"weight_loader":
mamba_v2_sharded_weight_loader([
query_key_settings,
query_key_settings,
value_settings,
], self.tp_size, self.tp_rank)
})
# selective projection used to make dt, B and C input dependent
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(
torch.ones(self.num_v_heads // self.tp_size), )
self.A_log = nn.Parameter(
torch.empty(
divide(self.num_v_heads, self.tp_size),
dtype=torch.float32,
))
set_weight_attrs(self.A_log,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
self.norm = RMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
norm_before_gate=True,
device="npu",
)
self.out_proj = RowParallelLinear(self.value_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj")
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_masks = attn_metadata.spec_token_masks
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = (attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens +
attn_metadata.num_spec_decode_tokens)
num_accepted_tokens = attn_metadata.num_accepted_tokens
# 1. Set up dimensions for reshapes later
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
if spec_token_masks is not None:
spec_token_masks = spec_token_masks[:num_actual_tokens]
projected_states_qkvz, projected_states_ba = torch.split(
projected_states,
[
self.projection_size_qkvz // self.tp_size,
self.projection_size_ba // self.tp_size
],
dim=-1,
)
query, key, value, z, b, a = self.fix_query_key_value_ordering(
projected_states_qkvz, projected_states_ba)
query, key, value = map(lambda x: rearrange(x, 'l p d -> l (p d)'),
(query, key, value))
mixed_qkv = torch.cat((query, key, value), dim=-1)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if spec_sequence_masks is not None:
if (attn_metadata.num_prefills == 0
and attn_metadata.num_decodes == 0):
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv[spec_token_masks]
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:attn_metadata
.num_decodes],
# validate_data=True,
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(
mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(
mixed_qkv_non_spec)
beta = b.sigmoid()
g = fused_gdn_gating(self.A_log, a, self.dt_bias)
g, beta = map(lambda x: rearrange(x, 'l d -> 1 l d'), (g, beta))
if spec_sequence_masks is not None:
if (attn_metadata.num_prefills == 0
and attn_metadata.num_decodes == 0):
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g[:, spec_token_masks]
beta_spec = beta[:, spec_token_masks]
g_non_spec = g[:, ~spec_token_masks]
beta_non_spec = beta[:, ~spec_token_masks]
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
# 3. Recurrent attention
# 3.1: process the mutlti-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[:attn_metadata.
num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
))
else:
core_attn_out_spec, last_recurrent_state = None, None
# 3.2: process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[
non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
batch_size = initial_state.shape[0]
core_attn_out = []
last_recurrent_state = []
for b_idx in range(batch_size):
start, end = non_spec_query_start_loc[
b_idx], non_spec_query_start_loc[b_idx + 1]
cur_q = query_non_spec[:, start:end, ...]
cur_k = key_non_spec[:, start:end, ...]
cur_v = value_non_spec[:, start:end, ...]
cur_g = g_non_spec[:, start:end, ...]
cur_b = beta_non_spec[:, start:end, ...]
cur_state = initial_state[b_idx].unsqueeze(0)
(
cur_core_attn_out_non_spec,
cur_last_recurrent_state,
) = chunk_gated_delta_rule(
query=cur_q,
key=cur_k,
value=cur_v,
g=cur_g,
beta=cur_b,
initial_state=cur_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
)
core_attn_out.append(cur_core_attn_out_non_spec)
last_recurrent_state.append(cur_last_recurrent_state)
tar_dtype = core_attn_out[0].dtype
tar_device = core_attn_out[0].device
tar_shape = list(core_attn_out[0].shape)
tar_shape[1] = non_spec_query_start_loc[-1]
core_attn_out_non_spec = torch.empty(tar_shape,
dtype=tar_dtype,
device=tar_device)
for b_idx in range(batch_size):
cur_core_attn_out = core_attn_out[b_idx]
start, end = non_spec_query_start_loc[
b_idx], non_spec_query_start_loc[b_idx + 1]
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = (
fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[:attn_metadata.
num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
))
else:
core_attn_out_non_spec, last_recurrent_state = None, None
# Merge core attention output
if (spec_sequence_masks is not None
and core_attn_out_non_spec is not None):
core_attn_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
core_attn_out[:, spec_token_masks] = core_attn_out_spec
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
elif spec_sequence_masks is not None:
core_attn_out = core_attn_out_spec
else:
core_attn_out = core_attn_out_non_spec
z_shape_og = z.shape
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(z_shape_og)
core_attn_out = rearrange(core_attn_out, '... h d -> ... (h d)')
output[:num_actual_tokens], _ = self.out_proj(core_attn_out)
class CustomQwen3NextDecoderLayer(Qwen3NextDecoderLayer):
def __init__(
self,
vllm_config: VllmConfig,
layer_type: str,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
if self.layer_type == "linear_attention":
self.linear_attn = CustomQwen3NextGatedDeltaNet(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=f'{prefix}.linear_attn')
elif self.layer_type == "full_attention":
self.self_attn = Qwen3NextAttention(
config,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.self_attn',
)
else:
raise ValueError(f"Invalid layer_type {self.layer_type}")
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (self.layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3NextMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3NextRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.layer_scale = getattr(config, "layer_scale", False)
if self.layer_scale:
self.attn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
config.hidden_size,
dtype=config.torch_dtype,
), )
@support_torch_compile
class CustomQwen3NextModel(Qwen3NextModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config: Qwen3NextConfig = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
self.config = config
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(prefix: str):
return CustomQwen3NextDecoderLayer(
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.norm = Qwen3NextRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("in_proj", "in_proj_qkvz", 0),
("in_proj", "in_proj_ba", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name.startswith("mtp."):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# name = apply_attn_prefix(name, params_dict)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CustomQwen3NextForCausalLM(Qwen3NextForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Qwen3Next currently does not support prefix caching"
assert envs.VLLM_USE_V1, "Qwen3Next requires VLLM_USE_V1"
self.quant_config = vllm_config.quant_config
self.config = config
self.scheduler_config = scheduler_config
self.model = CustomQwen3NextModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# Set MoE hyperparameters
self.expert_weights = []
self.moe_layers: list[FusedMoE] = []
example_layer = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, Qwen3NextDecoderLayer)
if isinstance(layer.mlp, Qwen3NextSparseMoeBlock):
example_layer = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_layer is None:
raise RuntimeError("No Qwen3Next layer found in the model.layers.")
self.num_moe_layers = len(self.moe_layers)
self.num_expert_groups = 1
self.num_shared_experts = 0
self.num_logical_experts = example_layer.n_logical_experts
self.num_physical_experts = example_layer.n_physical_experts
self.num_local_physical_experts = example_layer.n_local_physical_experts
self.num_routed_experts = example_layer.n_routed_experts
self.num_redundant_experts = example_layer.n_redundant_experts