From 83646089301a83b55845e48e93be0cefb6a21a7b Mon Sep 17 00:00:00 2001 From: Leng Yue Date: Fri, 4 Jul 2025 21:13:10 -0700 Subject: [PATCH] add model: qwen2-audio (#7596) --- python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 34 +++ .../multimodal_processors/qwen_audio.py | 94 ++++++++ python/sglang/srt/models/qwen2.py | 1 + python/sglang/srt/models/qwen2_audio.py | 200 ++++++++++++++++++ 5 files changed, 330 insertions(+) create mode 100644 python/sglang/srt/managers/multimodal_processors/qwen_audio.py create mode 100644 python/sglang/srt/models/qwen2_audio.py diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 6f202db6f..25104bd80 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -593,6 +593,7 @@ multimodal_model_archs = [ "Mistral3ForConditionalGeneration", "MultiModalityCausalLM", "MllamaForConditionalGeneration", + "Qwen2AudioForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", "KimiVLForConditionalGeneration", diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index ad292bbd9..4dd368a15 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -59,6 +59,7 @@ class SeparatorStyle(IntEnum): METAMATH = auto() DeepSeekVL2 = auto() QWEN2_VL_EMBED = auto() + QWEN2_AUDIO = auto() GEMMA3 = auto() MPT = auto() @@ -350,6 +351,23 @@ class Conversation: else: ret += role return ret + elif self.sep_style == SeparatorStyle.QWEN2_AUDIO: + ret = "" if system_prompt == "" else system_prompt + self.sep + + counter = 1 + for role, message in self.messages: + if message: + while self.audio_token in message: + message = message.replace( + self.audio_token, self.audio_token.format(idx=counter), 1 + ) + counter += 1 + + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -904,6 +922,20 @@ register_conv_template( ) +register_conv_template( + Conversation( + name="qwen2-audio", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.QWEN2_AUDIO, + stop_str=["<|im_end|>"], + audio_token="Audio {idx}: <|audio_bos|><|AUDIO|><|audio_eos|>\n", + ) +) + + @register_conv_template_matching_function def match_internvl(model_path: str): if re.search(r"internvl2_5", model_path, re.IGNORECASE): @@ -956,6 +988,8 @@ def match_qwen_chat_ml(model_path: str): return "gme-qwen2-vl" if re.search(r"qwen.*vl", model_path, re.IGNORECASE): return "qwen2-vl" + if re.search(r"qwen.*audio", model_path, re.IGNORECASE): + return "qwen2-audio" if re.search( r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2", model_path, diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_audio.py b/python/sglang/srt/managers/multimodal_processors/qwen_audio.py new file mode 100644 index 000000000..0558b5f5a --- /dev/null +++ b/python/sglang/srt/managers/multimodal_processors/qwen_audio.py @@ -0,0 +1,94 @@ +import re +from typing import List, Union + +import torch + +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, + MultimodalSpecialTokens, +) +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration + + +class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor): + models = [Qwen2AudioForConditionalGeneration] + + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>" + self.AUDIO_TOKEN_REGEX = re.compile( + r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>" + ) + + async def process_mm_data_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + max_req_input_len, + **kwargs, + ): + audio_data = request_obj.audio_data + if not isinstance(audio_data, list): + audio_data = [audio_data] + + base_output = self.load_mm_data( + prompt=input_text, + max_req_input_len=max_req_input_len, + audio_data=audio_data, + multimodal_tokens=MultimodalSpecialTokens( + audio_token=self.AUDIO_TOKEN, + audio_token_regex=self.AUDIO_TOKEN_REGEX, + ), + ) + if base_output is None: + return None + + res = self.process_mm_data( + input_text=base_output.input_text, + audio=base_output.audios, + ) + + # Collect special token ids + tokenizer = self._processor.tokenizer + audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>") + audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>") + audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") + + items = [] + input_ids = res["input_ids"].flatten() + + if ( + "input_features" in res + and res["input_features"] is not None + and len(res["input_features"]) != 0 + ): + if audio_start_id is not None and audio_end_id is not None: + audio_offsets = self.get_mm_items_offset_by_pair( + input_ids=input_ids, + mm_start_id=audio_start_id, + mm_end_id=audio_end_id, + ) + else: + audio_offsets = None + + input_lengths = res["feature_attention_mask"].sum(dim=-1) + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + + item = MultimodalDataItem( + audio_features=res["input_features"], + audio_feature_lens=output_lengths, + audio_offsets=audio_offsets, + modality=Modality.AUDIO, + ) + items += [item] + + return { + "mm_items": items, + "input_ids": input_ids.tolist(), + "audio_start_id": audio_start_id, + "audio_token_id": audio_token_id, + "audio_end_id": audio_end_id, + } diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index d0608129a..987204d83 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -425,6 +425,7 @@ class Qwen2ForCausalLM(nn.Module): quant_config=quant_config, prefix=add_prefix("lm_head", prefix), ) + else: # ranks other than the last rank will have a placeholder layer self.lm_head = PPMissingLayer() diff --git a/python/sglang/srt/models/qwen2_audio.py b/python/sglang/srt/models/qwen2_audio.py new file mode 100644 index 000000000..53e087496 --- /dev/null +++ b/python/sglang/srt/models/qwen2_audio.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/1d45d90e5d1552eccb6d8cc9b7bba283ccefb808/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" +import logging +import math +from functools import lru_cache, partial +from typing import Any, Iterable, List, Optional, Tuple, Type, TypedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import AutoTokenizer, Qwen2AudioEncoderConfig, Qwen2Config +from transformers.activations import ACT2FN +from transformers.models.qwen2_audio.configuration_qwen2_audio import Qwen2AudioConfig +from transformers.models.qwen2_audio.modeling_qwen2_audio import ( + Qwen2AudioEncoder, + Qwen2AudioMultiModalProjector, +) + +from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.activation import QuickGELU +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternMultimodalTokens, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class Qwen2AudioForConditionalGeneration(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen2AudioConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if getattr(self.config, "audio_config", None) is None: + self.config.audio_config = Qwen2AudioEncoderConfig( + self.config._name_or_path + ) + + if getattr(self.config, "text_config", None) is None: + self.config.text_config = Qwen2Config(self.config._name_or_path) + + self.audio_tower = Qwen2AudioEncoder( + config.audio_config, + ) + self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) + self.language_model = Qwen2ForCausalLM( + config.text_config, quant_config, prefix=add_prefix("model", prefix) + ) + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Get all special token IDs for audio + audio_token_id: int = getattr( + mm_inputs, "audio_token_id", mm_inputs.im_token_id + ) + + pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id]) + return pattern.pad_input_tokens(input_ids, mm_inputs) + + def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: + # Extract audio features from input items + input_features = torch.cat([item.audio_features for item in items], dim=0).type( + self.audio_tower.dtype + ) + + audio_embeds = self.audio_tower(input_features).last_hidden_state + audio_embeds = self.multi_modal_projector(audio_embeds) + + audio_feature_lens = torch.cat([item.audio_feature_lens for item in items]) + new_embeds = [] + for i, d in zip(audio_feature_lens, audio_embeds): + new_embeds.append(d[: i.item()]) + + return torch.cat(new_embeds, dim=0) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: Any, + ) -> torch.Tensor: + hidden_states = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + audio_data_embedding_func=self.get_audio_feature, + positions=positions, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + if self.config.text_config.tie_word_embeddings and "lm_head.weight" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "audio_tower" in name: + continue + name_tmp = name.replace(weight_name, param_name) + + # Skip loading extra bias for GPTQ models. + if name_tmp.endswith(".bias") and name_tmp not in params_dict: + continue + param = params_dict[name_tmp] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Qwen2AudioForConditionalGeneration