Files
sglang/python/sglang/srt/managers/multimodal_processors/qwen_audio.py

95 lines
3.1 KiB
Python

import re
from typing import List, Union
import torch
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
models = [Qwen2AudioForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.AUDIO_TOKEN = "<|audio_bos|><|AUDIO|><|audio_eos|>"
self.AUDIO_TOKEN_REGEX = re.compile(
r"<\|audio_bos\|>(?:<\|AUDIO\|>)+<\|audio_eos\|>"
)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
multimodal_tokens=MultimodalSpecialTokens(
audio_token=self.AUDIO_TOKEN,
audio_token_regex=self.AUDIO_TOKEN_REGEX,
),
)
if base_output is None:
return None
res = self.process_mm_data(
input_text=base_output.input_text,
audio=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
audio_start_id = tokenizer.convert_tokens_to_ids("<|audio_bos|>")
audio_token_id = tokenizer.convert_tokens_to_ids("<|AUDIO|>")
audio_end_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>")
items = []
input_ids = res["input_ids"].flatten()
if (
"input_features" in res
and res["input_features"] is not None
and len(res["input_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
)
else:
audio_offsets = None
input_lengths = res["feature_attention_mask"].sum(dim=-1)
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
item = MultimodalDataItem(
feature=res["input_features"],
audio_feature_lens=output_lengths,
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_token_id": audio_token_id,
"audio_end_id": audio_end_id,
}