Fix qwen2 audio not working bug (#8600)

This commit is contained in:
Binyao Jiang
2025-08-09 00:42:29 -07:00
committed by GitHub
parent d3e67deb1b
commit 7b81f956eb
4 changed files with 59 additions and 12 deletions

View File

@@ -614,8 +614,7 @@ def general_mm_embed_routine(
input_ids: Input token IDs tensor
forward_batch: Batch information for model forward pass
language_model: Base language model to use
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
**kwargs: Additional arguments passed to language model

View File

@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
self.language_model = Qwen2ForCausalLM(
config.text_config, quant_config, prefix=add_prefix("model", prefix)
)
self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs for audio
audio_token_id: int = getattr(
mm_inputs, "audio_token_id", mm_inputs.im_token_id
)
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
return self.pattern.pad_input_tokens(input_ids, mm_inputs)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# Extract audio features from input items
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
audio_data_embedding_func=self.get_audio_feature,
data_embedding_funcs={
Modality.AUDIO: self.get_audio_feature,
},
positions=positions,
)