diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 72107e0eb..b35e902c0 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module): need_vision = need_vision & has_pixel if need_vision.any(): - pixel_values = torch.tensor( - np.array([pixel_values[i] for i in range(bs) if need_vision[i]]), - device=self.vision_tower.device, - ) + pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] + image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] ########## Encode Image ######## - if pixel_values.ndim == 5: + if pixel_values[0].ndim == 4: # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images - concat_images = torch.cat( - [image for image in pixel_values], dim=0 - ) # ndim=4 + np.concatenate(pixel_values, axis=0) + # ndim=4 + concat_images = torch.tensor( + np.concatenate(pixel_values, axis=0), + device=self.vision_tower.device, + ) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in pixel_values] image_features = torch.split(image_features, split_sizes, dim=0) # hd image_features: BS, num_patch, 576, 4096 else: # normal pixel: BS, C=3, H=336, W=336 + pixel_values = torch.tensor( + np.array(pixel_values), device=self.vision_tower.device + ) image_features = self.encode_images(pixel_values) # image_features: BS, 576, 4096