From 7b02c32679dc3ebe341cc1c8d24372fb49e09bb2 Mon Sep 17 00:00:00 2001 From: Chang Su Date: Fri, 23 May 2025 18:11:54 -0700 Subject: [PATCH] [Bugfix](gemma3_mm): handle flatten_batch constraint for multiple images (#6562) --- python/sglang/srt/models/gemma3_mm.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index a112417c5..7d776da0d 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -288,13 +288,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): "MM inputs where only some items are precomputed." ) return torch.concat([item.precomputed_features for item in items]) - pixel_values = torch.stack( - flatten_nested_list([item.pixel_values for item in items]), dim=0 - ) - pixel_values = pixel_values.to(device=self.vision_tower.device) - pixel_values = pixel_values.to(dtype=self.language_model.dtype()) - vision_outputs = self.vision_tower(pixel_values=pixel_values) + # Process images one by one to handle flatten_batch=True constraint in vision_tower + all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) + vision_outputs_list = [] + + for pixel_value in all_pixel_values: + # Add batch dimension for single image processing + pixel_value_batch = pixel_value.unsqueeze(0) + pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device) + pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype()) + + vision_output = self.vision_tower(pixel_values=pixel_value_batch) + vision_outputs_list.append(vision_output) + + # Concatenate all vision outputs + vision_outputs = torch.cat(vision_outputs_list, dim=0) image_features = self.multi_modal_projector(vision_outputs) return image_features