# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2026 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3-ASR model.""" from collections.abc import Iterable, Mapping, Sequence from typing import Any, Literal import numpy as np import torch import torch.nn as nn from transformers.feature_extraction_utils import BatchFeature from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, SupportsMRoPE, SupportsMultiModal, SupportsPP, SupportsTranscription, ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.models.qwen3_omni_moe_thinker import ( Qwen2_5OmniAudioFeatureInputs, Qwen3OmniMoeAudioEncoder, Qwen3OmniMoeThinkerMultiModalProcessor, ) from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, _merge_multimodal_embeddings, maybe_prefix, ) from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( AudioItem, ModalityData, MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItems, ) from vllm.multimodal.parse import ( AudioProcessorItems, DictEmbeddingItems, ModalityDataItems, MultiModalDataItems, MultiModalDataParser, ) from vllm.multimodal.processing import ( BaseDummyInputsBuilder, BaseProcessingInfo, PromptReplacement, PromptUpdate, ) from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.transformers_utils.configs.qwen3_asr import ( Qwen3ASRConfig, Qwen3ASRThinkerConfig, ) from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processors.qwen3_asr import ( Qwen3ASRProcessor, ) logger = init_logger(__name__) _ASR_TEXT_TAG = "" def _get_feat_extract_output_lengths(input_lengths: torch.Tensor): input_lengths_leave = input_lengths % 100 feat_lengths = (input_lengths_leave - 1) // 2 + 1 output_lengths = ( ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 ) return output_lengths class Qwen3ASRProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Qwen3ASRConfig).thinker_config def get_hf_processor(self, **kwargs: object) -> Qwen3ASRProcessor: processor = self.ctx.get_hf_processor( Qwen3ASRProcessor, use_fast=kwargs.pop("use_fast", True), **kwargs, ) if not hasattr(processor, "audio_token"): processor.audio_token = "<|audio_pad|>" return processor def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor: hf_processor = self.get_hf_processor(**kwargs) feature_extractor = hf_processor.feature_extractor assert isinstance(feature_extractor, WhisperFeatureExtractor) return feature_extractor def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": None} def get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.get_feature_extractor() return Qwen3ASRMultiModalDataParser( target_sr=feature_extractor.sampling_rate, expected_hidden_size=self._get_expected_hidden_size(), ) class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_audios = mm_counts.get("audio", 0) hf_processor = self.info.get_hf_processor() audio_token = hf_processor.audio_token return audio_token * num_audios def get_dummy_mm_data( self, seq_len: int, mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) feature_extractor = self.info.get_feature_extractor() target_audio_length = ( min( feature_extractor.chunk_length, 30, ) * feature_extractor.sampling_rate ) audio_overrides = mm_options.get("audio") return { "audio": self._get_dummy_audios( length=target_audio_length, num_audios=num_audios, overrides=audio_overrides, ), } def _qwen3asr_field_config(hf_inputs: Mapping[str, torch.Tensor]): audio_feature_lengths = hf_inputs.get("audio_feature_lengths", torch.empty((0,))) return dict( input_audio_features=MultiModalFieldConfig.flat_from_sizes( "audio", audio_feature_lengths, dim=1 ), feature_attention_mask=MultiModalFieldConfig.batched("audio"), audio_feature_lengths=MultiModalFieldConfig.batched("audio"), ) class Qwen3ASRMultiModalDataParser(MultiModalDataParser): def _parse_audio_data( self, data: dict[str, torch.Tensor] | ModalityData[AudioItem], ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, modality="audio", required_fields={"input_audio_features", "audio_feature_lengths"}, fields_factory=_qwen3asr_field_config, ) return super()._parse_audio_data(data) class Qwen3ASRMultiModalProcessor( Qwen3OmniMoeThinkerMultiModalProcessor, ): def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return _qwen3asr_field_config(hf_inputs) def _get_prompt_updates( self, mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, Any], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() vocab = tokenizer.get_vocab() audio_token = processor.audio_token audio_token_id = vocab[audio_token] out_mm_data = out_mm_kwargs.get_data() audio_feature_lengths = out_mm_data.get("audio_feature_lengths") feature_attention_mask = out_mm_data.get("feature_attention_mask") if audio_feature_lengths is None and feature_attention_mask is None: audio_output_lengths = [] elif audio_feature_lengths is not None: audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths) audio_output_lengths = audio_output_lens.tolist() elif feature_attention_mask is not None: assert isinstance(feature_attention_mask, torch.Tensor) audio_output_lens = _get_feat_extract_output_lengths( feature_attention_mask.sum(-1) ) audio_output_lengths = audio_output_lens.tolist() def get_replacement_qwen2_audio(item_idx: int): num_features = audio_output_lengths[item_idx] if num_features == 0: audios = mm_items.get_items("audio", AudioProcessorItems) audio = audios.get(item_idx) raise ValueError( f"The audio {audio} (len={len(audio)}) is too short " "to be represented inside the model" ) return [audio_token_id] * num_features return [ PromptReplacement( modality="audio", target=audio_token, replacement=get_replacement_qwen2_audio, ), ] @MULTIMODAL_REGISTRY.register_processor( Qwen3ASRMultiModalProcessor, info=Qwen3ASRProcessingInfo, dummy_inputs=Qwen3ASRDummyInputsBuilder, ) class Qwen3ASRForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsMRoPE, SupportsTranscription, ): supported_languages = ISO639_1_SUPPORTED_LANGS hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "thinker.lm_head.": "language_model.lm_head.", "thinker.model.": "language_model.model.", "thinker.": "", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("audio"): return "<|audio_start|><|audio_pad|><|audio_end|>" raise ValueError("Only audio modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.vllm_config = vllm_config # needed for torch compile forward context thinker_config: Qwen3ASRThinkerConfig = ( vllm_config.model_config.hf_config.thinker_config ) quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config self.config = thinker_config self.multimodal_config = multimodal_config self.quant_config = quant_config with self._mark_tower_model(vllm_config, "audio"): self.audio_tower = Qwen3OmniMoeAudioEncoder( thinker_config.audio_config, prefix=maybe_prefix(prefix, "audio_tower"), ) with self._mark_language_model(vllm_config): self.language_model = Qwen3ForCausalLM( vllm_config=vllm_config.with_hf_config( thinker_config.text_config, architectures=["Qwen3ForCausalLM"] ), prefix=maybe_prefix(prefix, "language_model"), ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors ) def _parse_and_validate_audio_input( self, **kwargs: object ) -> Qwen2_5OmniAudioFeatureInputs | None: input_audio_features = kwargs.pop("input_audio_features", None) audio_feature_lengths = kwargs.pop("audio_feature_lengths", None) feature_attention_mask = kwargs.pop("feature_attention_mask", None) if input_audio_features is None: return None return Qwen2_5OmniAudioFeatureInputs( type="audio_features", input_features=input_audio_features, audio_feature_lengths=audio_feature_lengths, feature_attention_mask=feature_attention_mask, ) def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} # Preserve the order of modalities if there are multiple of them # from the order of kwargs. for input_key in kwargs: if ( input_key in ("input_audio_features") and "audio" not in mm_input_by_modality ): mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( **kwargs ) return mm_input_by_modality def _process_audio_input( self, audio_input: Qwen2_5OmniAudioFeatureInputs, audio_hashes: list[str] | None = None, cached_audio_features: torch.Tensor | None = None, ) -> torch.Tensor: input_features = audio_input["input_features"] audio_feature_lengths = audio_input["audio_feature_lengths"] audio_output_lengths = _get_feat_extract_output_lengths(audio_feature_lengths) audio_features = self.audio_tower( input_features.to(self.audio_tower.dtype), feature_lens=audio_feature_lengths, aftercnn_lens=audio_output_lengths, ) return audio_features.split(audio_output_lengths.tolist()) def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) if not mm_input_by_modality: return [] # The result multimodal_embeddings is tuple of tensors, with each # tensor correspoending to a multimodal data item (image or video). multimodal_embeddings: tuple[torch.Tensor, ...] = () # NOTE: It is important to iterate over the keys in this dictionary # to preserve the order of the modalities. for modality in mm_input_by_modality: multimodal_input = mm_input_by_modality[modality] if modality == "audio": audio_embeddings = self._process_audio_input(multimodal_input) multimodal_embeddings += tuple(audio_embeddings) return multimodal_embeddings def embed_input_ids( self, input_ids: torch.Tensor, multimodal_embeddings: MultiModalEmbeddings | None = None, *, is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: inputs_embeds = self._embed_text_input_ids( input_ids, self.language_model.embed_input_ids, is_multimodal=is_multimodal, handle_oov_mm_token=handle_oov_mm_token, ) if multimodal_embeddings is None or len(multimodal_embeddings) == 0: return inputs_embeds inputs_embeds = _merge_multimodal_embeddings( inputs_embeds=inputs_embeds, multimodal_embeddings=multimodal_embeddings, is_multimodal=is_multimodal, ) return inputs_embeds def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, skip_prefixes=["talker.", "code2wav."], ) loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loaded_weights def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: seq_len = len(input_tokens) if not mm_features: # No audio features, just return linear positions llm_positions = ( torch.arange(seq_len, dtype=torch.long).view(1, -1).expand(3, -1) ) return llm_positions.clone(), 0 llm_pos_ids_list: list[torch.Tensor] = [] st = 0 for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): offset = mm_feature.mm_position.offset # Get audio feature length from mm_feature data audio_feature_length = mm_feature.data["audio_feature_lengths"].data if isinstance(audio_feature_length, torch.Tensor): audio_feature_length = audio_feature_length.item() audio_len = _get_feat_extract_output_lengths( torch.tensor(audio_feature_length) ).item() # Text segment before audio (includes audio_start token) text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 text_positions = ( torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx ) llm_pos_ids_list.append(text_positions) st_idx = st_idx + text_len # Audio token segment audio_positions = ( torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx ) llm_pos_ids_list.append(audio_positions) st = offset + audio_len # Handle remaining text (includes audio_end and any trailing text) if st < seq_len: st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0 text_len = seq_len - st final_text_positions = ( torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1) + st_idx ) llm_pos_ids_list.append(final_text_positions) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) if llm_positions.shape[1] != seq_len: raise RuntimeError("Position ids length mismatch with input ids length") mrope_position_delta = (llm_positions.max() + 1 - seq_len).item() return llm_positions, mrope_position_delta def get_mm_mapping(self) -> MultiModelKeys: """ Get the module prefix in multimodal models """ return MultiModelKeys.from_string_field( language_model="language_model", tower_model=["audio_tower."], ) @classmethod def get_speech_to_text_config( cls, model_config: ModelConfig, task_type: str ) -> SpeechToTextConfig: processor = cached_processor_from_config(model_config) feature_extractor: WhisperFeatureExtractor = processor.feature_extractor return SpeechToTextConfig( max_audio_clip_s=feature_extractor.chunk_length, sample_rate=feature_extractor.sampling_rate, ) @classmethod def get_generation_prompt( cls, audio: np.ndarray, model_config: ModelConfig, stt_config: SpeechToTextConfig, language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, to_language: str | None, ) -> PromptType: """Get the generation prompt to be used for transcription requests.""" tokenizer = cached_tokenizer_from_config(model_config) audio_placeholder = cls.get_placeholder_str("audio", 0) if task_type not in ("transcribe", "translate"): raise ValueError( f"Unsupported task_type '{task_type}'. " "Supported task types are 'transcribe' and 'translate'." ) full_lang_name_to = cls.supported_languages.get(to_language, to_language) if to_language is None: prompt = ( f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n" f"<|im_start|>assistant\n" ) else: prompt = ( f"<|im_start|>user\n{audio_placeholder}<|im_end|>\n" f"<|im_start|>assistant\nlanguage {full_lang_name_to}{_ASR_TEXT_TAG}" ) prompt_token_ids = tokenizer.encode(prompt) return TokensPrompt( prompt_token_ids=prompt_token_ids, multi_modal_data={"audio": audio}, ) @classmethod def post_process_output(cls, text: str) -> str: """ Post-process Qwen3-ASR raw output to extract clean transcription. The model outputs in format: "language {lang}{transcription}" This method strips the language prefix and asr_text tags. """ if not text: return "" if _ASR_TEXT_TAG not in text: return text # Split on tag and take the transcription part _, text_part = text.rsplit(_ASR_TEXT_TAG, 1) return text_part