[Bugfix](gemma3_mm): handle flatten_batch constraint for multiple images (#6562)
This commit is contained in:
@@ -288,13 +288,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
"MM inputs where only some items are precomputed."
|
||||
)
|
||||
return torch.concat([item.precomputed_features for item in items])
|
||||
pixel_values = torch.stack(
|
||||
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
||||
)
|
||||
pixel_values = pixel_values.to(device=self.vision_tower.device)
|
||||
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
||||
|
||||
vision_outputs = self.vision_tower(pixel_values=pixel_values)
|
||||
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
||||
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
||||
vision_outputs_list = []
|
||||
|
||||
for pixel_value in all_pixel_values:
|
||||
# Add batch dimension for single image processing
|
||||
pixel_value_batch = pixel_value.unsqueeze(0)
|
||||
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
|
||||
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
|
||||
|
||||
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
|
||||
vision_outputs_list.append(vision_output)
|
||||
|
||||
# Concatenate all vision outputs
|
||||
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
||||
image_features = self.multi_modal_projector(vision_outputs)
|
||||
return image_features
|
||||
|
||||
|
||||
Reference in New Issue
Block a user