[Refactor] simplify multimodal data processing (#8107)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-07-20 21:43:09 -07:00
committed by GitHub
parent c9e8613c97
commit 8430bfe3e9
30 changed files with 297 additions and 421 deletions

View File

@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module):
dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
image_attention_mask = torch.cat(
[item.image_attention_mask for item in items], dim=0
[
item.image_attention_mask
for item in items
if hasattr(item, "image_attention_mask")
],
dim=0,
)
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
image_embeds = self.vision_encoder(
@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module):
audio_features=item.feature.to(device).type(dtype),
audio_attention_mask=(
item.audio_attention_mask.to(device)
if item.audio_attention_mask is not None
if hasattr(item, "audio_attention_mask")
else None
),
)