[Refactor] simplify multimodal data processing (#8107)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
dtype = next(self.vision_encoder.parameters()).dtype
|
||||
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
|
||||
image_attention_mask = torch.cat(
|
||||
[item.image_attention_mask for item in items], dim=0
|
||||
[
|
||||
item.image_attention_mask
|
||||
for item in items
|
||||
if hasattr(item, "image_attention_mask")
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
|
||||
image_embeds = self.vision_encoder(
|
||||
@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
audio_features=item.feature.to(device).type(dtype),
|
||||
audio_attention_mask=(
|
||||
item.audio_attention_mask.to(device)
|
||||
if item.audio_attention_mask is not None
|
||||
if hasattr(item, "audio_attention_mask")
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user