Support precomputed multimodal features for Qwen-VL and Gemma3 models. (#6136)

Co-authored-by: Yury Sulsky <ysulsky@tesla.com>
This commit is contained in:
Yury Sulsky
2025-05-16 12:26:15 -07:00
committed by GitHub
parent c23a7072b6
commit f19a9204cd
14 changed files with 592 additions and 125 deletions

View File

@@ -497,6 +497,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype