add model: qwen2-audio (#7596)

This commit is contained in:
Leng Yue
2025-07-04 21:13:10 -07:00
committed by GitHub
parent da3890e82a
commit 8364608930
5 changed files with 330 additions and 0 deletions

View File

@@ -0,0 +1,94 @@
import re
from typing import List, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
multimodal_tokens=MultimodalSpecialTokens(
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
if base_output is None:
return None
res = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
items = []
input_ids = res["input_ids"].flatten()
if (
"input_features" in res
and res["input_features"] is not None
and len(res["input_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
)
else:
audio_offsets = None
input_lengths = res["feature_attention_mask"].sum(dim=-1)
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
item = MultimodalDataItem(
audio_features=res["input_features"],
audio_feature_lens=output_lengths,
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_token_id": audio_token_id,
"audio_end_id": audio_end_id,
}