This commit is contained in:
2025-10-09 16:47:16 +08:00
parent c8feb4deb5
commit e27e3f16bb
5248 changed files with 1778505 additions and 0 deletions

View File

@@ -0,0 +1,289 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.image_utils import IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import Gemma3ImageProcessor
if is_torchvision_available():
from transformers import Gemma3ImageProcessorFast
class Gemma3ImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_normalize=True,
image_mean=IMAGENET_STANDARD_MEAN,
image_std=IMAGENET_STANDARD_STD,
do_convert_rgb=True,
do_pan_and_scan=True,
pan_and_scan_min_crop_size=10,
pan_and_scan_max_num_crops=2,
pan_and_scan_min_ratio_to_activate=1.2,
):
super().__init__()
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_convert_rgb = do_convert_rgb
self.do_pan_and_scan = do_pan_and_scan
self.pan_and_scan_min_crop_size = pan_and_scan_min_crop_size
self.pan_and_scan_max_num_crops = pan_and_scan_max_num_crops
self.pan_and_scan_min_ratio_to_activate = pan_and_scan_min_ratio_to_activate
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"do_pan_and_scan": self.do_pan_and_scan,
"pan_and_scan_min_crop_size": self.pan_and_scan_min_crop_size,
"pan_and_scan_max_num_crops": self.pan_and_scan_max_num_crops,
"pan_and_scan_min_ratio_to_activate": self.pan_and_scan_min_ratio_to_activate,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTester.prepare_image_inputs
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class Gemma3ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Gemma3ImageProcessor if is_vision_available() else None
fast_image_processing_class = Gemma3ImageProcessorFast if is_torchvision_available() else None
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.setUp with CLIP->Gemma3
def setUp(self):
super().setUp()
self.image_processor_tester = Gemma3ImageProcessingTester(self)
@property
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest.image_processor_dict
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_convert_rgb"))
self.assertTrue(hasattr(image_processing, "do_pan_and_scan"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_crop_size"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_max_num_crops"))
self.assertTrue(hasattr(image_processing, "pan_and_scan_min_ratio_to_activate"))
def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = image_processing_class.from_dict(self.image_processor_dict, size=84)
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
def test_without_pan_and_scan(self):
"""
Disable do_pan_and_scan parameter.
"""
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processor = image_processing_class.from_dict(self.image_processor_dict, do_pan_and_scan=False)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processor(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processor(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_pan_and_scan(self):
"""
Enables Pan and Scan path by choosing the correct input image resolution. If you are changing
image processor attributes for PaS, please update this test.
"""
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
"""This function prepares a list of PIL images"""
image_inputs = [np.random.randint(255, size=(3, 300, 600), dtype=np.uint8)] * 3
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
# Test not batched input, 3 images because we have base image + 2 crops
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (3, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched unbalanced, 9 images because we have base image + 2 crops per each item
encoded_images = image_processing(
[[image_inputs[0], image_inputs[1]], [image_inputs[2]]], return_tensors="pt"
).pixel_values
expected_output_image_shape = (9, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=True, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = (1, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = (7, 3, 18, 18)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)
@unittest.skip("Gemma3 doesn't work with 4 channels due to pan and scan method")
def test_call_numpy_4_channels(self):
pass
@require_vision
@require_torch
def test_slow_fast_equivalence_batched_pas(self):
if not self.test_slow_image_processor or not self.test_fast_image_processor:
self.skipTest(reason="Skipping slow/fast equivalence test")
if self.image_processing_class is None or self.fast_image_processing_class is None:
self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined")
if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop:
self.skipTest(
reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors"
)
crop_config = {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
image_processor_dict = self.image_processor_dict
image_processor_dict.update(crop_config)
dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)
image_processor_slow = self.image_processing_class(**image_processor_dict)
image_processor_fast = self.fast_image_processing_class(**image_processor_dict)
encoding_slow = image_processor_slow(dummy_images, return_tensors="pt")
encoding_fast = image_processor_fast(dummy_images, return_tensors="pt")
torch.testing.assert_close(encoding_slow.num_crops, encoding_fast.num_crops)
self._assert_slow_fast_tensors_equivalence(encoding_slow.pixel_values, encoding_fast.pixel_values)

View File

@@ -0,0 +1,830 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Gemma3 model."""
import logging
import tempfile
import unittest
import pytest
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Gemma3Config,
Gemma3TextConfig,
is_torch_available,
)
from transformers.testing_utils import (
Expectations,
cleanup,
is_flash_attn_2_available,
require_deterministic_for_xpu,
require_flash_attn,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_large_accelerator,
slow,
torch_device,
)
from ...generation.test_utils import GenerationTesterMixin
from ...models.gemma.test_modeling_gemma import GemmaModelTester
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import (
Gemma3ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
Gemma3Model,
Gemma3Processor,
Gemma3TextForSequenceClassification,
Gemma3TextModel,
)
from transformers.pytorch_utils import is_torch_greater_or_equal
class Gemma3ModelTester(GemmaModelTester):
if is_torch_available():
config_class = Gemma3TextConfig
model_class = Gemma3TextModel
for_causal_lm_class = Gemma3ForCausalLM
@require_torch
class Gemma3ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification) if is_torch_available() else ()
)
all_generative_model_classes = (Gemma3ForCausalLM,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
def setUp(self):
self.model_tester = Gemma3ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip("Gemma3 applies key/query norm which doesn't work with packing")
def test_sdpa_padding_matches_padding_free_with_position_ids(self):
pass
@unittest.skip(
"Gemma3 has no base model prefix which causes issues when loading base model from saved task model checkpoint"
)
def test_load_with_mismatched_shapes(self):
pass
def test_generation_beyond_sliding_window_tiny_model(self):
"""Test generation with a tiny randomly initialised model whose input length is larger than the `sliding_window`.
The model is configured with both `full_attention` and `sliding_attention` layers to make sure the hybrid cache
and mask slicing logic is covered.
"""
config = Gemma3TextConfig.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM")
config.attn_implementation = "eager"
config.layer_types = ["full_attention", "sliding_attention"]
config.sliding_window = 8
config.max_position_embeddings = 128
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-Gemma3ForCausalLM", config=config
).to(torch_device)
input_len = 10
input_ids = torch.tensor(
[
[42300, 241087, 255445, 81315, 193760, 184471, 67719, 98191, 210651, 124725],
[102294, 205314, 226646, 62020, 60245, 68025, 251839, 114053, 4695, 175511],
],
device=torch_device,
)
attention_mask = torch.ones_like(input_ids).to(torch_device)
with torch.no_grad():
_ = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=1,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
# 2 generations are needed to trigger https://github.com/huggingface/transformers/issues/39711
# Since it requires model._cache to have been previously initialized
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=5,
do_sample=False,
use_cache=True,
cache_implementation="hybrid",
)
generated_sequences = output[:, input_len:].cpu()
EXPECTED_OUTPUT = torch.tensor([[90109, 90109, 90109, 83191, 83191], [246901, 69832, 69832, 69832, 62288]])
torch.testing.assert_close(generated_sequences, EXPECTED_OUTPUT)
def test_gemma3_text_sequence_classification_model(self):
"""Test the text-only sequence classification model."""
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
input_ids = input_dict["input_ids"]
attention_mask = input_ids.ne(1).to(torch_device)
sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_labels)
model = Gemma3TextForSequenceClassification(config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, config.num_labels))
class Gemma3Vision2TextModelTester:
def __init__(
self,
parent,
mm_tokens_per_image=2,
image_token_index=4,
boi_token_index=5,
eoi_token_index=6,
seq_length=25,
is_training=True,
vision_config={
"use_labels": True,
"image_size": 20,
"patch_size": 5,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"num_key_value_heads": 1,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
use_cache=False,
):
self.parent = parent
# `image_token_index` is set to 0 to pass "resize_embeddings" test, do not modify
self.mm_tokens_per_image = mm_tokens_per_image
self.image_token_index = image_token_index
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.llm_tester = Gemma3ModelTester(self.parent)
self.text_config = self.llm_tester.get_config()
self.vision_config = vision_config
self.seq_length = seq_length
self.pad_token_id = self.text_config.pad_token_id
self.num_hidden_layers = self.text_config.num_hidden_layers
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
self.num_attention_heads = self.text_config.num_attention_heads
self.is_training = is_training
self.batch_size = 3
self.num_channels = vision_config["num_channels"]
self.image_size = vision_config["image_size"]
self.encoder_seq_length = seq_length
self.use_cache = use_cache
def get_config(self):
return Gemma3Config(
text_config=self.text_config,
vision_config=self.vision_config,
image_token_index=self.image_token_index,
boi_token_index=self.boi_token_index,
eoi_token_index=self.eoi_token_index,
mm_tokens_per_image=self.mm_tokens_per_image,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(self.pad_token_id).to(torch_device)
# set the 3 first tokens to be image, and ensure that no other tokens are image tokens
# do not change this unless you modified image size or patch size
input_ids[input_ids == config.image_token_index] = self.pad_token_id
input_ids[:, :1] = config.image_token_index
token_type_ids = torch.zeros_like(input_ids)
token_type_ids[input_ids == config.image_token_index] = 1
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
return config, inputs_dict
@require_torch
class Gemma3Vision2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (
(
Gemma3Model,
Gemma3ForConditionalGeneration,
Gemma3ForSequenceClassification,
)
if is_torch_available()
else ()
)
all_generative_model_classes = (Gemma3ForConditionalGeneration,) if is_torch_available() else ()
test_headmasking = False
test_pruning = False
test_missing_keys = False
_is_stateful = True
model_split_percents = [0.5, 0.6]
additional_model_inputs = ["token_type_ids"]
# MP works but offload doesn't work when the SigLIP MultiheadAttention is offloaded
# TODO: One potential solution would be to add to set preload_module_classes = ["SiglipMultiheadAttentionPoolingHead"]
# in the dispatch_model function
test_cpu_offload = False
test_disk_offload_safetensors = False
test_disk_offload_bin = False
def setUp(self):
self.model_tester = Gemma3Vision2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=Gemma3Config, hidden_size=37)
def test_bidirectional_image_attention(self):
"""
Tests that each image can attend to itself bidirectionally. However an image
cannot attend to future images, even within the same batch.
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "eager"
model = Gemma3Model(config).to(torch_device)
# First let's pass inputs without change which is one image per text and manipulate
# `token_type_ids` to make sure bidirectional mask is applied where it has to be
inputs_dict["token_type_ids"] = torch.zeros_like(inputs_dict["token_type_ids"])
inputs_dict["token_type_ids"][:, :4] = 1 # unmask first 4 tokens
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
# We expect a non-causal mask on first 4 tokens, thus no zeros
for attention in out.attentions:
self.assertTrue((attention[..., :4, :4] != 0).all().item())
# Now when removing `token_type_ids`, we will get simple causal mask
inputs_dict["token_type_ids"][:, :4] = 0 # mask back first 4 tokens
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
# We expect a causal mask on first 4 tokens, thus no zeros
for attention in out.attentions:
self.assertFalse((attention[..., :4, :4] != 0).all().item())
# Let's add two "images" per text, first one spanning 4 tokens and last one 3 tokens
inputs_dict["token_type_ids"][:, :4] = 1
inputs_dict["token_type_ids"][:, 7:10] = 1
with torch.no_grad():
out = model(**inputs_dict, output_attentions=True)
for attention in out.attentions:
self.assertTrue((attention[..., :4, :4] != 0).all().item())
self.assertTrue((attention[..., 7:10, 7:10] != 0).all().item())
# We expect a non-causal mask only within same image and no looking ahead to the future
self.assertTrue((attention[..., :4, 7:10] == 0).all().item())
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipVisionModel (vision backbone) does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="Siglip (vision backbone) uses the same initialization scheme as the Flax original implementation"
)
def test_initialization(self):
pass
@unittest.skip("Loading nested configs with overwritten `kwargs` isn't supported yet, FIXME @raushan.")
def test_load_with_mismatched_shapes(self):
pass
def test_automodelforcausallm(self):
"""
Regression test for #36741/#36917 -- make sure `AutoModelForCausalLM` works with a Gemma3 config, i.e. that
`AutoModelForCausalLM.from_pretrained` pulls the text config before loading the model
"""
config = self.model_tester.get_config()
model = Gemma3ForConditionalGeneration(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
for_causal_lm = AutoModelForCausalLM.from_pretrained(tmp_dir)
self.assertIsInstance(for_causal_lm, Gemma3ForConditionalGeneration)
@slow
@require_torch_accelerator
@require_read_token
class Gemma3IntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = Gemma3Processor.from_pretrained("google/gemma-3-4b-it", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
def tearDown(self):
cleanup(torch_device, gc_collect=True)
@require_deterministic_for_xpu
def test_model_4b_bf16(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water in the background. It looks like a lovely,'],
("cuda", (8, 0)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like'],
("cuda", (8, 6)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like'],
("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with turquoise water and a blue sky in the background. It looks like a'],
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant coastline in the background. It looks'],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
@require_deterministic_for_xpu
def test_model_4b_batch(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. It looks like a very sunny and',
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes:\n\n* **Image 1** shows a cow standing on a beach.',
],
("cuda", (8,0)):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown"
],
("cuda", (8,6)):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear blue water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a brown"
],
("rocm", (9, 4)):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
],
("rocm", (9, 5)):
[
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown and white cow standing on a sandy beach next to a turquoise ocean. There are some clouds in the blue',
'user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. They depict very different scenes. \n\n* **Image 1** shows a cow standing on a beach',
],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
def test_model_4b_crops(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
crop_config = {
"images_kwargs": {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
}
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
**crop_config,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 3 # one for the origin image and two crops of images
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.'],
("cuda", 7): [],
("cuda", (8, 6)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a clear blue sky with some white clouds above."],
("cuda", (8, 0)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background"],
("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the"],
("rocm", (9, 5)): ["user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background"]
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
print(f"Generated text: {output_text}")
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
@require_deterministic_for_xpu
def test_model_4b_batch_crops(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
crop_config = {
"images_kwargs": {
"do_pan_and_scan": True,
"pan_and_scan_max_num_crops": 448,
"pan_and_scan_min_crop_size": 32,
"pan_and_scan_min_ratio_to_activate": 0.3,
}
}
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
**crop_config,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_NUM_IMAGES = 9 # 3 * (one for the origin image and two crops of images) = 9
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
],
("cuda", 7): [],
("cuda", (8,0)): [
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a blue sky with some white clouds in the background",
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
],
("cuda", (8, 6)): [
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
],
("rocm", (9, 4)) : [
"user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There's a bright blue sky with some white clouds in the",
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a'
],
("rocm", (9, 5)) : [
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown cow standing on a sandy beach next to a turquoise ocean. There are clouds in the blue sky above.',
'user\nYou are a helpful assistant.\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nHere is the original image \n\n\n\n and here are some crops to help you see better \n\n\n\n \n\n\n\nAre these images identical?\nmodel\nNo, the images are not identical. \n\nThe first image shows a cow on a beach, while the second image shows a street scene with a',
],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(len(inputs["pixel_values"]), EXPECTED_NUM_IMAGES)
self.assertEqual(output_text, EXPECTED_TEXT)
@require_torch_large_accelerator
def test_model_4b_multiimage(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
},
{"type": "text", "text": "What do you see here?"},
],
},
]
inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image!\n\nHere's a description of the scene:\n\n* **Chinese Arch"],
("cuda", 7): [],
("cuda", (8, 0)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt looks like a street scene in a vibrant,"],
("cuda", (8, 6)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a city"],
("rocm", (9, 4)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Overall Scene:**\n\nIt appears to be a street scene in a vibrant"],
("rocm", (9, 5)): ["user\nYou are a helpful assistant.\n\n\n\n\n\nWhat do you see here?\nmodel\nOkay, let's break down what I see in this image:\n\n**Main Features:**\n\n* **Chinese Archway:** The most prominent"],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
@require_deterministic_for_xpu
def test_model_1b_text_only(self):
model_id = "google/gemma-3-1b-it"
model = Gemma3ForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
inputs = tokenizer("Write a poem about Machine Learning.", return_tensors="pt").to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'],
("cuda", 7): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
("cuda", 8): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
("rocm", (9, 4)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a silent stream,\nInto the neural net, a waking dream.\nAlgorithms hum, a coded grace,\n'],
("rocm", (9, 5)): ['Write a poem about Machine Learning.\n\n---\n\nThe data flows, a river deep,\nWith patterns hidden, secrets sleep.\nA neural net, a watchful eye,\nLearning'],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
# TODO: raushan FA2 generates gibberish for no reason, check later
@require_flash_attn
@require_torch_large_accelerator
@pytest.mark.flash_attn_test
def test_model_4b_flash_attn(self):
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to(torch_device)
inputs = self.processor.apply_chat_template(
self.messages,
tokenize=True,
return_dict=True,
return_tensors="pt",
add_generation_prompt=True,
).to(torch_device)
# cache_implementation="hybrid" an in the original transformers implementation
output = model.generate(**inputs, max_new_tokens=30, do_sample=False, cache_implementation="hybrid")
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = Expectations(
{
("xpu", 3): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
("cuda", 7): [],
("cuda", 8): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
("rocm", (9, 4)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with turquoise water and a distant island in the background. It looks like a sunny day'],
("rocm", (9, 5)): ['user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nThe image shows a brown and white cow standing on a sandy beach with a turquoise ocean and a distant island in the background. It looks like a sunny'],
}
) # fmt: skip
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
self.assertEqual(output_text, EXPECTED_TEXT)
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("eager",)])
def test_generation_beyond_sliding_window(self, attn_implementation: str):
"""Test that we can correctly generate beyond the sliding window. This is non trivial as
we need to correctly slice the attention mask in all cases (because we use a hybrid cache).
Outputs for every attention functions should be coherent and identical.
"""
model_id = "google/gemma-3-1b-it"
if attn_implementation == "flash_attention_2" and not is_flash_attn_2_available():
self.skipTest("FlashAttention2 is required for this test.")
input_text = [
"This is a nice place. " * 800 + "I really enjoy the scenery,", # This is larger than 4096 tokens
"A list of colors: red, blue", # This will almost all be padding tokens
]
tokenizer = AutoTokenizer.from_pretrained(model_id, padding="left")
inputs = tokenizer(input_text, padding=True, return_tensors="pt").to(torch_device)
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation=attn_implementation, dtype=torch.float16
).to(torch_device)
# Make sure prefill is larger than sliding window
input_size = inputs.input_ids.shape[-1]
self.assertTrue(input_size > model.config.sliding_window)
out = model.generate(**inputs, max_new_tokens=20, do_sample=False, cache_implementation="static")[
:, input_size:
]
output_text = tokenizer.batch_decode(out)
EXPECTED_COMPLETIONS = [
" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'",
", green, yellow, orange, purple, brown, black, white, gray.\n\nI'",
]
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
@pytest.mark.torch_export_test
def test_export_text_only_with_hybrid_cache(self):
if not is_torch_greater_or_equal("2.6.0"):
self.skipTest(reason="This test requires torch >= 2.6 to run.")
from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM
model_id = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(model_id)
self.assertEqual(model.config.cache_implementation, "hybrid")
# Export + hybrid cache
model.eval()
exportable_module = TorchExportableModuleForDecoderOnlyLM(model, batch_size=1, max_cache_len=1024)
exported_program = exportable_module.export(
input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
)
logging.info(f"\nExported program: {exported_program}")
# Test generation with the exported model
prompt = "What is the capital of France?"
max_new_tokens_to_generate = 20
# Generate text with the exported model
tokenizer = AutoTokenizer.from_pretrained(model_id)
export_generated_text = TorchExportableModuleForDecoderOnlyLM.generate(
exported_program, tokenizer, prompt, max_new_tokens=max_new_tokens_to_generate
)
logging.info(f"\nExport generated texts: '{export_generated_text}'")
input_text = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
eager_outputs = model.generate(
**input_text,
max_new_tokens=max_new_tokens_to_generate,
do_sample=False, # Use greedy decoding to match the exported model
cache_implementation="hybrid",
)
eager_generated_text = tokenizer.decode(eager_outputs[0], skip_special_tokens=True)
logging.info(f"\nEager generated texts: '{eager_generated_text}'")
self.assertEqual(export_generated_text, eager_generated_text)
def test_dynamic_sliding_window_is_default(self):
"""
Test that the dynamic sliding window cache (added in #40039) is the default cache implementation for Gemma3
models, despite the fact that Hub checkpoints may have `cache_implementation="hybrid"` (static sliding window).
"""
model_id = "google/gemma-3-1b-it"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
# the default cache is static sliding window
self.assertEqual(model.config.cache_implementation, "hybrid")
self.assertEqual(model.generation_config.cache_implementation, "hybrid")
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "What is the capital of France?"
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
forward_outputs = model(**model_inputs)
self.assertIn("DynamicSlidingWindowLayer", str(forward_outputs.past_key_values))
generate_outputs = model.generate(
**model_inputs, max_new_tokens=2, do_sample=False, return_dict_in_generate=True
)
self.assertIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))
# If we manually specify the cache implementation = "hybrid", it will use the static sliding window cache
generate_outputs = model.generate(
**model_inputs,
max_new_tokens=2,
do_sample=False,
return_dict_in_generate=True,
cache_implementation="hybrid",
)
self.assertNotIn("DynamicSlidingWindowLayer", str(generate_outputs.past_key_values))

View File

@@ -0,0 +1,164 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from typing import Optional
from transformers import Gemma3Processor, GemmaTokenizer
from transformers.testing_utils import get_tests_dir, require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Gemma3ImageProcessor
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
@require_vision
class Gemma3ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Gemma3Processor
@classmethod
def setUpClass(cls):
cls.tmpdirname = tempfile.mkdtemp()
gemma3_image_processor_kwargs = {
"do_pan_and_scan": True,
"pan_and_scan_min_crop_size": 256,
"pan_and_scan_max_num_crops": 4,
"pan_and_scan_min_ratio_to_activate": 1.2,
}
image_processor = Gemma3ImageProcessor.from_pretrained(
"google/siglip-so400m-patch14-384", **gemma3_image_processor_kwargs
)
extra_special_tokens = {
"image_token": "<image_soft_token>",
"boi_token": "<start_of_image>",
"eoi_token": "<end_of_image>",
}
tokenizer = GemmaTokenizer(SAMPLE_VOCAB, keep_accents=True, extra_special_tokens=extra_special_tokens)
processor_kwargs = cls.prepare_processor_dict()
processor = Gemma3Processor(image_processor=image_processor, tokenizer=tokenizer, **processor_kwargs)
processor.save_pretrained(cls.tmpdirname)
cls.image_token = processor.boi_token
# Copied from tests.models.llava.test_processing_llava.LlavaProcessorTest.test_get_num_vision_tokens
def test_get_num_vision_tokens(self):
"Tests general functionality of the helper used internally in vLLM"
processor = self.get_processor()
output = processor._get_num_multimodal_tokens(image_sizes=[(100, 100), (300, 100), (500, 30)])
self.assertTrue("num_image_tokens" in output)
self.assertEqual(len(output["num_image_tokens"]), 3)
self.assertTrue("num_image_patches" in output)
self.assertEqual(len(output["num_image_patches"]), 3)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tmpdirname, ignore_errors=True)
# TODO: raushan or arthur: add the real chat template
@staticmethod
def prepare_processor_dict():
return {
"chat_template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n", "image_seq_length": 3,
} # fmt: skip
# Override as Gemma3 needs images to be an explicitly nested batch
def prepare_image_inputs(self, batch_size: Optional[int] = None):
"""This function prepares a list of PIL images for testing"""
images = super().prepare_image_inputs(batch_size)
if isinstance(images, (list, tuple)):
images = [[image] for image in images]
return images
def test_text_with_image_tokens(self):
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
text_multi_images = f"{processor.boi_token}{processor.boi_token}Dummy text!"
text_single_image = f"{processor.boi_token}Dummy text!"
text_no_image = "Dummy text!"
image = self.prepare_image_inputs()
# If text has no image tokens, image should be `None`
with self.assertRaises(ValueError):
_ = processor(text=text_no_image, images=image, return_tensors="np")
# We can't be sure what is users intention: if user wants one image per text OR two images for first text and no image for second text
with self.assertRaises(ValueError):
_ = processor(text=[text_single_image, text_single_image], images=[image, image], return_tensors="np")
# The users is expected to be explicit about which image belong to which text by nesting the images list
out_multiimages = processor(text=text_multi_images, images=[image, image], return_tensors="np")
out_batch_oneimage = processor(
text=[text_single_image, text_single_image], images=[[image], [image]], return_tensors="np"
)
self.assertListEqual(
out_batch_oneimage[self.images_input_name].tolist(), out_multiimages[self.images_input_name].tolist()
)
def test_pan_and_scan(self):
processor_components = self.prepare_components()
processor_kwargs = self.prepare_processor_dict()
processor = self.processor_class(**processor_components, **processor_kwargs)
input_str = self.prepare_text_inputs(modalities="image")
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="np",
do_pan_and_scan=True,
image_seq_length=2,
pan_and_scan_min_crop_size=10,
)
# base image + 4 crops
self.assertEqual(len(inputs[self.images_input_name]), 5)
self.assertEqual(len(inputs[self.text_input_name][0]), 67)
def test_special_mm_token_truncation(self):
"""Tests that special vision tokens do not get truncated when `truncation=True` is set."""
processor = self.get_processor()
input_str = self.prepare_text_inputs(batch_size=2, modalities="image")
image_input = self.prepare_image_inputs(batch_size=2)
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=None,
padding=True,
)
with self.assertRaises(ValueError):
_ = processor(
text=input_str,
images=image_input,
return_tensors="pt",
truncation=True,
padding=True,
max_length=5,
)