init
This commit is contained in:
0
transformers/tests/models/sam2_video/__init__.py
Normal file
0
transformers/tests/models/sam2_video/__init__.py
Normal file
546
transformers/tests/models/sam2_video/test_modeling_sam2_video.py
Normal file
546
transformers/tests/models/sam2_video/test_modeling_sam2_video.py
Normal file
@@ -0,0 +1,546 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch SAM2 model."""
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from transformers.testing_utils import (
|
||||
backend_empty_cache,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
from transformers.video_utils import load_video
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Sam2VideoModel, Sam2VideoProcessor
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def prepare_image():
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/truck.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_groceries_image():
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/groceries.jpg"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_dog_img():
|
||||
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/dog-sam.png"
|
||||
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
|
||||
return raw_image
|
||||
|
||||
|
||||
def prepare_video():
|
||||
video_url = "https://huggingface.co/datasets/hf-internal-testing/sam2-fixtures/resolve/main/bedroom.mp4"
|
||||
raw_video, _ = load_video(video_url)
|
||||
return raw_video
|
||||
|
||||
|
||||
@slow
|
||||
class Sam2VideoModelIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.video_model = Sam2VideoModel.from_pretrained("facebook/sam2.1-hiera-tiny").to(torch.float32)
|
||||
self.processor = Sam2VideoProcessor.from_pretrained("facebook/sam2.1-hiera-tiny")
|
||||
self.video_model.to(torch_device)
|
||||
self.video_model.eval()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
# clean-up as much as possible GPU memory occupied by PyTorch
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_inference_mask_generation_video_one_point(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_points=[[[[210, 350]]]],
|
||||
input_labels=[[[1]]],
|
||||
)
|
||||
outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = outputs.pred_masks
|
||||
self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
|
||||
video_res_masks = self.processor.post_process_masks([low_res_masks], [raw_video.shape[-3:-1]], binarize=False)[
|
||||
0
|
||||
]
|
||||
self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[0, 0, :3, :3],
|
||||
torch.tensor(
|
||||
[[-21.4113, -21.4113, -22.9687], [-23.3090, -23.3090, -24.2606], [-27.5705, -27.5705, -27.1616]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-21.4113, -21.4113], [-23.3090, -23.3090]]]],
|
||||
[[[[-20.1003, -20.1003], [-21.2294, -21.2294]]]],
|
||||
[[[[-19.9619, -19.9619], [-21.3060, -21.3060]]]],
|
||||
],
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_one_point_propagate_in_video_directly(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_points=[[[[210, 350]]]],
|
||||
input_labels=[[[1]]],
|
||||
)
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-21.4113, -21.4113], [-23.3090, -23.3090]]]],
|
||||
[[[[-20.1003, -20.1003], [-21.2294, -21.2294]]]],
|
||||
[[[[-19.9619, -19.9619], [-21.3060, -21.3060]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_multi_points(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_points=[[[[210, 350], [250, 220]]]],
|
||||
input_labels=[[[1, 1]]],
|
||||
)
|
||||
outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = outputs.pred_masks
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
|
||||
self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[0, 0, :3, :3],
|
||||
torch.tensor(
|
||||
[[-11.1487, -11.1487, -11.4202], [-11.6522, -11.6522, -11.8057], [-12.7829, -12.7829, -12.6715]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
# higher tolerance due to errors propagating from frame to frame
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-11.1487, -11.1487], [-11.6522, -11.6522]]]],
|
||||
[[[[-15.3821, -15.3821], [-16.0333, -16.0333]]]],
|
||||
[[[[-15.4855, -15.4855], [-16.4230, -16.4230]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_one_bb(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_boxes=[[[300, 0, 500, 400]]],
|
||||
)
|
||||
outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = outputs.pred_masks
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
|
||||
self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[0, 0, :3, :3],
|
||||
torch.tensor(
|
||||
[[-13.1427, -13.1427, -13.6418], [-13.7753, -13.7753, -14.1144], [-15.1957, -15.1957, -15.1757]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
# higher tolerance due to errors propagating from frame to frame
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-13.1427, -13.1427], [-13.7753, -13.7753]]]],
|
||||
[[[[-14.9998, -14.9998], [-15.7086, -15.7086]]]],
|
||||
[[[[-15.4558, -15.4558], [-16.1649, -16.1649]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_one_point_one_bb(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_boxes=[[[300, 0, 500, 400]]],
|
||||
input_points=[[[[460, 60]]]],
|
||||
input_labels=[[[1]]],
|
||||
)
|
||||
outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = outputs.pred_masks
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
|
||||
self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[0, 0, :3, :3],
|
||||
torch.tensor(
|
||||
[[-12.3525, -12.3525, -12.8907], [-13.0608, -13.0608, -13.4079], [-14.6511, -14.6511, -14.5694]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
# higher tolerance due to errors propagating from frame to frame
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-12.3525, -12.3525], [-13.0608, -13.0608]]]],
|
||||
[[[[-15.8181, -15.8181], [-16.4163, -16.4163]]]],
|
||||
[[[[-15.8900, -15.8900], [-16.5953, -16.5953]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_multi_objects_multi_points(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_ids = [2, 3] # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_ids,
|
||||
input_points=[[[[200, 300], [230, 250], [275, 175]], [[400, 150]]]],
|
||||
input_labels=[[[1, 1, 0], [1]]],
|
||||
)
|
||||
outputs = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = outputs.pred_masks
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[outputs.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
self.assertEqual(low_res_masks.shape, (2, 1, 256, 256))
|
||||
self.assertEqual(video_res_masks.shape, (2, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[:, 0, :2, :2], # first object
|
||||
torch.tensor(
|
||||
[[[-12.6294, -12.6294], [-13.3659, -13.3659]], [[-20.3319, -20.3319], [-22.0491, -22.0491]]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 2, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-12.6294, -12.6294], [-13.3659, -13.3659]]], [[[-20.3319, -20.3319], [-22.0491, -22.0491]]]],
|
||||
[[[[-18.5249, -18.5249], [-19.5830, -19.5830]]], [[[-17.5537, -17.5537], [-19.2259, -19.2259]]]],
|
||||
[[[[-14.2722, -14.2722], [-15.4622, -15.4622]]], [[[-18.3185, -18.3185], [-20.0314, -20.0314]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_mask_generation_video_batched_bb(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_ids = [2, 3] # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_ids,
|
||||
input_boxes=[[[300, 0, 500, 400], [400, 0, 600, 400]]],
|
||||
)
|
||||
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
print(video_res_masks.shape)
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 2, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
print(frames.shape)
|
||||
print(frames[:3, :, :, :2, :2])
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-13.1427, -13.1427], [-13.7753, -13.7753]]], [[[-8.4576, -8.4576], [-8.7329, -8.7329]]]],
|
||||
[[[[-14.9998, -14.9998], [-15.7086, -15.7086]]], [[[-9.2998, -9.2998], [-9.8947, -9.8947]]]],
|
||||
[[[[-15.4558, -15.4558], [-16.1649, -16.1649]]], [[[-10.4880, -10.4880], [-11.2098, -11.2098]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_propagate_video_from_mask_input(self):
|
||||
raw_video = prepare_video()
|
||||
inference_session = self.processor.init_video_session(video=raw_video, inference_device=torch_device)
|
||||
ann_frame_idx = 0 # the frame index we interact with
|
||||
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
||||
|
||||
# get input_mask
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_points=[[[[210, 350], [250, 220]]]],
|
||||
input_labels=[[[1, 1]]],
|
||||
)
|
||||
sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
|
||||
# set mask as input
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session=inference_session,
|
||||
frame_idx=ann_frame_idx,
|
||||
obj_ids=ann_obj_id,
|
||||
input_masks=self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0],
|
||||
)
|
||||
sam2_video_output = self.video_model(inference_session=inference_session, frame_idx=ann_frame_idx)
|
||||
low_res_masks = sam2_video_output.pred_masks
|
||||
self.assertEqual(low_res_masks.shape, (1, 1, 256, 256))
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
self.assertEqual(video_res_masks.shape, (1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[0, 0, :3, :3],
|
||||
torch.tensor(
|
||||
[[-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000], [-10.0000, -10.0000, -10.0000]]
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
# test propagate in video frames
|
||||
frames = []
|
||||
for sam2_video_output in self.video_model.propagate_in_video_iterator(
|
||||
inference_session=inference_session,
|
||||
start_frame_idx=ann_frame_idx,
|
||||
max_frame_num_to_track=2,
|
||||
):
|
||||
video_res_masks = self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], [raw_video.shape[-3:-1]], binarize=False
|
||||
)[0]
|
||||
frames.append(video_res_masks)
|
||||
frames = torch.stack(frames, dim=0)
|
||||
self.assertEqual(frames.shape, (3, 1, 1, raw_video.shape[-3], raw_video.shape[-2]))
|
||||
torch.testing.assert_close(
|
||||
frames[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-10.0000, -10.0000], [-10.0000, -10.0000]]]],
|
||||
[[[[-18.4807, -18.4807], [-19.1966, -19.1966]]]],
|
||||
[[[[-20.0512, -20.0512], [-20.9110, -20.9110]]]],
|
||||
],
|
||||
).to(torch_device),
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
)
|
||||
|
||||
def test_inference_propagate_on_streamed_video(self):
|
||||
raw_video = prepare_video()
|
||||
|
||||
inference_session = self.processor.init_video_session(inference_device=torch_device)
|
||||
video_res_masks = []
|
||||
max_frame_num_to_track = 3
|
||||
for frame_idx, frame in enumerate(raw_video):
|
||||
if frame_idx >= max_frame_num_to_track:
|
||||
break
|
||||
inputs = self.processor(images=frame, device=torch_device, return_tensors="pt")
|
||||
if frame_idx == 0:
|
||||
self.processor.add_inputs_to_inference_session(
|
||||
inference_session,
|
||||
frame_idx=0,
|
||||
obj_ids=1,
|
||||
input_points=[[[[210, 350], [250, 220]]]],
|
||||
input_labels=[[[1, 1]]],
|
||||
original_size=inputs.original_sizes[0],
|
||||
)
|
||||
sam2_video_output = self.video_model(inference_session=inference_session, frame=inputs.pixel_values[0])
|
||||
video_res_masks.append(
|
||||
self.processor.post_process_masks(
|
||||
[sam2_video_output.pred_masks], inputs.original_sizes, binarize=False
|
||||
)[0]
|
||||
)
|
||||
|
||||
video_res_masks = torch.stack(video_res_masks, dim=0)
|
||||
self.assertEqual(
|
||||
video_res_masks.shape, (max_frame_num_to_track, 1, 1, raw_video.shape[-3], raw_video.shape[-2])
|
||||
)
|
||||
# higher tolerance due to errors propagating from frame to frame
|
||||
torch.testing.assert_close(
|
||||
video_res_masks[:3, :, :, :2, :2],
|
||||
torch.tensor(
|
||||
[
|
||||
[[[[-11.1487, -11.1487], [-11.6522, -11.6522]]]],
|
||||
[[[[-15.3821, -15.3821], [-16.0333, -16.0333]]]],
|
||||
[[[[-15.4855, -15.4855], [-16.4230, -16.4230]]]],
|
||||
]
|
||||
).to(torch_device),
|
||||
atol=1e-2,
|
||||
rtol=1e-2,
|
||||
)
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.testing_utils import (
|
||||
require_torch,
|
||||
require_torchvision,
|
||||
require_vision,
|
||||
)
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
from transformers import AutoProcessor, Sam2ImageProcessorFast, Sam2VideoProcessor, Sam2VideoVideoProcessor
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_vision
|
||||
@require_torchvision
|
||||
class Sam2ProcessorTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.tmpdirname = tempfile.mkdtemp()
|
||||
image_processor = Sam2ImageProcessorFast()
|
||||
video_processor = Sam2VideoVideoProcessor()
|
||||
processor = Sam2VideoProcessor(image_processor, video_processor)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
def get_image_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
|
||||
|
||||
def get_video_processor(self, **kwargs):
|
||||
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).video_processor
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdirname)
|
||||
|
||||
def prepare_image_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
image_inputs = torch.randint(0, 256, size=(1, 3, 30, 400), dtype=torch.uint8)
|
||||
# image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
|
||||
return image_inputs
|
||||
|
||||
def prepare_mask_inputs(self):
|
||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||
or a list of PyTorch tensors if one specifies torchify=True.
|
||||
"""
|
||||
mask_inputs = torch.randint(0, 256, size=(1, 30, 400), dtype=torch.uint8)
|
||||
# mask_inputs = [Image.fromarray(x) for x in mask_inputs]
|
||||
return mask_inputs
|
||||
|
||||
def test_save_load_pretrained_additional_features(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor)
|
||||
processor.save_pretrained(self.tmpdirname)
|
||||
|
||||
image_processor_add_kwargs = self.get_image_processor(do_normalize=False, padding_value=1.0)
|
||||
|
||||
processor = Sam2VideoProcessor.from_pretrained(self.tmpdirname, do_normalize=False, padding_value=1.0)
|
||||
|
||||
self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string())
|
||||
self.assertIsInstance(processor.image_processor, Sam2ImageProcessorFast)
|
||||
self.assertIsInstance(processor.video_processor, Sam2VideoVideoProcessor)
|
||||
|
||||
def test_image_processor_no_masks(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
|
||||
input_feat_extract = image_processor(image_input)
|
||||
input_processor = processor(images=image_input)
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
if key == "pixel_values":
|
||||
for input_feat_extract_item, input_processor_item in zip(
|
||||
input_feat_extract[key], input_processor[key]
|
||||
):
|
||||
np.testing.assert_array_equal(input_feat_extract_item, input_processor_item)
|
||||
else:
|
||||
self.assertEqual(input_feat_extract[key], input_processor[key])
|
||||
|
||||
for image in input_feat_extract.pixel_values:
|
||||
self.assertEqual(image.shape, (3, 1024, 1024))
|
||||
|
||||
for original_size in input_feat_extract.original_sizes:
|
||||
np.testing.assert_array_equal(original_size, np.array([30, 400]))
|
||||
|
||||
def test_image_processor_with_masks(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor)
|
||||
|
||||
image_input = self.prepare_image_inputs()
|
||||
mask_input = self.prepare_mask_inputs()
|
||||
|
||||
input_feat_extract = image_processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt")
|
||||
input_processor = processor(images=image_input, segmentation_maps=mask_input, return_tensors="pt")
|
||||
|
||||
for key in input_feat_extract.keys():
|
||||
self.assertAlmostEqual(input_feat_extract[key].sum(), input_processor[key].sum(), delta=1e-2)
|
||||
|
||||
for label in input_feat_extract.labels:
|
||||
self.assertEqual(label.shape, (256, 256))
|
||||
|
||||
@require_torch
|
||||
def test_post_process_masks(self):
|
||||
image_processor = self.get_image_processor()
|
||||
video_processor = self.get_video_processor()
|
||||
|
||||
processor = Sam2VideoProcessor(image_processor=image_processor, video_processor=video_processor)
|
||||
dummy_masks = [torch.ones((1, 3, 5, 5))]
|
||||
|
||||
original_sizes = [[1764, 2646]]
|
||||
|
||||
masks = processor.post_process_masks(dummy_masks, original_sizes)
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
masks = processor.post_process_masks(dummy_masks, torch.tensor(original_sizes))
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
# should also work with np
|
||||
dummy_masks = [np.ones((1, 3, 5, 5))]
|
||||
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes))
|
||||
|
||||
self.assertEqual(masks[0].shape, (1, 3, 1764, 2646))
|
||||
|
||||
dummy_masks = [[1, 0], [0, 1]]
|
||||
with self.assertRaises(ValueError):
|
||||
masks = processor.post_process_masks(dummy_masks, np.array(original_sizes))
|
||||
Reference in New Issue
Block a user