Feat: Support audio in Phi4-mm model (#8048)
This commit is contained in:
@@ -729,6 +729,7 @@ register_conv_template(
|
||||
sep="<|end|>",
|
||||
stop_str="<|end|>",
|
||||
image_token="<|endoftext10|>",
|
||||
audio_token="<|endoftext11|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -239,6 +239,10 @@ class MultimodalDataItem:
|
||||
# For gemma3n
|
||||
input_features_mask: Optional[torch.Tensor] = None
|
||||
|
||||
# For phi4-mm
|
||||
image_attention_mask: Optional[torch.Tensor] = None
|
||||
audio_attention_mask: Optional[torch.Tensor] = None
|
||||
|
||||
@staticmethod
|
||||
def is_empty_list(l):
|
||||
if l is None:
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.phi4mm_audio import AudioEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -420,16 +421,49 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
model_dir=config._name_or_path,
|
||||
)
|
||||
|
||||
if isinstance(config.embd_layer["audio_embd_layer"], dict):
|
||||
embedding_config = {
|
||||
"embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"],
|
||||
**config.embd_layer["audio_embd_layer"],
|
||||
}
|
||||
else:
|
||||
embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]}
|
||||
|
||||
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
dtype = next(self.vision_encoder.parameters()).dtype
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
|
||||
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
|
||||
image_attention_mask = torch.cat(
|
||||
[item.image_attention_mask for item in items], dim=0
|
||||
)
|
||||
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
||||
image_embeds = self.vision_encoder(
|
||||
pixel_values, image_sizes, image_attention_mask
|
||||
)
|
||||
return torch.cat(image_embeds).type(dtype)
|
||||
|
||||
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# (e.g. multiple examples) and the second dim is the multi-audio dim
|
||||
# (e.g. multiple audios in the same example)
|
||||
embed_tokens_extend_param = next(self.embed_tokens_extend.parameters())
|
||||
device = embed_tokens_extend_param.device
|
||||
dtype = embed_tokens_extend_param.dtype
|
||||
audio_embeds = [
|
||||
self.embed_tokens_extend(
|
||||
# item.feature: (num_audios_in_a_sequence, T, D)
|
||||
# item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None
|
||||
audio_features=item.feature.to(device).type(dtype),
|
||||
audio_attention_mask=(
|
||||
item.audio_attention_mask.to(device)
|
||||
if item.audio_attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
return torch.cat(audio_embeds).type(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -443,6 +477,7 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
language_model=self.language_model,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
Modality.AUDIO: self.get_audio_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
@@ -464,6 +499,9 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
prefix_mapping = {
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.",
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.",
|
||||
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
||||
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
||||
"model.": "language_model.model.",
|
||||
}
|
||||
@@ -472,7 +510,6 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
"img_processor.encoder.layers.26",
|
||||
"img_processor.head",
|
||||
"img_processor.post_layernorm",
|
||||
"audio",
|
||||
]
|
||||
|
||||
def _should_skip(name: str) -> bool:
|
||||
|
||||
1260
python/sglang/srt/models/phi4mm_audio.py
Normal file
1260
python/sglang/srt/models/phi4mm_audio.py
Normal file
File diff suppressed because it is too large
Load Diff
1917
python/sglang/srt/models/phi4mm_utils.py
Normal file
1917
python/sglang/srt/models/phi4mm_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -158,6 +158,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
"pixel_values_videos": Modality.VIDEO,
|
||||
"image_sizes": Modality.IMAGE,
|
||||
"image_grid_thw": Modality.IMAGE,
|
||||
"image_attention_mask": Modality.IMAGE,
|
||||
"image_emb_mask": Modality.IMAGE,
|
||||
"image_spatial_crop": Modality.IMAGE,
|
||||
"tgt_size": Modality.IMAGE,
|
||||
@@ -170,6 +171,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
"audio_feature_lens": Modality.AUDIO,
|
||||
"input_features": Modality.AUDIO,
|
||||
"input_features_mask": Modality.AUDIO,
|
||||
"audio_attention_mask": Modality.AUDIO,
|
||||
# Video-related attributes
|
||||
"video_grid_thw": Modality.VIDEO,
|
||||
# Generic attributes that could apply to multiple modalities
|
||||
@@ -251,7 +253,11 @@ class BaseMultimodalProcessor(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _load_single_item(
|
||||
data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
|
||||
data,
|
||||
modality: Modality,
|
||||
frame_count_limit=None,
|
||||
audio_sample_rate: Optional[int] = None,
|
||||
discard_alpha_channel=True,
|
||||
):
|
||||
"""
|
||||
Load a single multimodal data.
|
||||
@@ -268,7 +274,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
elif modality == Modality.VIDEO:
|
||||
return load_video(data, frame_count_limit)
|
||||
elif modality == Modality.AUDIO:
|
||||
return load_audio(data)
|
||||
return load_audio(data, audio_sample_rate)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while loading data {data}: {e}")
|
||||
@@ -282,6 +288,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
image_estimated_frames_iter: Optional[iter] = None,
|
||||
image_scaling_factor: float = 1.0,
|
||||
max_image_frames: int = 30,
|
||||
audio_sample_rate: Optional[int] = None,
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
load multimodal data parallelly using iterators.
|
||||
@@ -324,6 +331,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
data,
|
||||
modality,
|
||||
frame_count_limit,
|
||||
audio_sample_rate,
|
||||
discard_alpha_channel,
|
||||
)
|
||||
)
|
||||
@@ -352,6 +360,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
audio_data: Optional[list] = None,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
audio_sample_rate: Optional[int] = None,
|
||||
) -> BaseMultiModalProcessorOutput:
|
||||
"""
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
@@ -390,6 +399,7 @@ class BaseMultimodalProcessor(ABC):
|
||||
multimodal_tokens=multimodal_tokens,
|
||||
data_iterators=data_iterators,
|
||||
discard_alpha_channel=discard_alpha_channel,
|
||||
audio_sample_rate=audio_sample_rate,
|
||||
)
|
||||
task_info_iter = iter(task_info)
|
||||
futures_iter = iter(futures)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import List, Union
|
||||
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.phi4mm import Phi4MMForCausalLM
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
@@ -10,18 +12,58 @@ from sglang.srt.multimodal.processors.base_processor import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>"
|
||||
_IMAGE_SPECIAL_TOKEN_ID = 200010
|
||||
|
||||
# It is an adapter of hf phi4 mm processor to make it work for sglang
|
||||
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py#L693
|
||||
class Phi4MMProcessorAdapter(ProcessorMixin):
|
||||
def __init__(self, _processor) -> None:
|
||||
self._processor = _processor
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
result = self._processor(**kwargs)
|
||||
|
||||
# Map HuggingFace output keys to sglang standard keys
|
||||
key_mapping = {
|
||||
"input_image_embeds": "pixel_values",
|
||||
"input_audio_embeds": "audio_features",
|
||||
"audio_embed_sizes": "audio_feature_lens",
|
||||
}
|
||||
for hf_key, sglang_key in key_mapping.items():
|
||||
if hf_key in result:
|
||||
result[sglang_key] = result[hf_key]
|
||||
|
||||
# Filter out None or empty tensors from the result.
|
||||
# This prevents the sglang function base_processor.collect_mm_items_from_processor_output()
|
||||
# from misclassifying audio content as image content, and vice versa.
|
||||
filtered_result = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if v is not None and (not hasattr(v, "numel") or v.numel() > 0)
|
||||
}
|
||||
return filtered_result
|
||||
|
||||
|
||||
class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
class Phi4MMMultimodalProcessor(BaseMultimodalProcessor):
|
||||
models = [Phi4MMForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.processor = Phi4MMProcessorAdapter(_processor)
|
||||
super().__init__(hf_config, server_args, self.processor)
|
||||
|
||||
# the following CONSTANTS come from hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file
|
||||
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
||||
self.IMAGE_TOKEN = "<|endoftext10|>"
|
||||
self.AUDIO_TOKEN = "<|endoftext11|>"
|
||||
self.IM_TOKEN_ID = 200010
|
||||
self.AUDIO_TOKEN_ID = 200011
|
||||
self.AUDIO_SAMPLE_RATE = 16000
|
||||
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_IMAGE_SPECIAL_TOKEN,
|
||||
).build(_processor)
|
||||
image_token=self.IMAGE_TOKEN,
|
||||
image_token_id=self.IM_TOKEN_ID,
|
||||
audio_token=self.AUDIO_TOKEN,
|
||||
audio_token_id=self.AUDIO_TOKEN_ID,
|
||||
).build(self.processor)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
@@ -32,46 +74,29 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
|
||||
max_req_input_len,
|
||||
**kwargs,
|
||||
):
|
||||
if audio_data:
|
||||
logger.warning(
|
||||
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
|
||||
)
|
||||
audio_data = []
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
max_req_input_len=max_req_input_len,
|
||||
audio_data=audio_data,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=self.multimodal_tokens,
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
res = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
audios=base_output.audios,
|
||||
audio_sample_rate=self.AUDIO_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
input_ids = res["input_ids"].flatten()
|
||||
image_offsets = self.get_mm_items_offset(
|
||||
input_ids=input_ids,
|
||||
mm_token_id=_IMAGE_SPECIAL_TOKEN_ID,
|
||||
)
|
||||
if base_output.audios is not None:
|
||||
# hugging-face microsoft/Phi-4-multimodal-instruct's processing_phi4mm.py file requires the audio input to be tuple of (audio, sample_rate)
|
||||
# ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/processing_phi4mm.py
|
||||
base_output.audios = [
|
||||
(audio, self.AUDIO_SAMPLE_RATE) for audio in base_output.audios
|
||||
]
|
||||
|
||||
items = [
|
||||
MultimodalDataItem(
|
||||
feature=res["input_image_embeds"],
|
||||
image_sizes=res["image_sizes"],
|
||||
image_emb_mask=res["image_attention_mask"],
|
||||
offsets=image_offsets,
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
]
|
||||
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
||||
base_output, self.multimodal_tokens
|
||||
)
|
||||
|
||||
return {
|
||||
"mm_items": items,
|
||||
"input_ids": input_ids.tolist(),
|
||||
"im_token_id": _IMAGE_SPECIAL_TOKEN_ID,
|
||||
"mm_items": mm_items,
|
||||
"im_token_id": self.IM_TOKEN_ID,
|
||||
"audio_token_id": self.AUDIO_TOKEN_ID,
|
||||
}
|
||||
|
||||
@@ -691,12 +691,17 @@ def decode_video_base64(video_base64):
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
|
||||
def load_audio(
|
||||
audio_file: str, sr: Optional[int] = None, mono: bool = True
|
||||
) -> np.ndarray:
|
||||
# Use soundfile here, since librosa use it under the hood,
|
||||
# and librosa will not support audio loading in the future
|
||||
import soundfile as sf
|
||||
from scipy.signal import resample
|
||||
|
||||
if sr is None:
|
||||
sr = 16000
|
||||
|
||||
# Load audio data
|
||||
if isinstance(audio_file, bytes):
|
||||
audio, original_sr = sf.read(BytesIO(audio_file))
|
||||
|
||||
Reference in New Issue
Block a user