diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index a112417c5..7d776da0d 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -288,13 +288,22 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): "MM inputs where only some items are precomputed." ) return torch.concat([item.precomputed_features for item in items]) - pixel_values = torch.stack( - flatten_nested_list([item.pixel_values for item in items]), dim=0 - ) - pixel_values = pixel_values.to(device=self.vision_tower.device) - pixel_values = pixel_values.to(dtype=self.language_model.dtype()) - vision_outputs = self.vision_tower(pixel_values=pixel_values) + # Process images one by one to handle flatten_batch=True constraint in vision_tower + all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) + vision_outputs_list = [] + + for pixel_value in all_pixel_values: + # Add batch dimension for single image processing + pixel_value_batch = pixel_value.unsqueeze(0) + pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device) + pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype()) + + vision_output = self.vision_tower(pixel_values=pixel_value_batch) + vision_outputs_list.append(vision_output) + + # Concatenate all vision outputs + vision_outputs = torch.cat(vision_outputs_list, dim=0) image_features = self.multi_modal_projector(vision_outputs) return image_features