[Refactor] Multimodal data processing for VLM (#6659)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
Xinyuan Tong
2025-06-04 11:22:33 -07:00
committed by GitHub
parent bd75690f4e
commit cf9815ba69
11 changed files with 248 additions and 167 deletions

View File

@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: