Feat: Support audio in Phi4-mm model (#8048)
This commit is contained in:
@@ -40,6 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
|
||||
from sglang.srt.models.llama import LlamaForCausalLM
|
||||
from sglang.srt.models.phi4mm_audio import AudioEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -420,16 +421,49 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
model_dir=config._name_or_path,
|
||||
)
|
||||
|
||||
if isinstance(config.embd_layer["audio_embd_layer"], dict):
|
||||
embedding_config = {
|
||||
"embedding_cls": config.embd_layer["audio_embd_layer"]["embedding_cls"],
|
||||
**config.embd_layer["audio_embd_layer"],
|
||||
}
|
||||
else:
|
||||
embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"]}
|
||||
|
||||
self.embed_tokens_extend = AudioEmbedding(config, **embedding_config)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
dtype = next(self.vision_encoder.parameters()).dtype
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
|
||||
image_attention_mask = torch.cat([item.image_emb_mask for item in items], dim=0)
|
||||
image_attention_mask = torch.cat(
|
||||
[item.image_attention_mask for item in items], dim=0
|
||||
)
|
||||
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
||||
image_embeds = self.vision_encoder(
|
||||
pixel_values, image_sizes, image_attention_mask
|
||||
)
|
||||
return torch.cat(image_embeds).type(dtype)
|
||||
|
||||
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
# (e.g. multiple examples) and the second dim is the multi-audio dim
|
||||
# (e.g. multiple audios in the same example)
|
||||
embed_tokens_extend_param = next(self.embed_tokens_extend.parameters())
|
||||
device = embed_tokens_extend_param.device
|
||||
dtype = embed_tokens_extend_param.dtype
|
||||
audio_embeds = [
|
||||
self.embed_tokens_extend(
|
||||
# item.feature: (num_audios_in_a_sequence, T, D)
|
||||
# item.audio_attention_mask: (num_audios_in_a_sequence, T, D) BoolTensor or None
|
||||
audio_features=item.feature.to(device).type(dtype),
|
||||
audio_attention_mask=(
|
||||
item.audio_attention_mask.to(device)
|
||||
if item.audio_attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
return torch.cat(audio_embeds).type(dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -443,6 +477,7 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
language_model=self.language_model,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
Modality.AUDIO: self.get_audio_feature,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
@@ -464,6 +499,9 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
prefix_mapping = {
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.vision.": "embed_tokens_extend.audio_projection_for_vision.",
|
||||
"model.embed_tokens_extend.audio_embed.audio_projection.speech.": "embed_tokens_extend.audio_projection.",
|
||||
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
|
||||
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
|
||||
"model.": "language_model.model.",
|
||||
}
|
||||
@@ -472,7 +510,6 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
"img_processor.encoder.layers.26",
|
||||
"img_processor.head",
|
||||
"img_processor.post_layernorm",
|
||||
"audio",
|
||||
]
|
||||
|
||||
def _should_skip(name: str) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user