Sync from v0.13
This commit is contained in:
0
tests/transformers_utils/__init__.py
Normal file
0
tests/transformers_utils/__init__.py
Normal file
32
tests/transformers_utils/test_config.py
Normal file
32
tests/transformers_utils/test_config.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This test file includes some cases where it is inappropriate to
|
||||
only get the `eos_token_id` from the tokenizer as defined by
|
||||
`vllm.LLMEngine._get_eos_token_id`.
|
||||
"""
|
||||
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
|
||||
|
||||
def test_get_llama3_eos_token():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 128009
|
||||
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == [128001, 128008, 128009]
|
||||
|
||||
|
||||
def test_get_blip2_eos_token():
|
||||
model_name = "Salesforce/blip2-opt-2.7b"
|
||||
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
assert tokenizer.eos_token_id == 2
|
||||
|
||||
generation_config = try_get_generation_config(model_name, trust_remote_code=False)
|
||||
assert generation_config is not None
|
||||
assert generation_config.eos_token_id == 50118
|
||||
35
tests/transformers_utils/test_config_parser_registry.py
Normal file
35
tests/transformers_utils/test_config_parser_registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.config import get_config_parser, register_config_parser
|
||||
from vllm.transformers_utils.config_parser_base import ConfigParserBase
|
||||
|
||||
|
||||
@register_config_parser("custom_config_parser")
|
||||
class CustomConfigParser(ConfigParserBase):
|
||||
def parse(
|
||||
self,
|
||||
model: str | Path,
|
||||
trust_remote_code: bool,
|
||||
revision: str | None = None,
|
||||
code_revision: str | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[dict, PretrainedConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def test_register_config_parser():
|
||||
assert isinstance(get_config_parser("custom_config_parser"), CustomConfigParser)
|
||||
|
||||
|
||||
def test_invalid_config_parser():
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@register_config_parser("invalid_config_parser")
|
||||
class InvalidConfigParser:
|
||||
pass
|
||||
66
tests/transformers_utils/test_processor.py
Normal file
66
tests/transformers_utils/test_processor.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import importlib
|
||||
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from vllm.transformers_utils.processor import (
|
||||
get_processor_kwargs_from_processor,
|
||||
)
|
||||
|
||||
|
||||
class _FakeProcessorKwargs(ProcessingKwargs, total=False): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def _assert_has_all_expected(keys: set[str]) -> None:
|
||||
# text
|
||||
for k in ("text_pair", "text_target", "text_pair_target"):
|
||||
assert k in keys
|
||||
# image
|
||||
for k in ("do_convert_rgb", "do_resize"):
|
||||
assert k in keys
|
||||
# audio
|
||||
for k in (
|
||||
"fps",
|
||||
"do_sample_frames",
|
||||
"input_data_format",
|
||||
"default_to_square",
|
||||
):
|
||||
assert k in keys
|
||||
# audio
|
||||
for k in ("padding", "return_attention_mask"):
|
||||
assert k in keys
|
||||
|
||||
|
||||
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
|
||||
class _ProcWithUnpack:
|
||||
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
|
||||
proc = _ProcWithUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
_assert_has_all_expected(keys)
|
||||
|
||||
|
||||
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
|
||||
|
||||
|
||||
class _ProcWithoutUnpack:
|
||||
def __call__(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
|
||||
# ensure the module scanned by fallback is this test module
|
||||
module_name = _ProcWithoutUnpack.__module__
|
||||
mod = importlib.import_module(module_name)
|
||||
assert hasattr(mod, "_FakeProcessorKwargs")
|
||||
|
||||
proc = _ProcWithoutUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
_assert_has_all_expected(keys)
|
||||
62
tests/transformers_utils/test_repo_utils.py
Normal file
62
tests/transformers_utils/test_repo_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_patterns,expected_relative_files",
|
||||
[
|
||||
(
|
||||
["*.json", "correct*.txt"],
|
||||
["json_file.json", "subfolder/correct.txt", "correct_2.txt"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_list_filtered_repo_files(
|
||||
allow_patterns: list[str], expected_relative_files: list[str]
|
||||
):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Prep folder and files
|
||||
path_tmp_dir = Path(tmp_dir)
|
||||
subfolder = path_tmp_dir / "subfolder"
|
||||
subfolder.mkdir()
|
||||
(path_tmp_dir / "json_file.json").touch()
|
||||
(path_tmp_dir / "correct_2.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.txt").touch()
|
||||
(path_tmp_dir / "uncorrect.jpeg").touch()
|
||||
(subfolder / "correct.txt").touch()
|
||||
(subfolder / "uncorrect_sub.txt").touch()
|
||||
|
||||
def _glob_path() -> list[str]:
|
||||
return [
|
||||
str(file.relative_to(path_tmp_dir))
|
||||
for file in path_tmp_dir.glob("**/*")
|
||||
if file.is_file()
|
||||
]
|
||||
|
||||
# Patch list_repo_files called by fn
|
||||
with patch(
|
||||
"vllm.transformers_utils.repo_utils.list_repo_files",
|
||||
MagicMock(return_value=_glob_path()),
|
||||
) as mock_list_repo_files:
|
||||
out_files = sorted(
|
||||
list_filtered_repo_files(
|
||||
tmp_dir, allow_patterns, "revision", "model", "token"
|
||||
)
|
||||
)
|
||||
assert out_files == sorted(expected_relative_files)
|
||||
assert mock_list_repo_files.call_count == 1
|
||||
assert mock_list_repo_files.call_args_list[0] == call(
|
||||
repo_id=tmp_dir,
|
||||
revision="revision",
|
||||
repo_type="model",
|
||||
token="token",
|
||||
)
|
||||
178
tests/transformers_utils/test_utils.py
Normal file
178
tests/transformers_utils/test_utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.gguf_utils import (
|
||||
is_gguf,
|
||||
is_remote_gguf,
|
||||
split_remote_gguf,
|
||||
)
|
||||
from vllm.transformers_utils.utils import (
|
||||
is_cloud_storage,
|
||||
is_gcs,
|
||||
is_s3,
|
||||
)
|
||||
|
||||
|
||||
def test_is_gcs():
|
||||
assert is_gcs("gs://model-path")
|
||||
assert not is_gcs("s3://model-path/path-to-model")
|
||||
assert not is_gcs("/unix/local/path")
|
||||
assert not is_gcs("nfs://nfs-fqdn.local")
|
||||
|
||||
|
||||
def test_is_s3():
|
||||
assert is_s3("s3://model-path/path-to-model")
|
||||
assert not is_s3("gs://model-path")
|
||||
assert not is_s3("/unix/local/path")
|
||||
assert not is_s3("nfs://nfs-fqdn.local")
|
||||
|
||||
|
||||
def test_is_cloud_storage():
|
||||
assert is_cloud_storage("gs://model-path")
|
||||
assert is_cloud_storage("s3://model-path/path-to-model")
|
||||
assert not is_cloud_storage("/unix/local/path")
|
||||
assert not is_cloud_storage("nfs://nfs-fqdn.local")
|
||||
|
||||
|
||||
class TestIsRemoteGGUF:
|
||||
"""Test is_remote_gguf utility function."""
|
||||
|
||||
def test_is_remote_gguf_with_colon_and_slash(self):
|
||||
"""Test is_remote_gguf with repo_id:quant_type format."""
|
||||
# Valid quant types
|
||||
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert is_remote_gguf("user/repo:Q2_K")
|
||||
assert is_remote_gguf("repo/model:Q4_K")
|
||||
assert is_remote_gguf("repo/model:Q8_0")
|
||||
|
||||
# Invalid quant types should return False
|
||||
assert not is_remote_gguf("repo/model:quant")
|
||||
assert not is_remote_gguf("repo/model:INVALID")
|
||||
assert not is_remote_gguf("repo/model:invalid_type")
|
||||
|
||||
def test_is_remote_gguf_without_colon(self):
|
||||
"""Test is_remote_gguf without colon."""
|
||||
assert not is_remote_gguf("repo/model")
|
||||
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
|
||||
|
||||
def test_is_remote_gguf_without_slash(self):
|
||||
"""Test is_remote_gguf without slash."""
|
||||
assert not is_remote_gguf("model.gguf")
|
||||
# Even with valid quant_type, no slash means not remote GGUF
|
||||
assert not is_remote_gguf("model:IQ1_S")
|
||||
assert not is_remote_gguf("model:quant")
|
||||
|
||||
def test_is_remote_gguf_local_path(self):
|
||||
"""Test is_remote_gguf with local file path."""
|
||||
assert not is_remote_gguf("/path/to/model.gguf")
|
||||
assert not is_remote_gguf("./model.gguf")
|
||||
|
||||
def test_is_remote_gguf_with_path_object(self):
|
||||
"""Test is_remote_gguf with Path object."""
|
||||
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||
assert not is_remote_gguf(Path("repo/model"))
|
||||
|
||||
def test_is_remote_gguf_with_http_https(self):
|
||||
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
|
||||
# HTTP/HTTPS URLs should return False even with valid quant_type
|
||||
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
|
||||
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
|
||||
assert not is_remote_gguf("http://repo/model:Q4_K")
|
||||
assert not is_remote_gguf("https://repo/model:Q8_0")
|
||||
|
||||
def test_is_remote_gguf_with_cloud_storage(self):
|
||||
"""Test is_remote_gguf with cloud storage paths."""
|
||||
# Cloud storage paths should return False even with valid quant_type
|
||||
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
|
||||
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
|
||||
assert not is_remote_gguf("s3://repo/model:Q4_K")
|
||||
assert not is_remote_gguf("gs://repo/model:Q8_0")
|
||||
|
||||
|
||||
class TestSplitRemoteGGUF:
|
||||
"""Test split_remote_gguf utility function."""
|
||||
|
||||
def test_split_remote_gguf_valid(self):
|
||||
"""Test split_remote_gguf with valid repo_id:quant_type format."""
|
||||
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||
assert quant_type == "IQ1_S"
|
||||
|
||||
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
|
||||
assert repo_id == "repo/model"
|
||||
assert quant_type == "Q2_K"
|
||||
|
||||
def test_split_remote_gguf_with_path_object(self):
|
||||
"""Test split_remote_gguf with Path object."""
|
||||
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
|
||||
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
|
||||
assert quant_type == "IQ1_S"
|
||||
|
||||
def test_split_remote_gguf_invalid(self):
|
||||
"""Test split_remote_gguf with invalid format."""
|
||||
# Invalid format (no colon) - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("repo/model")
|
||||
|
||||
# Invalid quant type - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("repo/model:INVALID_TYPE")
|
||||
|
||||
# HTTP URL - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("http://repo/model:IQ1_S")
|
||||
|
||||
# Cloud storage - is_remote_gguf returns False
|
||||
with pytest.raises(ValueError, match="Wrong GGUF model"):
|
||||
split_remote_gguf("s3://bucket/repo/model:Q2_K")
|
||||
|
||||
|
||||
class TestIsGGUF:
|
||||
"""Test is_gguf utility function."""
|
||||
|
||||
@patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=True)
|
||||
def test_is_gguf_with_local_file(self, mock_check_gguf):
|
||||
"""Test is_gguf with local GGUF file."""
|
||||
assert is_gguf("/path/to/model.gguf")
|
||||
assert is_gguf("./model.gguf")
|
||||
|
||||
def test_is_gguf_with_remote_gguf(self):
|
||||
"""Test is_gguf with remote GGUF format."""
|
||||
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
|
||||
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
|
||||
assert is_gguf("repo/model:Q2_K")
|
||||
assert is_gguf("repo/model:Q4_K")
|
||||
|
||||
# Invalid quant_type should return False
|
||||
assert not is_gguf("repo/model:quant")
|
||||
assert not is_gguf("repo/model:INVALID")
|
||||
|
||||
@patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=False)
|
||||
def test_is_gguf_false(self, mock_check_gguf):
|
||||
"""Test is_gguf returns False for non-GGUF models."""
|
||||
assert not is_gguf("unsloth/Qwen3-0.6B")
|
||||
assert not is_gguf("repo/model")
|
||||
assert not is_gguf("model")
|
||||
|
||||
def test_is_gguf_edge_cases(self):
|
||||
"""Test is_gguf with edge cases."""
|
||||
# Empty string
|
||||
assert not is_gguf("")
|
||||
|
||||
# Only colon, no slash (even with valid quant_type)
|
||||
assert not is_gguf("model:IQ1_S")
|
||||
|
||||
# Only slash, no colon
|
||||
assert not is_gguf("repo/model")
|
||||
|
||||
# HTTP/HTTPS URLs
|
||||
assert not is_gguf("http://repo/model:IQ1_S")
|
||||
assert not is_gguf("https://repo/model:Q2_K")
|
||||
|
||||
# Cloud storage
|
||||
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
|
||||
assert not is_gguf("gs://bucket/repo/model:Q2_K")
|
||||
Reference in New Issue
Block a user