from collections.abc import Iterable from typing import Optional, Union import torch from torch import nn from transformers import Qwen3Config from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm_ascend.ops.layernorm import AddRMSNormW8A8Quant class CustomQwen3DecoderLayer(Qwen3DecoderLayer): def __init__( self, config: Qwen3Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__(config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix) if quant_config is None: return from vllm_ascend.quantization.quant_config import AscendQuantConfig from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod assert isinstance(quant_config, AscendQuantConfig), \ "Expected quant_config to be an instance of AscendQuantConfig" if isinstance(self.self_attn.qkv_proj.quant_method.quant_method, AscendW8A8LinearMethod): self.input_layernorm = AddRMSNormW8A8Quant( config.hidden_size, layer=self.self_attn.qkv_proj, eps=config.rms_norm_eps) if isinstance(self.mlp.gate_up_proj.quant_method.quant_method, AscendW8A8LinearMethod): self.post_attention_layernorm = AddRMSNormW8A8Quant( config.hidden_size, layer=self.mlp.gate_up_proj, eps=config.rms_norm_eps) ALL_DECODER_LAYER_TYPES = { "attention": CustomQwen3DecoderLayer, } @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, # otherwise (seq_len, ). "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, }) class CustomQwen3Model(Qwen2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix, decoder_layer_type=CustomQwen3DecoderLayer) class CustomQwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # add `CustomQwen3Model` to init self.model packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ], "gate_up_proj": [ "gate_proj", "up_proj", ], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config self.config = config self.lora_config = lora_config self.quant_config = quant_config self.model = CustomQwen3Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix( prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights)