Feat: Support audio in Phi4-mm model (#8048)

This commit is contained in:
Binyao Jiang
2025-07-18 21:03:53 -07:00
committed by GitHub
parent d918ab7985
commit b7e951a6db
11 changed files with 3333 additions and 54 deletions

View File

@@ -729,6 +729,7 @@ register_conv_template(
sep="<|end|>",
stop_str="<|end|>",
image_token="<|endoftext10|>",
audio_token="<|endoftext11|>",
)
)

View File

@@ -239,6 +239,10 @@ class MultimodalDataItem:
# For gemma3n
input_features_mask: Optional[torch.Tensor] = None
# For phi4-mm
image_attention_mask: Optional[torch.Tensor] = None
audio_attention_mask: Optional[torch.Tensor] = None
@staticmethod
def is_empty_list(l):
if l is None:

View File

@@ -40,6 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.phi4mm_audio import AudioEmbedding
logger = logging.getLogger(__name__)
@@ -420,16 +421,49 @@ class Phi4MMForCausalLM(nn.Module):
model_dir=config._name_or_path,
)
if isinstance(config.embd_layer["audio_embd_layer"], dict):
embedding_config = {
"embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"],
**config.embd_layer["audio_embd_layer"],
}
else:
embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]}
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
image_attention_mask = torch.cat(
[item.image_attention_mask for item in items], dim=0
)
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
image_embeds = self.vision_encoder(
pixel_values, image_sizes, image_attention_mask
)
return torch.cat(image_embeds).type(dtype)
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# (e.g. multiple examples) and the second dim is the multi-audio dim
# (e.g. multiple audios in the same example)
embed_tokens_extend_param = next(self.embed_tokens_extend.parameters())
device = embed_tokens_extend_param.device
dtype = embed_tokens_extend_param.dtype
audio_embeds = [
self.embed_tokens_extend(
# item.feature: (num_audios_in_a_sequence, T, D)
# item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None
audio_features=item.feature.to(device).type(dtype),
audio_attention_mask=(
item.audio_attention_mask.to(device)
if item.audio_attention_mask is not None
else None
),
)
for item in items
]
return torch.cat(audio_embeds).type(dtype)
def forward(
self,
input_ids: torch.Tensor,
@@ -443,6 +477,7 @@ class Phi4MMForCausalLM(nn.Module):
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.AUDIO: self.get_audio_feature,
},
positions=positions,
)
@@ -464,6 +499,9 @@ class Phi4MMForCausalLM(nn.Module):
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
prefix_mapping = {
"model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.",
"model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.",
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
"model.": "language_model.model.",
}
@@ -472,7 +510,6 @@ class Phi4MMForCausalLM(nn.Module):
"img_processor.encoder.layers.26",
"img_processor.head",
"img_processor.post_layernorm",
"audio",
]
def _should_skip(name: str) -> bool:

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -158,6 +158,7 @@ class BaseMultimodalProcessor(ABC):
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
@@ -170,6 +171,7 @@ class BaseMultimodalProcessor(ABC):
"audio_feature_lens": Modality.AUDIO,
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
# Video-related attributes
"video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
@@ -251,7 +253,11 @@ class BaseMultimodalProcessor(ABC):
@staticmethod
def _load_single_item(
data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
data,
modality: Modality,
frame_count_limit=None,
audio_sample_rate: Optional[int] = None,
discard_alpha_channel=True,
):
"""
Load a single multimodal data.
@@ -268,7 +274,7 @@ class BaseMultimodalProcessor(ABC):
elif modality == Modality.VIDEO:
return load_video(data, frame_count_limit)
elif modality == Modality.AUDIO:
return load_audio(data)
return load_audio(data, audio_sample_rate)
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
@@ -282,6 +288,7 @@ class BaseMultimodalProcessor(ABC):
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
max_image_frames: int = 30,
audio_sample_rate: Optional[int] = None,
) -> Tuple[List, List]:
"""
load multimodal data parallelly using iterators.
@@ -324,6 +331,7 @@ class BaseMultimodalProcessor(ABC):
data,
modality,
frame_count_limit,
audio_sample_rate,
discard_alpha_channel,
)
)
@@ -352,6 +360,7 @@ class BaseMultimodalProcessor(ABC):
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
audio_sample_rate: Optional[int] = None,
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
@@ -390,6 +399,7 @@ class BaseMultimodalProcessor(ABC):
multimodal_tokens=multimodal_tokens,
data_iterators=data_iterators,
discard_alpha_channel=discard_alpha_channel,
audio_sample_rate=audio_sample_rate,
)
task_info_iter = iter(task_info)
futures_iter = iter(futures)

View File

@@ -1,6 +1,8 @@
import logging
from typing import List, Union
from transformers.processing_utils import ProcessorMixin
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.phi4mm import Phi4MMForCausalLM
from sglang.srt.multimodal.processors.base_processor import (
@@ -10,18 +12,58 @@ from sglang.srt.multimodal.processors.base_processor import (
logger = logging.getLogger(__name__)
_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>"
_IMAGE_SPECIAL_TOKEN_ID = 200010
# It is an adapter of hf phi4 mm processor to make it work for sglang
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py#L693
class Phi4MMProcessorAdapter(ProcessorMixin):
def __init__(self, _processor) -> None:
self._processor = _processor
def __call__(self, **kwargs):
result = self._processor(**kwargs)
# Map HuggingFace output keys to sglang standard keys
key_mapping = {
"input_image_embeds": "pixel_values",
"input_audio_embeds": "audio_features",
"audio_embed_sizes": "audio_feature_lens",
}
for hf_key, sglang_key in key_mapping.items():
if hf_key in result:
result[sglang_key] = result[hf_key]
# Filter out None or empty tensors from the result.
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
# from misclassifying audio content as image content, and vice versa.
filtered_result = {
k: v
for k, v in result.items()
if v is not None and (not hasattr(v, "numel") or v.numel() > 0)
}
return filtered_result
class Phi4MMImageProcessor(BaseMultimodalProcessor):
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
models = [Phi4MMForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.processor = Phi4MMProcessorAdapter(_processor)
super().__init__(hf_config, server_args, self.processor)
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
self.IMAGE_TOKEN = "<|endoftext10|>"
self.AUDIO_TOKEN = "<|endoftext11|>"
self.IM_TOKEN_ID = 200010
self.AUDIO_TOKEN_ID = 200011
self.AUDIO_SAMPLE_RATE = 16000
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_IMAGE_SPECIAL_TOKEN,
).build(_processor)
image_token=self.IMAGE_TOKEN,
image_token_id=self.IM_TOKEN_ID,
audio_token=self.AUDIO_TOKEN,
audio_token_id=self.AUDIO_TOKEN_ID,
).build(self.processor)
async def process_mm_data_async(
self,
@@ -32,46 +74,29 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
max_req_input_len,
**kwargs,
):
if audio_data:
logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
)
audio_data = []
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=self.multimodal_tokens,
)
if base_output is None:
return None
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
audios=base_output.audios,
audio_sample_rate=self.AUDIO_SAMPLE_RATE,
)
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=_IMAGE_SPECIAL_TOKEN_ID,
)
if base_output.audios is not None:
# hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file requires the audio input to be tuple of (audio, sample_rate)
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
base_output.audios = [
(audio, self.AUDIO_SAMPLE_RATE) for audio in base_output.audios
]
items = [
MultimodalDataItem(
feature=res["input_image_embeds"],
image_sizes=res["image_sizes"],
image_emb_mask=res["image_attention_mask"],
offsets=image_offsets,
modality=Modality.IMAGE,
)
]
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.multimodal_tokens
)
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"im_token_id": _IMAGE_SPECIAL_TOKEN_ID,
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"audio_token_id": self.AUDIO_TOKEN_ID,
}

View File

@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
) # Return an empty array and size tuple if no frames were found
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
def load_audio(
audio_file: str, sr: Optional[int] = None, mono: bool = True
) -> np.ndarray:
# Use soundfile here, since librosa use it under the hood,
# and librosa will not support audio loading in the future
import soundfile as sf
from scipy.signal import resample
if sr is None:
sr = 16000
# Load audio data
if isinstance(audio_file, bytes):
audio, original_sr = sf.read(BytesIO(audio_file))