# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. from typing import Optional from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, Qwen3MoeDecoderLayer, Qwen3MoeForCausalLM, Qwen3MoeMLP, Qwen3MoeModel) from vllm.model_executor.models.utils import ( extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock from vllm_ascend.platform import VllmConfig class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): def __init__( self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: nn.Module.__init__(self) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, 'attention_bias', False), head_dim=getattr(config, 'head_dim', None), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = AscendSparseMoeBlock(config=config, quant_config=quant_config, prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp") self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @support_torch_compile class CustomQwen3MoeModel(Qwen3MoeModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: CustomQwen3MoeDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = CustomQwen3MoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors)