Fix qwen2 audio not working bug (#8600)
This commit is contained in:
@@ -614,8 +614,7 @@ def general_mm_embed_routine(
|
|||||||
input_ids: Input token IDs tensor
|
input_ids: Input token IDs tensor
|
||||||
forward_batch: Batch information for model forward pass
|
forward_batch: Batch information for model forward pass
|
||||||
language_model: Base language model to use
|
language_model: Base language model to use
|
||||||
image_data_embedding_func: Function to embed image data
|
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
||||||
audio_data_embedding_func: Function to embed audio data
|
|
||||||
placeholder_tokens: Token IDs for multimodal placeholders
|
placeholder_tokens: Token IDs for multimodal placeholders
|
||||||
**kwargs: Additional arguments passed to language model
|
**kwargs: Additional arguments passed to language model
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
|
|||||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||||
general_mm_embed_routine,
|
general_mm_embed_routine,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Modality,
|
||||||
|
MultimodalDataItem,
|
||||||
|
MultimodalInputs,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|||||||
self.language_model = Qwen2ForCausalLM(
|
self.language_model = Qwen2ForCausalLM(
|
||||||
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
)
|
)
|
||||||
|
self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
# Get all special token IDs for audio
|
return self.pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
audio_token_id: int = getattr(
|
|
||||||
mm_inputs, "audio_token_id", mm_inputs.im_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
|
|
||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
||||||
|
|
||||||
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||||
# Extract audio features from input items
|
# Extract audio features from input items
|
||||||
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
language_model=self.language_model,
|
language_model=self.language_model,
|
||||||
audio_data_embedding_func=self.get_audio_feature,
|
data_embedding_funcs={
|
||||||
|
Modality.AUDIO: self.get_audio_feature,
|
||||||
|
},
|
||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -190,6 +190,53 @@ class TestGemma3nServer(TestOpenAIVisionServer):
|
|||||||
# self._test_audio_ambient_completion()
|
# self._test_audio_ambient_completion()
|
||||||
|
|
||||||
|
|
||||||
|
class TestQwen2AudioServer(TestOpenAIVisionServer):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.70",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
def test_audio_chat_completion(self):
|
||||||
|
self._test_audio_speech_completion()
|
||||||
|
self._test_audio_ambient_completion()
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_single_image_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_multi_turn_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_multi_images_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_video_images_chat_completion(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_regex(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Qwen2Audio does not support image
|
||||||
|
def test_mixed_batch(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestKimiVLServer(TestOpenAIVisionServer):
|
class TestKimiVLServer(TestOpenAIVisionServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
@@ -547,7 +547,7 @@ class TestOpenAIVisionServer(CustomTestCase):
|
|||||||
# bird song
|
# bird song
|
||||||
audio_response = self.get_audio_response(
|
audio_response = self.get_audio_response(
|
||||||
AUDIO_BIRD_SONG_URL,
|
AUDIO_BIRD_SONG_URL,
|
||||||
"Please listen to the audio snippet carefully and transcribe the content.",
|
"Please listen to the audio snippet carefully and transcribe the content in English.",
|
||||||
"ambient",
|
"ambient",
|
||||||
)
|
)
|
||||||
assert "bird" in audio_response
|
assert "bird" in audio_response
|
||||||
|
|||||||
Reference in New Issue
Block a user