Files
2025-10-09 16:47:16 +08:00

353 lines
13 KiB
Python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch ColPali model."""
import collections
import gc
import re
import unittest
from typing import ClassVar
import pytest
import torch
from datasets import load_dataset
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from transformers import (
is_torch_available,
)
from transformers.models.colpali.configuration_colpali import ColPaliConfig
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
from transformers.models.colpali.processing_colpali import ColPaliProcessor
from transformers.testing_utils import (
backend_empty_cache,
require_torch,
require_vision,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers.pytorch_utils import id_tensor_storage
class ColPaliForRetrievalModelTester:
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=0,
projector_hidden_act="gelu",
seq_length=25,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
projection_dim=32,
text_config={
"model_type": "gemma",
"seq_length": 128,
"is_training": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 8,
"intermediate_size": 37,
"hidden_activation": "gelu_pytorch_tanh",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 1,
},
is_training=False,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_image_tokens": 4,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
embedding_dim=128,
):
self.parent = parent
self.ignore_index = ignore_index
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.projection_dim = projection_dim
self.pad_token_id = text_config["pad_token_id"]
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
self.embedding_dim = embedding_dim
self.vlm_config = {
"model_type": "paligemma",
"text_config": self.text_config,
"vision_config": self.vision_config,
"ignore_index": self.ignore_index,
"image_token_index": self.image_token_index,
"projector_hidden_act": self.projector_hidden_act,
"projection_dim": self.projection_dim,
"vision_feature_select_strategy": self.vision_feature_select_strategy,
"vision_feature_layer": self.vision_feature_layer,
}
def get_config(self):
return ColPaliConfig(
vlm_config=self.vlm_config,
embedding_dim=self.embedding_dim,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.vlm_config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
# set the 16 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.vlm_config.image_token_index] = self.pad_token_id
input_ids[:, :16] = config.vlm_config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids,
"token_type_ids": torch.zeros_like(input_ids),
}
return config, inputs_dict
@require_torch
class ColPaliForRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `ColPaliForRetrieval`.
"""
all_model_classes = (ColPaliForRetrieval,) if is_torch_available() else ()
fx_compatible = False
test_torchscript = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
additional_model_inputs = ["token_type_ids"]
def setUp(self):
self.model_tester = ColPaliForRetrievalModelTester(self)
self.config_tester = ConfigTester(self, config_class=ColPaliConfig, has_text_modality=False)
@slow
@require_vision
def test_colpali_forward_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)
with torch.no_grad():
outputs = model(**inputs, return_dict=True)
self.assertIsInstance(outputs, ColPaliForRetrievalOutput)
# ColPali uses a VLM internally which has its state dict keys renames with `conversion_mapping`
# This test is written assuming that `_tied_weights_keys` are not going to be renamed, thus we
# overwrite it. NOTE: ColPali inference/save/load works without issues, it is the testcase
# that makes general assumptions
def test_tied_weights_keys(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
config.get_text_config().tie_word_embeddings = True
for model_class in self.all_model_classes:
model_tied = model_class(config)
ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)
# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
key = key.replace(".language_model", "") # remove 'language_model' prefix
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")
# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
key = key.replace(".language_model", "") # remove 'language_model' prefix
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]
tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecture seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="From PaliGemma: Some undefined behavior encountered with test versions of this model. Skip for now."
)
def test_model_parallelism(self):
pass
@unittest.skip(reason="PaliGemma's SigLip encoder uses a non-standard initialization scheme")
def test_initialization(self):
pass
# TODO extend valid outputs to include this test @Molbap
@unittest.skip(reason="PaliGemma has currently one output format.")
def test_model_outputs_equivalence(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
def test_sdpa_can_dispatch_on_flash(self):
pass
@unittest.skip(reason="Pass because ColPali requires `attention_mask is not None`")
@pytest.mark.torch_compile_test
def test_sdpa_can_compile_dynamic(self):
pass
@require_torch
class ColPaliModelIntegrationTest(unittest.TestCase):
model_name: ClassVar[str] = "vidore/colpali-v1.2-hf"
def setUp(self):
self.processor = ColPaliProcessor.from_pretrained(self.model_name)
def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
@slow
def test_model_integration_test(self):
"""
Test if the model is able to retrieve the correct pages for a small and easy dataset.
"""
model = ColPaliForRetrieval.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
device_map=torch_device,
).eval()
# Load the test dataset
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
# Preprocess the examples
batch_images = self.processor(images=ds["image"]).to(torch_device)
batch_queries = self.processor(text=ds["query"]).to(torch_device)
# Run inference
with torch.inference_mode():
image_embeddings = model(**batch_images).embeddings
query_embeddings = model(**batch_queries).embeddings
# Compute retrieval scores
scores = self.processor.score_retrieval(
query_embeddings=query_embeddings,
passage_embeddings=image_embeddings,
) # (num_queries, num_passages)
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
# Check if the maximum scores per row are in the diagonal of the matrix score
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
# Further validation: fine-grained check, with a hardcoded score from the original implementation
expected_scores = torch.tensor(
[
[15.5625, 6.5938, 14.4375],
[12.2500, 16.2500, 11.0000],
[15.0625, 11.7500, 21.0000],
],
dtype=scores.dtype,
)
assert torch.allclose(scores, expected_scores, atol=1), f"Expected scores {expected_scores}, got {scores}"