Sync from v0.13
This commit is contained in:
0
tests/multimodal/__init__.py
Normal file
0
tests/multimodal/__init__.py
Normal file
BIN
tests/multimodal/assets/corrupted.mp4
Normal file
BIN
tests/multimodal/assets/corrupted.mp4
Normal file
Binary file not shown.
BIN
tests/multimodal/assets/image1.png
Normal file
BIN
tests/multimodal/assets/image1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 KiB |
BIN
tests/multimodal/assets/image2.png
Normal file
BIN
tests/multimodal/assets/image2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 KiB |
BIN
tests/multimodal/assets/rgba.png
Normal file
BIN
tests/multimodal/assets/rgba.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 219 KiB |
139
tests/multimodal/test_audio.py
Normal file
139
tests/multimodal/test_audio.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# test_audio.py
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.audio import (
|
||||
AudioMediaIO,
|
||||
AudioResampler,
|
||||
resample_audio_librosa,
|
||||
resample_audio_scipy,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_audio():
|
||||
return np.array([0.0, 0.1, 0.2, 0.3, 0.4], dtype=float)
|
||||
|
||||
|
||||
def test_resample_audio_librosa(dummy_audio):
|
||||
with patch("vllm.multimodal.audio.librosa.resample") as mock_resample:
|
||||
mock_resample.return_value = dummy_audio * 2
|
||||
out = resample_audio_librosa(dummy_audio, orig_sr=44100, target_sr=22050)
|
||||
mock_resample.assert_called_once_with(
|
||||
dummy_audio, orig_sr=44100, target_sr=22050
|
||||
)
|
||||
assert np.all(out == dummy_audio * 2)
|
||||
|
||||
|
||||
def test_resample_audio_scipy(dummy_audio):
|
||||
out_down = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=2)
|
||||
out_up = resample_audio_scipy(dummy_audio, orig_sr=2, target_sr=4)
|
||||
out_same = resample_audio_scipy(dummy_audio, orig_sr=4, target_sr=4)
|
||||
|
||||
assert len(out_down) == 3
|
||||
assert len(out_up) == 10
|
||||
assert np.all(out_same == dummy_audio)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="resample_audio_scipy is buggy for non-integer ratios")
|
||||
def test_resample_audio_scipy_non_integer_ratio(dummy_audio):
|
||||
out = resample_audio_scipy(dummy_audio, orig_sr=5, target_sr=3)
|
||||
|
||||
expected_len = int(round(len(dummy_audio) * 3 / 5))
|
||||
assert len(out) == expected_len
|
||||
|
||||
assert isinstance(out, np.ndarray)
|
||||
assert np.isfinite(out).all()
|
||||
|
||||
|
||||
def test_audio_resampler_librosa_calls_resample(dummy_audio):
|
||||
resampler = AudioResampler(target_sr=22050, method="librosa")
|
||||
with patch("vllm.multimodal.audio.resample_audio_librosa") as mock_resample:
|
||||
mock_resample.return_value = dummy_audio
|
||||
out = resampler.resample(dummy_audio, orig_sr=44100)
|
||||
mock_resample.assert_called_once_with(
|
||||
dummy_audio, orig_sr=44100, target_sr=22050
|
||||
)
|
||||
assert np.all(out == dummy_audio)
|
||||
|
||||
|
||||
def test_audio_resampler_scipy_calls_resample(dummy_audio):
|
||||
resampler = AudioResampler(target_sr=22050, method="scipy")
|
||||
with patch("vllm.multimodal.audio.resample_audio_scipy") as mock_resample:
|
||||
mock_resample.return_value = dummy_audio
|
||||
out = resampler.resample(dummy_audio, orig_sr=44100)
|
||||
mock_resample.assert_called_once_with(
|
||||
dummy_audio, orig_sr=44100, target_sr=22050
|
||||
)
|
||||
assert np.all(out == dummy_audio)
|
||||
|
||||
|
||||
def test_audio_resampler_invalid_method(dummy_audio):
|
||||
resampler = AudioResampler(target_sr=22050, method="invalid")
|
||||
with pytest.raises(ValueError):
|
||||
resampler.resample(dummy_audio, orig_sr=44100)
|
||||
|
||||
|
||||
def test_audio_resampler_no_target_sr(dummy_audio):
|
||||
resampler = AudioResampler(target_sr=None)
|
||||
with pytest.raises(RuntimeError):
|
||||
resampler.resample(dummy_audio, orig_sr=44100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_audio_bytes():
|
||||
return b"FAKEAUDIOBYTES"
|
||||
|
||||
|
||||
def test_audio_media_io_load_bytes(dummy_audio_bytes):
|
||||
audio_io = AudioMediaIO()
|
||||
with patch("vllm.multimodal.audio.librosa.load") as mock_load:
|
||||
mock_load.return_value = (np.array([0.1, 0.2]), 16000)
|
||||
out = audio_io.load_bytes(dummy_audio_bytes)
|
||||
mock_load.assert_called_once()
|
||||
assert isinstance(out[0], np.ndarray)
|
||||
assert out[1] == 16000
|
||||
|
||||
|
||||
def test_audio_media_io_load_base64(dummy_audio_bytes):
|
||||
audio_io = AudioMediaIO()
|
||||
encoded = base64.b64encode(dummy_audio_bytes).decode("utf-8")
|
||||
with patch.object(AudioMediaIO, "load_bytes") as mock_load_bytes:
|
||||
mock_load_bytes.return_value = (np.array([0.1, 0.2]), 16000)
|
||||
out = audio_io.load_base64("audio/wav", encoded)
|
||||
mock_load_bytes.assert_called_once()
|
||||
assert isinstance(out[0], np.ndarray)
|
||||
assert out[1] == 16000
|
||||
|
||||
|
||||
def test_audio_media_io_load_file():
|
||||
audio_io = AudioMediaIO()
|
||||
path = Path("/fake/path.wav")
|
||||
with patch("vllm.multimodal.audio.librosa.load") as mock_load:
|
||||
mock_load.return_value = (np.array([0.1, 0.2]), 16000)
|
||||
out = audio_io.load_file(path)
|
||||
mock_load.assert_called_once_with(path, sr=None)
|
||||
assert isinstance(out[0], np.ndarray)
|
||||
assert out[1] == 16000
|
||||
|
||||
|
||||
def test_audio_media_io_encode_base64(dummy_audio):
|
||||
audio_io = AudioMediaIO()
|
||||
media = (dummy_audio, 16000)
|
||||
with patch("vllm.multimodal.audio.soundfile.write") as mock_write:
|
||||
|
||||
def write_to_buffer(buffer, *_args, **_kwargs):
|
||||
buffer.write(b"dummy_wav_data")
|
||||
|
||||
mock_write.side_effect = write_to_buffer
|
||||
|
||||
out = audio_io.encode_base64(media)
|
||||
decoded = base64.b64decode(out)
|
||||
assert decoded == b"dummy_wav_data"
|
||||
mock_write.assert_called_once()
|
||||
520
tests/multimodal/test_cache.py
Normal file
520
tests/multimodal/test_cache.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import (
|
||||
BaseMultiModalProcessorCache,
|
||||
BaseMultiModalReceiverCache,
|
||||
MultiModalCache,
|
||||
MultiModalProcessorCacheInItem,
|
||||
MultiModalProcessorCacheItem,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
MultiModalProcessorSenderCache,
|
||||
MultiModalReceiverCache,
|
||||
ShmObjectStoreReceiverCache,
|
||||
ShmObjectStoreSenderCache,
|
||||
engine_receiver_cache_from_config,
|
||||
processor_cache_from_config,
|
||||
)
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField,
|
||||
)
|
||||
from vllm.multimodal.processing import PromptInsertion
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _dummy_elem(
|
||||
modality: str,
|
||||
key: str,
|
||||
size: int,
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
):
|
||||
if rng is None:
|
||||
data = torch.empty((size,), dtype=torch.int8)
|
||||
else:
|
||||
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
|
||||
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=data,
|
||||
field=MultiModalSharedField(batch_size=1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(
|
||||
modality: str,
|
||||
size_by_key: dict[str, int],
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
):
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
|
||||
)
|
||||
|
||||
|
||||
def _dummy_items(
|
||||
size_by_key_modality: dict[str, dict[str, int]],
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq(
|
||||
[
|
||||
_dummy_item(modality, size_by_key, rng=rng)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("item", "expected_size"),
|
||||
[
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
],
|
||||
)
|
||||
def test_cache_item_size(item, expected_size):
|
||||
cache = MultiModalCache.get_lru_cache(2048, type(item))
|
||||
|
||||
cache[""] = item
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0)
|
||||
|
||||
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = item.get_data()
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
|
||||
def _create_vllm_config(
|
||||
*,
|
||||
mm_processor_cache_gb: float,
|
||||
enable_ipc: bool,
|
||||
):
|
||||
return VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||
),
|
||||
parallel_config=ParallelConfig(data_parallel_size=1 if enable_ipc else 2),
|
||||
)
|
||||
|
||||
|
||||
def _compare_caches(
|
||||
config_0: VllmConfig,
|
||||
config_1: VllmConfig,
|
||||
*,
|
||||
item_capacity: int = 8,
|
||||
hit_rate: float = 0.5,
|
||||
max_items_per_iter: int = 3,
|
||||
is_cached_calls_per_iter: int,
|
||||
n_iter: int = 100,
|
||||
seed: int = 0,
|
||||
):
|
||||
cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY)
|
||||
cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY)
|
||||
cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY)
|
||||
cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY)
|
||||
|
||||
cache_size_gb = max(
|
||||
config_0.model_config.multimodal_config.mm_processor_cache_gb,
|
||||
config_1.model_config.multimodal_config.mm_processor_cache_gb,
|
||||
)
|
||||
item_size_gb = int(cache_size_gb / item_capacity)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
all_items = [
|
||||
_dummy_item("item", {"key": item_size_gb}, rng=rng)
|
||||
for _ in range(int(item_capacity / hit_rate))
|
||||
]
|
||||
all_hashes = [
|
||||
MultiModalHasher.hash_kwargs(item=item.get_data()) for item in all_items
|
||||
]
|
||||
|
||||
prompt_update = PromptInsertion("dummy", "target", "insertion").resolve(0)
|
||||
|
||||
for it in range(n_iter):
|
||||
num_items_to_select = rng.randint(0, max_items_per_iter)
|
||||
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
|
||||
|
||||
selected_items = [all_items[idx] for idx in item_idxs_to_select]
|
||||
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
|
||||
|
||||
if cache_0_p0 is None:
|
||||
cache_0_p0_out = selected_items
|
||||
else:
|
||||
for _ in range(is_cached_calls_per_iter):
|
||||
cache_0_p0.is_cached(selected_hashes)
|
||||
|
||||
cache_0_p0_out = [
|
||||
item
|
||||
for item, _ in cache_0_p0.get_and_update(
|
||||
[(item, [prompt_update]) for item in selected_items],
|
||||
selected_hashes,
|
||||
)
|
||||
]
|
||||
|
||||
if cache_1_p0 is None:
|
||||
cache_1_p0_out = selected_items
|
||||
else:
|
||||
for _ in range(is_cached_calls_per_iter):
|
||||
cache_1_p0.is_cached(selected_hashes)
|
||||
|
||||
cache_1_p0_out = [
|
||||
item
|
||||
for item, _ in cache_1_p0.get_and_update(
|
||||
[(item, [prompt_update]) for item in selected_items],
|
||||
selected_hashes,
|
||||
)
|
||||
]
|
||||
|
||||
if cache_0_p1 is None:
|
||||
cache_0_p1_out = cache_0_p0_out
|
||||
else:
|
||||
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, selected_hashes)
|
||||
|
||||
if cache_1_p1 is None:
|
||||
cache_1_p1_out = cache_1_p0_out
|
||||
else:
|
||||
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, selected_hashes)
|
||||
|
||||
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
|
||||
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
|
||||
cache_size_gb = 1 / (1 << 20)
|
||||
|
||||
vllm_config_ipc_enabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=cache_size_gb,
|
||||
enable_ipc=True,
|
||||
)
|
||||
vllm_config_ipc_disabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=0,
|
||||
enable_ipc=False,
|
||||
)
|
||||
vllm_config_cache_disabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=cache_size_gb,
|
||||
enable_ipc=True,
|
||||
)
|
||||
|
||||
_compare_caches(
|
||||
vllm_config_ipc_enabled,
|
||||
vllm_config_ipc_disabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
_compare_caches(
|
||||
vllm_config_ipc_disabled,
|
||||
vllm_config_cache_disabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
_compare_caches(
|
||||
vllm_config_cache_disabled,
|
||||
vllm_config_ipc_enabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
|
||||
|
||||
def _run_test_cache_eviction_lru(
|
||||
p0_cache: BaseMultiModalProcessorCache,
|
||||
p1_cache: BaseMultiModalReceiverCache,
|
||||
base_item_size: int,
|
||||
):
|
||||
request1_hashes = [
|
||||
"image_A",
|
||||
"image_B",
|
||||
"image_C",
|
||||
]
|
||||
request1_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
|
||||
for h in request1_hashes
|
||||
}
|
||||
|
||||
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
|
||||
request2_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
|
||||
for h in request2_hashes
|
||||
}
|
||||
|
||||
##########################
|
||||
# STEP 1: Request 1 send
|
||||
##########################
|
||||
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
|
||||
# Cache is empty
|
||||
assert sender_is_cached_item_req1 == [False, False, False]
|
||||
|
||||
# Touch all mm hash for P0 Cache before process
|
||||
for mm_hash in request1_hashes:
|
||||
p0_cache.touch_sender_cache_item(mm_hash)
|
||||
|
||||
###########################
|
||||
# Process request 1 for P0 Cache
|
||||
###########################
|
||||
item_tuple: MultiModalProcessorCacheInItem
|
||||
for i, h in enumerate(request1_hashes):
|
||||
# Use precomputed cache state
|
||||
is_cached = sender_is_cached_item_req1[i]
|
||||
item_tuple = (request1_items[h], []) if not is_cached else None
|
||||
print(f"Request 1: key={h} | cached={is_cached}")
|
||||
|
||||
p0_cache.get_and_update_item(item_tuple, h)
|
||||
|
||||
###########################
|
||||
# Process request 1 for P1 Cache
|
||||
###########################
|
||||
# Touch all mm hash for P1 Cache before process
|
||||
for mm_hash in request1_hashes:
|
||||
p1_cache.touch_receiver_cache_item(mm_hash)
|
||||
|
||||
for h in request1_hashes:
|
||||
p1_cache.get_and_update_item(request1_items[h], h)
|
||||
|
||||
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||
assert list(p0_cache._cache.order) == expected_hashes
|
||||
|
||||
##########################
|
||||
# STEP 2: Request 2 send
|
||||
##########################
|
||||
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
|
||||
assert sender_is_cached_item_req2 == [False, False, True, True]
|
||||
|
||||
# Touch all mm hash for P0 Cache before process
|
||||
for mm_hash in request2_hashes:
|
||||
p0_cache.touch_sender_cache_item(mm_hash)
|
||||
|
||||
###########################
|
||||
# Process request 2 for P0 Cache
|
||||
###########################
|
||||
for i, h in enumerate(request2_hashes):
|
||||
# Use precomputed cache state again
|
||||
is_cached = sender_is_cached_item_req2[i]
|
||||
item_tuple = (request2_items[h], []) if not is_cached else None
|
||||
print(f"Request 2: key={h} | cached={is_cached}")
|
||||
|
||||
p0_cache.get_and_update_item(item_tuple, h)
|
||||
|
||||
###########################
|
||||
# Process request 2 for P1 Cache
|
||||
###########################
|
||||
|
||||
# Touch all mm hash for P1 Cache before process
|
||||
for mm_hash in request2_hashes:
|
||||
p1_cache.touch_receiver_cache_item(mm_hash)
|
||||
|
||||
for h in request2_hashes:
|
||||
p1_cache.get_and_update_item(request2_items[h], h)
|
||||
|
||||
expected_hashes = ["image_D", "image_E", "image_A", "image_C"]
|
||||
assert list(p0_cache._cache.order) == expected_hashes
|
||||
|
||||
|
||||
def test_cache_eviction_lru_cache():
|
||||
model_config = ModelConfig(
|
||||
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
mm_processor_cache_gb=6 / GiB_bytes,
|
||||
)
|
||||
sender_cache = MultiModalProcessorSenderCache(model_config)
|
||||
receiver_cache = MultiModalReceiverCache(model_config)
|
||||
|
||||
_run_test_cache_eviction_lru(sender_cache, receiver_cache, base_item_size=1)
|
||||
|
||||
|
||||
# This test verifies shared-memory cache eviction behavior across processor (p0)
|
||||
# and receiver (p1) caches.
|
||||
# Flow summary:
|
||||
# 1. Request 1 adds images A, B, C — completely filling the cache.
|
||||
# 2. Request 2 tries to add image_G and image_A, but image_G cannot be added because
|
||||
# cache is full and A is protected from eviction — cache remains unchanged.
|
||||
# 3. Request 3 adds image_G, image_H, image_I and image_B
|
||||
# this time, image_A is evicted, freeing 5MB space
|
||||
# and image_G, image_H successfully fits,
|
||||
# image_B is protected from eviction then image_i cannot be added.
|
||||
# This proving normal eviction and reuse behavior.
|
||||
def _run_test_cache_eviction_shm(
|
||||
p0_cache: BaseMultiModalProcessorCache,
|
||||
p1_cache: BaseMultiModalReceiverCache,
|
||||
base_item_size: int,
|
||||
):
|
||||
request1_hashes = ["image_A", "image_B", "image_C"]
|
||||
request1_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
|
||||
for h in request1_hashes
|
||||
}
|
||||
request1_items_p0_result = []
|
||||
|
||||
request2_hashes = ["image_G", "image_A"]
|
||||
request2_items = {
|
||||
h: MultiModalKwargsItem.dummy(
|
||||
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||
)
|
||||
for h in request2_hashes
|
||||
}
|
||||
request2_items_p0_result = []
|
||||
|
||||
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
|
||||
request3_items = {
|
||||
h: MultiModalKwargsItem.dummy(
|
||||
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||
)
|
||||
for h in request3_hashes
|
||||
}
|
||||
request3_items_p0_result = []
|
||||
|
||||
##########################
|
||||
# STEP 1: Request 1 send
|
||||
# This will fill up the cache
|
||||
##########################
|
||||
sender_is_cached_item_req1 = p0_cache.is_cached(request1_hashes)
|
||||
# Cache is empty
|
||||
assert sender_is_cached_item_req1 == [False, False, False]
|
||||
|
||||
# Touch all mm hash for P0 Cache before process
|
||||
for mm_hash in request1_hashes:
|
||||
p0_cache.touch_sender_cache_item(mm_hash)
|
||||
|
||||
###########################
|
||||
# Process request 1 for P0 Cache
|
||||
###########################
|
||||
item_tuple: MultiModalProcessorCacheInItem
|
||||
for i, h in enumerate(request1_hashes):
|
||||
# Use precomputed cache state
|
||||
is_cached = sender_is_cached_item_req1[i]
|
||||
item_tuple = (request1_items[h], []) if not is_cached else None
|
||||
print(f"Request 1: key={h} | cached={is_cached}")
|
||||
|
||||
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||
# Only get mm item, ignore prompt update result
|
||||
request1_items_p0_result.append(p0_result[0])
|
||||
|
||||
###########################
|
||||
# Process request 1 for P1 Cache
|
||||
###########################
|
||||
# Touch all mm hash for P1 Cache before process
|
||||
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
|
||||
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||
|
||||
for mm_hash, mm_item in zip(request1_hashes, request1_items_p0_result):
|
||||
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||
|
||||
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||
|
||||
##########################
|
||||
# STEP 2: Request 2 send
|
||||
# There is no eviction because image_A is protected
|
||||
# No new item can add to cache
|
||||
##########################
|
||||
sender_is_cached_item_req2 = p0_cache.is_cached(request2_hashes)
|
||||
assert sender_is_cached_item_req2 == [False, True]
|
||||
|
||||
# Touch all mm hash for P0 Cache before process
|
||||
for mm_hash in request2_hashes:
|
||||
p0_cache.touch_sender_cache_item(mm_hash)
|
||||
|
||||
###########################
|
||||
# Process request 2 for P0 Cache
|
||||
###########################
|
||||
for i, h in enumerate(request2_hashes):
|
||||
# Use precomputed cache state again
|
||||
is_cached = sender_is_cached_item_req2[i]
|
||||
item_tuple = (request2_items[h], []) if not is_cached else None
|
||||
print(f"Request 2: key={h} | cached={is_cached}")
|
||||
|
||||
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||
# Only get mm item, ignore prompt update result
|
||||
request2_items_p0_result.append(p0_result[0])
|
||||
|
||||
# image_A cannot be evict then
|
||||
# image_G will fail to allocate anyway and image_A still in cache
|
||||
assert p0_cache.is_cached(request2_hashes) == [False, True]
|
||||
|
||||
###########################
|
||||
# Process request 2 for P1 Cache
|
||||
###########################
|
||||
|
||||
# Touch all mm hash for P1 Cache before process
|
||||
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
|
||||
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||
|
||||
for mm_hash, mm_item in zip(request2_hashes, request2_items_p0_result):
|
||||
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||
|
||||
# Prove that cache state is unchanged
|
||||
expected_hashes = ["image_A", "image_B", "image_C"]
|
||||
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||
|
||||
##########################
|
||||
# STEP 3: Request 3 send
|
||||
##########################
|
||||
##### Prove that cache eviction work normally
|
||||
sender_is_cached_item_req3 = p0_cache.is_cached(request3_hashes)
|
||||
assert sender_is_cached_item_req3 == [False, False, False, True]
|
||||
|
||||
# Touch all mm hash for P0 Cache before process
|
||||
for mm_hash in request3_hashes:
|
||||
p0_cache.touch_sender_cache_item(mm_hash)
|
||||
|
||||
###########################
|
||||
# Process request 3 for P0 Cache
|
||||
###########################
|
||||
for i, h in enumerate(request3_hashes):
|
||||
# Use precomputed cache state again
|
||||
is_cached = sender_is_cached_item_req3[i]
|
||||
item_tuple = (request3_items[h], []) if not is_cached else None
|
||||
print(f"Request 3: key={h} | cached={is_cached}")
|
||||
p0_result = p0_cache.get_and_update_item(item_tuple, h)
|
||||
# Only get mm item, ignore prompt update result
|
||||
request3_items_p0_result.append(p0_result[0])
|
||||
|
||||
# image_A got evict and image_G add to cache
|
||||
# image_B is still protected
|
||||
# image_G, image_H fit but image_I cannot fit
|
||||
assert p0_cache.is_cached(request3_hashes) == [True, True, False, True]
|
||||
|
||||
###########################
|
||||
# Process request 3 for P1 Cache
|
||||
###########################
|
||||
|
||||
# Touch all mm hash for P1 Cache before process
|
||||
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
|
||||
p1_cache.touch_receiver_cache_item(mm_hash, mm_item)
|
||||
|
||||
for mm_hash, mm_item in zip(request3_hashes, request3_items_p0_result):
|
||||
p1_cache.get_and_update_item(mm_item, mm_hash)
|
||||
|
||||
expected_hashes = ["image_B", "image_C", "image_G", "image_H"]
|
||||
assert list(p0_cache._shm_cache.key_index.keys()) == expected_hashes
|
||||
|
||||
|
||||
def test_cache_eviction_shm_cache():
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
mm_processor_cache_type="shm",
|
||||
mm_shm_cache_max_object_size_mb=6,
|
||||
mm_processor_cache_gb=15.2 * MiB_bytes / GiB_bytes,
|
||||
),
|
||||
)
|
||||
sender_cache = ShmObjectStoreSenderCache(vllm_config)
|
||||
receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock())
|
||||
|
||||
_run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes)
|
||||
95
tests/multimodal/test_hasher.py
Normal file
95
tests/multimodal/test_hasher.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent / "assets"
|
||||
assert ASSETS_DIR.exists()
|
||||
|
||||
|
||||
# NOTE: Images that are the same visually are allowed to have the same hash
|
||||
@pytest.mark.parametrize("mode_pair", [("1", "L"), ("RGBA", "CMYK")])
|
||||
def test_hash_collision_image_mode(mode_pair):
|
||||
mode1, mode2 = mode_pair
|
||||
image1 = Image.new(mode1, size=(10, 10), color=1)
|
||||
image2 = Image.new(mode2, size=(10, 10), color=1)
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
|
||||
|
||||
|
||||
def test_hash_collision_image_palette():
|
||||
# These images differ only in Image.palette._palette
|
||||
image1 = Image.open(ASSETS_DIR / "image1.png")
|
||||
image2 = Image.open(ASSETS_DIR / "image2.png")
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
|
||||
|
||||
|
||||
def test_hash_collision_image_transpose():
|
||||
image1 = Image.new("1", size=(10, 20))
|
||||
ImageDraw.Draw(image1).line([(0, 0), (10, 0)])
|
||||
|
||||
image2 = Image.new("1", size=(20, 10))
|
||||
ImageDraw.Draw(image2).line([(0, 0), (0, 10)])
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
def test_hash_collision_tensor_shape(dtype):
|
||||
# The hash should be different though the data is the same when flattened
|
||||
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
|
||||
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
|
||||
|
||||
|
||||
def test_hash_collision_array_shape():
|
||||
# The hash should be different though the data is the same when flattened
|
||||
arr1 = np.zeros((5, 10, 20, 3))
|
||||
arr2 = np.zeros((10, 20, 5, 3))
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
|
||||
|
||||
|
||||
def test_hash_non_contiguous_array():
|
||||
arr = np.arange(24).reshape(4, 6).T
|
||||
assert not arr.flags.c_contiguous
|
||||
|
||||
arr_c = np.ascontiguousarray(arr)
|
||||
assert arr_c.flags.c_contiguous
|
||||
|
||||
hasher = MultiModalHasher
|
||||
# Both should be hashable and produce the same hashes
|
||||
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
|
||||
|
||||
|
||||
def test_hash_image_exif_id():
|
||||
# Test that EXIF ImageId tag can be used to store UUID
|
||||
# and the hasher will use that instead of the image data.
|
||||
image1 = image2 = Image.new("1", size=(10, 20))
|
||||
id = uuid.uuid4()
|
||||
image1.getexif()[Image.ExifTags.Base.ImageID] = id
|
||||
image2 = Image.open(ASSETS_DIR / "image1.png")
|
||||
image2.getexif()[Image.ExifTags.Base.ImageID] = "Not a UUID"
|
||||
image2a = Image.open(ASSETS_DIR / "image1.png")
|
||||
|
||||
hasher = MultiModalHasher
|
||||
# first image has UUID in ImageID, so it should hash to that UUID
|
||||
assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(image=id.bytes)
|
||||
# second image has non-UUID in ImageID, so it should hash to the image data
|
||||
assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(image=image2a)
|
||||
159
tests/multimodal/test_image.py
Normal file
159
tests/multimodal/test_image.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import ImageMediaIO, convert_image_mode
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent / "assets"
|
||||
assert ASSETS_DIR.exists()
|
||||
|
||||
|
||||
def test_rgb_to_rgb():
|
||||
# Start with an RGB image.
|
||||
original_image = Image.open(ASSETS_DIR / "image1.png").convert("RGB")
|
||||
converted_image = convert_image_mode(original_image, "RGB")
|
||||
|
||||
# RGB to RGB should be a no-op.
|
||||
diff = ImageChops.difference(original_image, converted_image)
|
||||
assert diff.getbbox() is None
|
||||
|
||||
|
||||
def test_rgba_to_rgb():
|
||||
original_image = Image.open(ASSETS_DIR / "rgba.png")
|
||||
original_image_numpy = np.array(original_image)
|
||||
|
||||
converted_image = convert_image_mode(original_image, "RGB")
|
||||
converted_image_numpy = np.array(converted_image)
|
||||
|
||||
for i in range(original_image_numpy.shape[0]):
|
||||
for j in range(original_image_numpy.shape[1]):
|
||||
# Verify that all transparent pixels are converted to white.
|
||||
if original_image_numpy[i][j][3] == 0:
|
||||
assert converted_image_numpy[i][j][0] == 255
|
||||
assert converted_image_numpy[i][j][1] == 255
|
||||
assert converted_image_numpy[i][j][2] == 255
|
||||
|
||||
|
||||
def test_rgba_to_rgb_custom_background(tmp_path):
|
||||
"""Test RGBA to RGB conversion with custom background colors."""
|
||||
# Create a simple RGBA image with transparent and opaque pixels
|
||||
rgba_image = Image.new("RGBA", (10, 10), (255, 0, 0, 255)) # Red with full opacity
|
||||
|
||||
# Make top-left quadrant transparent
|
||||
for i in range(5):
|
||||
for j in range(5):
|
||||
rgba_image.putpixel((i, j), (0, 0, 0, 0)) # Fully transparent
|
||||
|
||||
# Save the test image to tmp_path
|
||||
test_image_path = tmp_path / "test_rgba.png"
|
||||
rgba_image.save(test_image_path)
|
||||
|
||||
# Test 1: Default white background (backward compatibility)
|
||||
image_io_default = ImageMediaIO()
|
||||
converted_default = image_io_default.load_file(test_image_path)
|
||||
default_numpy = np.array(converted_default)
|
||||
|
||||
# Check transparent pixels are white
|
||||
assert default_numpy[0][0][0] == 255 # R
|
||||
assert default_numpy[0][0][1] == 255 # G
|
||||
assert default_numpy[0][0][2] == 255 # B
|
||||
# Check opaque pixels remain red
|
||||
assert default_numpy[5][5][0] == 255 # R
|
||||
assert default_numpy[5][5][1] == 0 # G
|
||||
assert default_numpy[5][5][2] == 0 # B
|
||||
|
||||
# Test 2: Custom black background via kwargs
|
||||
image_io_black = ImageMediaIO(rgba_background_color=(0, 0, 0))
|
||||
converted_black = image_io_black.load_file(test_image_path)
|
||||
black_numpy = np.array(converted_black)
|
||||
|
||||
# Check transparent pixels are black
|
||||
assert black_numpy[0][0][0] == 0 # R
|
||||
assert black_numpy[0][0][1] == 0 # G
|
||||
assert black_numpy[0][0][2] == 0 # B
|
||||
# Check opaque pixels remain red
|
||||
assert black_numpy[5][5][0] == 255 # R
|
||||
assert black_numpy[5][5][1] == 0 # G
|
||||
assert black_numpy[5][5][2] == 0 # B
|
||||
|
||||
# Test 3: Custom blue background via kwargs (as list)
|
||||
image_io_blue = ImageMediaIO(rgba_background_color=[0, 0, 255])
|
||||
converted_blue = image_io_blue.load_file(test_image_path)
|
||||
blue_numpy = np.array(converted_blue)
|
||||
|
||||
# Check transparent pixels are blue
|
||||
assert blue_numpy[0][0][0] == 0 # R
|
||||
assert blue_numpy[0][0][1] == 0 # G
|
||||
assert blue_numpy[0][0][2] == 255 # B
|
||||
|
||||
# Test 4: Test with load_bytes method
|
||||
with open(test_image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
image_io_green = ImageMediaIO(rgba_background_color=(0, 255, 0))
|
||||
converted_green = image_io_green.load_bytes(image_data)
|
||||
green_numpy = np.array(converted_green)
|
||||
|
||||
# Check transparent pixels are green
|
||||
assert green_numpy[0][0][0] == 0 # R
|
||||
assert green_numpy[0][0][1] == 255 # G
|
||||
assert green_numpy[0][0][2] == 0 # B
|
||||
|
||||
|
||||
def test_rgba_background_color_validation():
|
||||
"""Test that invalid rgba_background_color values are properly rejected."""
|
||||
|
||||
# Test invalid types
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color="255,255,255")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=255)
|
||||
|
||||
# Test wrong number of elements
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(255, 255))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(255, 255, 255, 255))
|
||||
|
||||
# Test non-integer values
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(255.0, 255.0, 255.0))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(255, "255", 255))
|
||||
|
||||
# Test out of range values
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(256, 255, 255))
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="rgba_background_color must be a list or tuple"
|
||||
):
|
||||
ImageMediaIO(rgba_background_color=(255, -1, 255))
|
||||
|
||||
# Test that valid values work
|
||||
ImageMediaIO(rgba_background_color=(0, 0, 0)) # Should not raise
|
||||
ImageMediaIO(rgba_background_color=[255, 255, 255]) # Should not raise
|
||||
ImageMediaIO(rgba_background_color=(128, 128, 128)) # Should not raise
|
||||
1087
tests/multimodal/test_processing.py
Normal file
1087
tests/multimodal/test_processing.py
Normal file
File diff suppressed because it is too large
Load Diff
34
tests/multimodal/test_registry.py
Normal file
34
tests/multimodal/test_registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for MultiModalRegistry.supports_multimodal_inputs and
|
||||
Qwen2.5-VL visual component loading behavior.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ..models.utils import build_model_context
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_id,limit_mm_per_prompt,expected",
|
||||
[
|
||||
("Qwen/Qwen2-0.5B-Instruct", {}, False),
|
||||
("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
|
||||
("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0, "video": 0}, False),
|
||||
("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0}, True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.core_model
|
||||
def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
|
||||
"""Test supports_multimodal_inputs returns correct boolean for various
|
||||
configs."""
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(ctx.model_config) is expected
|
||||
134
tests/multimodal/test_sparse_tensor_validation_unit.py
Normal file
134
tests/multimodal/test_sparse_tensor_validation_unit.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for sparse tensor validation.
|
||||
|
||||
Simple, fast unit tests that can run without server fixtures.
|
||||
Run with: pytest tests/multimodal/test_sparse_tensor_validation_unit.py -v
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
class TestSparseTensorValidationContextManager:
|
||||
"""Test that torch.sparse.check_sparse_tensor_invariants() works as expected."""
|
||||
|
||||
def test_valid_sparse_tensor_passes(self):
|
||||
"""Valid sparse tensors should pass validation."""
|
||||
indices = torch.tensor([[0, 1], [0, 1]])
|
||||
values = torch.tensor([1.0, 2.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
dense = tensor.to_dense()
|
||||
|
||||
assert dense.shape == shape
|
||||
|
||||
def test_out_of_bounds_indices_rejected(self):
|
||||
"""Sparse tensors with out-of-bounds indices should be rejected."""
|
||||
indices = torch.tensor([[5], [5]]) # Out of bounds for 2x2
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info: # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
tensor.to_dense()
|
||||
|
||||
assert (
|
||||
"index" in str(exc_info.value).lower()
|
||||
or "bound" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
def test_negative_indices_rejected(self):
|
||||
"""Sparse tensors with negative indices should be rejected."""
|
||||
indices = torch.tensor([[-1], [0]])
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
with pytest.raises(RuntimeError): # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
tensor.to_dense()
|
||||
|
||||
def test_without_context_manager_allows_invalid(self):
|
||||
"""
|
||||
WITHOUT validation, invalid tensors may not immediately error.
|
||||
|
||||
This demonstrates the vulnerability: PyTorch 2.8.0+ doesn't validate
|
||||
by default, which can lead to memory corruption.
|
||||
"""
|
||||
indices = torch.tensor([[100], [100]]) # Way out of bounds
|
||||
values = torch.tensor([1.0])
|
||||
shape = (2, 2)
|
||||
|
||||
# Without validation context, this might create an invalid tensor
|
||||
# (actual behavior depends on PyTorch version)
|
||||
tensor = torch.sparse_coo_tensor(indices, values, shape)
|
||||
|
||||
# The tensor object is created, but it's invalid
|
||||
assert tensor.is_sparse
|
||||
|
||||
|
||||
class TestTorchLoadWithValidation:
|
||||
"""Test torch.load() with sparse tensor validation."""
|
||||
|
||||
def test_load_valid_sparse_tensor_with_validation(self):
|
||||
"""Valid sparse tensors should load successfully with validation."""
|
||||
# Create and save a valid sparse tensor
|
||||
indices = torch.tensor([[0, 1], [0, 1]])
|
||||
values = torch.tensor([1.0, 2.0])
|
||||
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
dense = loaded.to_dense()
|
||||
|
||||
assert dense.shape == (2, 2)
|
||||
|
||||
def test_load_invalid_sparse_tensor_rejected(self):
|
||||
"""Invalid sparse tensors should be caught when loaded with validation."""
|
||||
# Create an invalid sparse tensor (out of bounds)
|
||||
indices = torch.tensor([[10], [10]])
|
||||
values = torch.tensor([1.0])
|
||||
tensor = torch.sparse_coo_tensor(indices, values, (2, 2))
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation - should fail on to_dense()
|
||||
with pytest.raises(RuntimeError): # noqa: SIM117
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
loaded.to_dense()
|
||||
|
||||
def test_load_dense_tensor_unaffected(self):
|
||||
"""Dense tensors should work normally with the validation context."""
|
||||
# Create and save a dense tensor
|
||||
tensor = torch.randn(10, 20)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
buffer.seek(0)
|
||||
|
||||
# Load with validation (should have no effect on dense tensors)
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
loaded = torch.load(buffer, weights_only=True)
|
||||
|
||||
assert loaded.shape == (10, 20)
|
||||
assert not loaded.is_sparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running directly for quick testing
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
531
tests/multimodal/test_utils.py
Normal file
531
tests/multimodal/test_utils.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import MediaConnector, argsort_mm_positions
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_ASSETS = [
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
"Grayscale_8bits_palette_sample_image.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/Grayscale_8bits_palette_sample_image.png",
|
||||
"1280px-Venn_diagram_rgb.svg.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/1280px-Venn_diagram_rgb.svg.png",
|
||||
"RGBA_comp.png", # "https://vllm-public-assets.s3.us-west-2.amazonaws.com/vision_model_images/RGBA_comp.png",
|
||||
]
|
||||
|
||||
TEST_VIDEO_URLS = [
|
||||
"https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4",
|
||||
"https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def url_images(local_asset_server) -> dict[str, Image.Image]:
|
||||
return {
|
||||
image_url: local_asset_server.get_image_asset(image_url)
|
||||
for image_url in TEST_IMAGE_ASSETS
|
||||
}
|
||||
|
||||
|
||||
def get_supported_suffixes() -> tuple[str, ...]:
|
||||
# We should at least test the file types mentioned in GPT-4 with Vision
|
||||
OPENAI_SUPPORTED_SUFFIXES = (".png", ".jpeg", ".jpg", ".webp", ".gif")
|
||||
|
||||
# Additional file types that are supported by us
|
||||
EXTRA_SUPPORTED_SUFFIXES = (".bmp", ".tiff")
|
||||
|
||||
return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
|
||||
|
||||
|
||||
def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||
return (np.asarray(a) == np.asarray(convert_image_mode(b, a.mode))).all()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_fetch_image_http(image_url: str):
|
||||
connector = MediaConnector()
|
||||
|
||||
image_sync = connector.fetch_image(image_url)
|
||||
image_async = await connector.fetch_image_async(image_url)
|
||||
assert _image_equals(image_sync, image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(
|
||||
url_images: dict[str, Image.Image], raw_image_url: str, suffix: str
|
||||
):
|
||||
connector = MediaConnector(
|
||||
# Domain restriction should not apply to data URLs.
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
]
|
||||
)
|
||||
url_image = url_images[raw_image_url]
|
||||
|
||||
try:
|
||||
mime_type = Image.MIME[Image.registered_extensions()[suffix]]
|
||||
except KeyError:
|
||||
try:
|
||||
mime_type = mimetypes.types_map[suffix]
|
||||
except KeyError:
|
||||
pytest.skip("No MIME type")
|
||||
|
||||
with NamedTemporaryFile(suffix=suffix) as f:
|
||||
try:
|
||||
url_image.save(f.name)
|
||||
except Exception as e:
|
||||
if e.args[0] == "cannot write mode RGBA as JPEG":
|
||||
pytest.skip("Conversion not supported")
|
||||
|
||||
raise
|
||||
|
||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
data_image_sync = connector.fetch_image(data_url)
|
||||
if _image_equals(url_image, Image.open(f)):
|
||||
assert _image_equals(url_image, data_image_sync)
|
||||
else:
|
||||
pass # Lossy format; only check that image can be opened
|
||||
|
||||
data_image_async = await connector.fetch_image_async(data_url)
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
|
||||
async def test_fetch_image_local_files(image_url: str):
|
||||
connector = MediaConnector()
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||
|
||||
origin_image = connector.fetch_image(image_url)
|
||||
origin_image.save(
|
||||
os.path.join(temp_dir, os.path.basename(image_url)),
|
||||
quality=100,
|
||||
icc_profile=origin_image.info.get("icc_profile"),
|
||||
)
|
||||
|
||||
image_async = await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}"
|
||||
)
|
||||
image_sync = local_connector.fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}"
|
||||
)
|
||||
# Check that the images are equal
|
||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
await connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
local_connector.fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
connector.fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", [TEST_IMAGE_ASSETS[0]], indirect=True)
|
||||
async def test_fetch_image_local_files_with_space_in_name(image_url: str):
|
||||
connector = MediaConnector()
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||
|
||||
origin_image = connector.fetch_image(image_url)
|
||||
filename = "file name with space.jpg"
|
||||
origin_image.save(
|
||||
os.path.join(temp_dir, filename),
|
||||
quality=100,
|
||||
icc_profile=origin_image.info.get("icc_profile"),
|
||||
)
|
||||
|
||||
try:
|
||||
image_async = await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/{filename}"
|
||||
)
|
||||
image_sync = local_connector.fetch_image(f"file://{temp_dir}/{filename}")
|
||||
except FileNotFoundError as e:
|
||||
pytest.fail("Failed to fetch image with space in name: {}".format(e))
|
||||
# Check that the images are equal
|
||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_image_error_conversion():
|
||||
connector = MediaConnector()
|
||||
broken_img = "data:image/png;base64,aGVsbG9fdmxsbV9jb21tdW5pdHkK"
|
||||
|
||||
# PIL.UnidentifiedImageError should be converted to ValueError
|
||||
with pytest.raises(ValueError):
|
||||
await connector.fetch_image_async(broken_img)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
connector.fetch_image(broken_img)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3, reruns_delay=5)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_fetch_video_http(video_url: str, num_frames: int):
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={
|
||||
"video": {
|
||||
"num_frames": num_frames,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
|
||||
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
|
||||
@pytest.mark.parametrize("requested_fps", [2, 24])
|
||||
async def test_fetch_video_http_with_dynamic_loader(
|
||||
video_url: str,
|
||||
max_duration: int,
|
||||
requested_fps: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={
|
||||
"video": {
|
||||
"max_duration": max_duration,
|
||||
"requested_fps": requested_fps,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
assert metadata_sync["video_backend"] == "opencv_dynamic"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
# Single modality
|
||||
## Internally sorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
]
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
## Internally unsorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
]
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 1),
|
||||
("image", 0),
|
||||
],
|
||||
),
|
||||
# Two modalities
|
||||
## Internally sorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("audio", 0),
|
||||
("audio", 1),
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
## Interleaved, internally sorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("audio", 0),
|
||||
("image", 1),
|
||||
("audio", 1),
|
||||
],
|
||||
),
|
||||
## Interleaved, internally unsorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 1),
|
||||
("audio", 1),
|
||||
("image", 0),
|
||||
("audio", 0),
|
||||
],
|
||||
),
|
||||
# Three modalities
|
||||
## Internally sorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=15, length=7),
|
||||
PlaceholderRange(offset=22, length=8),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=3, length=4),
|
||||
PlaceholderRange(offset=7, length=5),
|
||||
PlaceholderRange(offset=12, length=6),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("video", 1),
|
||||
("video", 2),
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
## Interleaved, internally sorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 1),
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("image", 2),
|
||||
],
|
||||
),
|
||||
## Interleaved, internally unsorted
|
||||
dict(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
],
|
||||
},
|
||||
expected_modality_idxs=[
|
||||
("image", 0),
|
||||
("image", 2),
|
||||
("audio", 0),
|
||||
("video", 0),
|
||||
("image", 1),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_argsort_mm_positions(case):
|
||||
mm_positions = case["mm_positions"]
|
||||
expected_modality_idxs = case["expected_modality_idxs"]
|
||||
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, 5),
|
||||
(torch.tensor([True, True, True, True, True]), 5),
|
||||
(torch.tensor([False, False, False, False, False]), 0),
|
||||
(torch.tensor([True, False, True, False, True]), 3),
|
||||
(torch.tensor([True]), 1),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_num_embeds == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, None),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
torch.tensor([0, 1, 1, 2, 3]),
|
||||
),
|
||||
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
|
||||
if expected is None:
|
||||
assert pr.embeds_cumsum is None
|
||||
return
|
||||
|
||||
assert torch.equal(pr.embeds_cumsum, expected)
|
||||
# cached_property should return the same object on repeated access
|
||||
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,start_idx,end_idx,expected",
|
||||
[
|
||||
(None, 2, 4, (2, 4)),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
3,
|
||||
5,
|
||||
(1, 3),
|
||||
),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
0,
|
||||
2,
|
||||
(0, 1),
|
||||
),
|
||||
(
|
||||
torch.tensor([True, False, True, False]),
|
||||
2,
|
||||
2,
|
||||
(1, 1),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_embeds_indices_in_range(
|
||||
is_embed, start_idx, end_idx, expected
|
||||
):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offset,is_embed,expected",
|
||||
[
|
||||
(0, None, [(0, 4)]),
|
||||
(
|
||||
2,
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
[(3, 3), (5, 6)],
|
||||
),
|
||||
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||
(0, torch.tensor([False, False, False, False]), []),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||
assert pr.extract_embeds_range() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_allowed_media_domains(video_url: str, num_frames: int):
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={
|
||||
"video": {
|
||||
"num_frames": num_frames,
|
||||
}
|
||||
},
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
],
|
||||
)
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = connector.fetch_video(disallowed_url)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = await connector.fetch_video_async(disallowed_url)
|
||||
301
tests/multimodal/test_video.py
Normal file
301
tests/multimodal/test_video.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.base import get_vllm_public_assets
|
||||
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
|
||||
from vllm.multimodal.image import ImageMediaIO
|
||||
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO
|
||||
|
||||
from .utils import cosine_similarity, create_video_from_image, normalize_image
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
ASSETS_DIR = Path(__file__).parent / "assets"
|
||||
NUM_FRAMES = 10
|
||||
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
|
||||
class TestVideoLoader1(VideoLoader):
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
return FAKE_OUTPUT_1
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
|
||||
class TestVideoLoader2(VideoLoader):
|
||||
@classmethod
|
||||
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
|
||||
return FAKE_OUTPUT_2
|
||||
|
||||
|
||||
def test_video_loader_registry():
|
||||
custom_loader_1 = VIDEO_LOADER_REGISTRY.load("test_video_loader_1")
|
||||
output_1 = custom_loader_1.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(output_1, FAKE_OUTPUT_1)
|
||||
|
||||
custom_loader_2 = VIDEO_LOADER_REGISTRY.load("test_video_loader_2")
|
||||
output_2 = custom_loader_2.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(output_2, FAKE_OUTPUT_2)
|
||||
|
||||
|
||||
def test_video_loader_type_doesnt_exist():
|
||||
with pytest.raises(AssertionError):
|
||||
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
|
||||
class Assert10Frames1FPSVideoLoader(VideoLoader):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, fps: float = -1.0, **kwargs
|
||||
) -> npt.NDArray:
|
||||
assert num_frames == 10, "bad num_frames"
|
||||
assert fps == 1.0, "bad fps"
|
||||
return FAKE_OUTPUT_2
|
||||
|
||||
|
||||
def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps")
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Verify that different args pass/fail assertions as expected.
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
videoio = VideoMediaIO(
|
||||
imageio, **{"num_frames": 10, "fps": 1.0, "not_used": "not_used"}
|
||||
)
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad num_frames"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
with pytest.raises(AssertionError, match="bad fps"):
|
||||
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
|
||||
_ = videoio.load_bytes(b"test")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_color", [True, False])
|
||||
@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")])
|
||||
def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
|
||||
"""
|
||||
Test all functions that use OpenCV for video I/O return RGB format.
|
||||
Both RGB and grayscale videos are tested.
|
||||
"""
|
||||
image_path = get_vllm_public_assets(
|
||||
filename="stop_sign.jpg", s3_prefix="vision_model_images"
|
||||
)
|
||||
image = Image.open(image_path)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if not is_color:
|
||||
image_path = f"{tmpdir}/test_grayscale_image.png"
|
||||
image = image.convert("L")
|
||||
image.save(image_path)
|
||||
# Convert to gray RGB for comparison
|
||||
image = image.convert("RGB")
|
||||
video_path = f"{tmpdir}/test_RGB_video.{ext}"
|
||||
create_video_from_image(
|
||||
image_path,
|
||||
video_path,
|
||||
num_frames=2,
|
||||
is_color=is_color,
|
||||
fourcc=fourcc,
|
||||
)
|
||||
|
||||
frames = video_to_ndarrays(video_path)
|
||||
for frame in frames:
|
||||
sim = cosine_similarity(
|
||||
normalize_image(np.array(frame)), normalize_image(np.array(image))
|
||||
)
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
pil_frames = video_to_pil_images_list(video_path)
|
||||
for frame in pil_frames:
|
||||
sim = cosine_similarity(
|
||||
normalize_image(np.array(frame)), normalize_image(np.array(image))
|
||||
)
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
|
||||
for frame in io_frames:
|
||||
sim = cosine_similarity(
|
||||
normalize_image(np.array(frame)), normalize_image(np.array(image))
|
||||
)
|
||||
assert np.sum(np.isnan(sim)) / sim.size < 0.001
|
||||
assert np.nanmean(sim) > 0.99
|
||||
|
||||
|
||||
def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Regression test for handling videos with broken frames.
|
||||
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
|
||||
contains broken frames to verify the video loader handles
|
||||
them gracefully without crashing and returns accurate metadata.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv")
|
||||
|
||||
# Load the pre-corrupted video file that contains broken frames
|
||||
corrupted_video_path = ASSETS_DIR / "corrupted.mp4"
|
||||
|
||||
with open(corrupted_video_path, "rb") as f:
|
||||
video_data = f.read()
|
||||
|
||||
loader = VIDEO_LOADER_REGISTRY.load("opencv")
|
||||
frames, metadata = loader.load_bytes(video_data, num_frames=-1)
|
||||
|
||||
# Verify metadata consistency:
|
||||
# frames_indices must match actual loaded frames
|
||||
assert frames.shape[0] == len(metadata["frames_indices"]), (
|
||||
f"Frames array size must equal frames_indices length. "
|
||||
f"Got {frames.shape[0]} frames but "
|
||||
f"{len(metadata['frames_indices'])} indices"
|
||||
)
|
||||
|
||||
# Verify that broken frames were skipped:
|
||||
# loaded frames should be less than total
|
||||
assert frames.shape[0] < metadata["total_num_frames"], (
|
||||
f"Should load fewer frames than total due to broken frames. "
|
||||
f"Expected fewer than {metadata['total_num_frames']} frames, "
|
||||
f"but loaded {frames.shape[0]} frames"
|
||||
)
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1")
|
||||
class TestVideoBackendOverride1(VideoLoader):
|
||||
"""Test loader that returns FAKE_OUTPUT_1 to verify backend selection."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"}
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2")
|
||||
class TestVideoBackendOverride2(VideoLoader):
|
||||
"""Test loader that returns FAKE_OUTPUT_2 to verify backend selection."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"}
|
||||
|
||||
|
||||
def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND
|
||||
environment variable.
|
||||
|
||||
This allows users to dynamically select a different video backend
|
||||
via --media-io-kwargs without changing the global env var, which is
|
||||
useful when plugins set a default backend but a specific request
|
||||
needs a different one.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
# Set the env var to one backend
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Without video_backend kwarg, should use env var backend
|
||||
videoio_default = VideoMediaIO(imageio, num_frames=10)
|
||||
frames_default, metadata_default = videoio_default.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1)
|
||||
assert metadata_default["video_backend"] == "test_video_backend_override_1"
|
||||
|
||||
# With video_backend kwarg, should override env var
|
||||
videoio_override = VideoMediaIO(
|
||||
imageio, num_frames=10, video_backend="test_video_backend_override_2"
|
||||
)
|
||||
frames_override, metadata_override = videoio_override.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2)
|
||||
assert metadata_override["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
|
||||
def test_video_media_io_backend_kwarg_not_passed_to_loader(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""
|
||||
Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed
|
||||
through to the underlying video loader's load_bytes method.
|
||||
|
||||
This ensures the kwarg is properly popped from kwargs before forwarding.
|
||||
"""
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg")
|
||||
class RejectVideoBackendKwargLoader(VideoLoader):
|
||||
"""Test loader that fails if video_backend is passed through."""
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
) -> tuple[npt.NDArray, dict]:
|
||||
# This should never receive video_backend in kwargs
|
||||
if "video_backend" in kwargs:
|
||||
raise AssertionError(
|
||||
"video_backend should be consumed by VideoMediaIO, "
|
||||
"not passed to loader"
|
||||
)
|
||||
return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Even when video_backend is provided, it should NOT be passed to loader
|
||||
videoio = VideoMediaIO(
|
||||
imageio,
|
||||
num_frames=10,
|
||||
video_backend="test_reject_video_backend_kwarg",
|
||||
other_kwarg="should_pass_through",
|
||||
)
|
||||
|
||||
# This should NOT raise AssertionError
|
||||
frames, metadata = videoio.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames, FAKE_OUTPUT_1)
|
||||
# Verify other kwargs are still passed through
|
||||
assert "other_kwarg" in metadata["received_kwargs"]
|
||||
|
||||
|
||||
def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that when video_backend kwarg is None or not provided,
|
||||
VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2")
|
||||
|
||||
imageio = ImageMediaIO()
|
||||
|
||||
# Explicit None should fall back to env var
|
||||
videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None)
|
||||
frames_none, metadata_none = videoio_none.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2)
|
||||
assert metadata_none["video_backend"] == "test_video_backend_override_2"
|
||||
|
||||
# Not providing video_backend should also fall back to env var
|
||||
videoio_missing = VideoMediaIO(imageio, num_frames=10)
|
||||
frames_missing, metadata_missing = videoio_missing.load_bytes(b"test")
|
||||
np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2)
|
||||
assert metadata_missing["video_backend"] == "test_video_backend_override_2"
|
||||
78
tests/multimodal/utils.py
Normal file
78
tests/multimodal/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int):
|
||||
w, h = rng.randint(min_wh, max_wh, size=(2,))
|
||||
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
|
||||
return Image.fromarray(arr)
|
||||
|
||||
|
||||
def random_video(
|
||||
rng: np.random.RandomState,
|
||||
min_frames: int,
|
||||
max_frames: int,
|
||||
min_wh: int,
|
||||
max_wh: int,
|
||||
):
|
||||
num_frames = rng.randint(min_frames, max_frames)
|
||||
w, h = rng.randint(min_wh, max_wh, size=(2,))
|
||||
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def random_audio(
|
||||
rng: np.random.RandomState,
|
||||
min_len: int,
|
||||
max_len: int,
|
||||
sr: int,
|
||||
):
|
||||
audio_len = rng.randint(min_len, max_len)
|
||||
return rng.rand(audio_len), sr
|
||||
|
||||
|
||||
def create_video_from_image(
|
||||
image_path: str,
|
||||
video_path: str,
|
||||
num_frames: int = 10,
|
||||
fps: float = 1.0,
|
||||
is_color: bool = True,
|
||||
fourcc: str = "mp4v",
|
||||
):
|
||||
image = cv2.imread(image_path)
|
||||
if not is_color:
|
||||
# Convert to grayscale if is_color is False
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
height, width = image.shape
|
||||
else:
|
||||
height, width, _ = image.shape
|
||||
|
||||
video_writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*fourcc),
|
||||
fps,
|
||||
(width, height),
|
||||
isColor=is_color,
|
||||
)
|
||||
|
||||
for _ in range(num_frames):
|
||||
video_writer.write(image)
|
||||
|
||||
video_writer.release()
|
||||
return video_path
|
||||
|
||||
|
||||
def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
return np.sum(A * B, axis=axis) / (
|
||||
np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)
|
||||
)
|
||||
|
||||
|
||||
def normalize_image(image: npt.NDArray) -> npt.NDArray:
|
||||
"""Normalize image to [0, 1] range."""
|
||||
return image.astype(np.float32) / 255.0
|
||||
Reference in New Issue
Block a user