diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 9e20a726a..9c1e529f0 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -185,11 +185,14 @@ class LlavaBaseForCausalLM(nn.Module): new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if modalities_list[image_idx] == 1: + if modalities_list[image_idx] == "image": image_aspect_ratio = ( self.config.image_aspect_ratio ) # single image - else: + elif ( + modalities_list[image_idx] == "multi-images" + or modalities_list[image_idx] == "video" + ): image_aspect_ratio = "pad" # multi image # image_aspect_ratio = ( # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" @@ -319,6 +322,21 @@ class LlavaBaseForCausalLM(nn.Module): .transpose(1, 2) .contiguous() ) # N, C, H*W + if "unpad" in self.mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + # Expand to (bs, 1, hidden_dim) and concat at the end of the image tokens + self.language_model.model.image_newline[ + None, None + ].expand( + image_feature.shape[0], + 1, + image_feature.shape[-1], + ), + ), + dim=1, + ) new_image_features.append(image_feature) image_features = new_image_features