[Refactor] Multimodal data processing for VLM (#6659)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
||||
self.visual.dtype
|
||||
)
|
||||
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
|
||||
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
|
||||
assert pixel_values.dim() == 2, pixel_values.dim()
|
||||
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
|
||||
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user