Sync from v0.13
This commit is contained in:
6
tests/utils_/__init__.py
Normal file
6
tests/utils_/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This module is named `utils_` instead of `utils` to avoid obscuring
|
||||
`tests/utils.py`.
|
||||
"""
|
||||
460
tests/utils_/test_argparse_utils.py
Normal file
460
tests/utils_/test_argparse_utils.py
Normal file
@@ -0,0 +1,460 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from transformers import AutoTokenizer
|
||||
from pydantic import ValidationError
|
||||
|
||||
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
|
||||
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from ..utils import flat_product
|
||||
|
||||
|
||||
# Tests for FlexibleArgumentParser
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument(
|
||||
"--image-input-type", choices=["pixel_values", "image_features"]
|
||||
)
|
||||
parser.add_argument("--model-name")
|
||||
parser.add_argument("--batch-size", type=int)
|
||||
parser.add_argument("--enable-feature", action="store_true")
|
||||
parser.add_argument("--hf-overrides", type=json.loads)
|
||||
parser.add_argument("-cc", "--compilation-config", type=json.loads)
|
||||
parser.add_argument("--optimization-level", type=int)
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser_with_config():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("serve")
|
||||
parser.add_argument("model_tag", nargs="?")
|
||||
parser.add_argument("--model", type=str)
|
||||
parser.add_argument("--served-model-name", type=str)
|
||||
parser.add_argument("--config", type=str)
|
||||
parser.add_argument("--port", type=int)
|
||||
parser.add_argument("--tensor-parallel-size", type=int)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
def test_underscore_to_dash(parser):
|
||||
args = parser.parse_args(["--image_input_type", "pixel_values"])
|
||||
assert args.image_input_type == "pixel_values"
|
||||
|
||||
|
||||
def test_mixed_usage(parser):
|
||||
args = parser.parse_args(
|
||||
["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"]
|
||||
)
|
||||
assert args.image_input_type == "image_features"
|
||||
assert args.model_name == "facebook/opt-125m"
|
||||
|
||||
|
||||
def test_with_equals_sign(parser):
|
||||
args = parser.parse_args(
|
||||
["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"]
|
||||
)
|
||||
assert args.image_input_type == "pixel_values"
|
||||
assert args.model_name == "facebook/opt-125m"
|
||||
|
||||
|
||||
def test_with_int_value(parser):
|
||||
args = parser.parse_args(["--batch_size", "32"])
|
||||
assert args.batch_size == 32
|
||||
args = parser.parse_args(["--batch-size", "32"])
|
||||
assert args.batch_size == 32
|
||||
|
||||
|
||||
def test_with_bool_flag(parser):
|
||||
args = parser.parse_args(["--enable_feature"])
|
||||
assert args.enable_feature is True
|
||||
args = parser.parse_args(["--enable-feature"])
|
||||
assert args.enable_feature is True
|
||||
|
||||
|
||||
def test_invalid_choice(parser):
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args(["--image_input_type", "invalid_choice"])
|
||||
|
||||
|
||||
def test_missing_required_argument(parser):
|
||||
parser.add_argument("--required-arg", required=True)
|
||||
with pytest.raises(SystemExit):
|
||||
parser.parse_args([])
|
||||
|
||||
|
||||
def test_cli_override_to_config(parser_with_config, cli_config_file):
|
||||
args = parser_with_config.parse_args(
|
||||
["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"]
|
||||
)
|
||||
assert args.tensor_parallel_size == 3
|
||||
args = parser_with_config.parse_args(
|
||||
["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file]
|
||||
)
|
||||
assert args.tensor_parallel_size == 3
|
||||
assert args.port == 12312
|
||||
args = parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"mymodel",
|
||||
"--tensor-parallel-size",
|
||||
"3",
|
||||
"--config",
|
||||
cli_config_file,
|
||||
"--port",
|
||||
"666",
|
||||
]
|
||||
)
|
||||
assert args.tensor_parallel_size == 3
|
||||
assert args.port == 666
|
||||
|
||||
|
||||
def test_config_args(parser_with_config, cli_config_file):
|
||||
args = parser_with_config.parse_args(
|
||||
["serve", "mymodel", "--config", cli_config_file]
|
||||
)
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code
|
||||
|
||||
|
||||
def test_config_file(parser_with_config):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
parser_with_config.parse_args(
|
||||
["serve", "mymodel", "--config", "test_config.yml"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser_with_config.parse_args(
|
||||
["serve", "mymodel", "--config", "./data/test_config.json"]
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"mymodel",
|
||||
"--tensor-parallel-size",
|
||||
"3",
|
||||
"--config",
|
||||
"--batch-size",
|
||||
"32",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_no_model_tag(parser_with_config, cli_config_file):
|
||||
with pytest.raises(ValueError):
|
||||
parser_with_config.parse_args(["serve", "--config", cli_config_file])
|
||||
|
||||
|
||||
def test_dict_args(parser):
|
||||
args = [
|
||||
"--model-name=something.something",
|
||||
"--hf-overrides.key1",
|
||||
"val1",
|
||||
# Test nesting
|
||||
"--hf-overrides.key2.key3",
|
||||
"val2",
|
||||
"--hf-overrides.key2.key4",
|
||||
"val3",
|
||||
# Test compile config and compilation mode
|
||||
"-cc.use_inductor_graph_partition=true",
|
||||
"-cc.backend",
|
||||
"custom",
|
||||
"-O1",
|
||||
# Test = sign
|
||||
"--hf-overrides.key5=val4",
|
||||
# Test underscore to dash conversion
|
||||
"--hf_overrides.key_6",
|
||||
"val5",
|
||||
"--hf_overrides.key-7.key_8",
|
||||
"val6",
|
||||
# Test data type detection
|
||||
"--hf_overrides.key9",
|
||||
"100",
|
||||
"--hf_overrides.key10",
|
||||
"100.0",
|
||||
"--hf_overrides.key11",
|
||||
"true",
|
||||
"--hf_overrides.key12.key13",
|
||||
"null",
|
||||
# Test '-' and '.' in value
|
||||
"--hf_overrides.key14.key15",
|
||||
"-minus.and.dot",
|
||||
# Test array values
|
||||
"-cc.custom_ops+",
|
||||
"-quant_fp8",
|
||||
"-cc.custom_ops+=+silu_mul,-rms_norm",
|
||||
]
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.model_name == "something.something"
|
||||
assert parsed_args.hf_overrides == {
|
||||
"key1": "val1",
|
||||
"key2": {
|
||||
"key3": "val2",
|
||||
"key4": "val3",
|
||||
},
|
||||
"key5": "val4",
|
||||
"key_6": "val5",
|
||||
"key-7": {
|
||||
"key_8": "val6",
|
||||
},
|
||||
"key9": 100,
|
||||
"key10": 100.0,
|
||||
"key11": True,
|
||||
"key12": {
|
||||
"key13": None,
|
||||
},
|
||||
"key14": {
|
||||
"key15": "-minus.and.dot",
|
||||
},
|
||||
}
|
||||
assert parsed_args.optimization_level == 1
|
||||
assert parsed_args.compilation_config == {
|
||||
"use_inductor_graph_partition": True,
|
||||
"backend": "custom",
|
||||
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
|
||||
}
|
||||
|
||||
|
||||
def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
args = [
|
||||
"--model-name=something.something",
|
||||
"--hf-overrides.key1",
|
||||
"val1",
|
||||
"--hf-overrides.key1",
|
||||
"val2",
|
||||
"-O1",
|
||||
"-cc.mode",
|
||||
"2",
|
||||
"-O3",
|
||||
]
|
||||
|
||||
parsed_args = parser.parse_args(args)
|
||||
# Should be the last value
|
||||
assert parsed_args.hf_overrides == {"key1": "val2"}
|
||||
assert parsed_args.optimization_level == 3
|
||||
assert parsed_args.compilation_config == {"mode": 2}
|
||||
|
||||
assert len(caplog_vllm.records) == 1
|
||||
assert "duplicate" in caplog_vllm.text
|
||||
assert "--hf-overrides.key1" in caplog_vllm.text
|
||||
assert "--optimization-level" in caplog_vllm.text
|
||||
|
||||
|
||||
def test_model_specification(
|
||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||
):
|
||||
# Test model in CLI takes precedence over config
|
||||
args = parser_with_config.parse_args(
|
||||
["serve", "cli-model", "--config", cli_config_file_with_model]
|
||||
)
|
||||
assert args.model_tag == "cli-model"
|
||||
assert args.served_model_name == "mymodel"
|
||||
|
||||
# Test model from config file works
|
||||
args = parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"--config",
|
||||
cli_config_file_with_model,
|
||||
]
|
||||
)
|
||||
assert args.model == "config-model"
|
||||
assert args.served_model_name == "mymodel"
|
||||
|
||||
# Test no model specified anywhere raises error
|
||||
with pytest.raises(ValueError, match="No model specified!"):
|
||||
parser_with_config.parse_args(["serve", "--config", cli_config_file])
|
||||
|
||||
# Test using --model option raises error
|
||||
# with pytest.raises(
|
||||
# ValueError,
|
||||
# match=
|
||||
# ("With `vllm serve`, you should provide the model as a positional "
|
||||
# "argument or in a config file instead of via the `--model` option."),
|
||||
# ):
|
||||
# parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
|
||||
# Test using --model option back-compatibility
|
||||
# (when back-compatibility ends, the above test should be uncommented
|
||||
# and the below test should be removed)
|
||||
args = parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"--tensor-parallel-size",
|
||||
"2",
|
||||
"--model",
|
||||
"my-model",
|
||||
"--trust-remote-code",
|
||||
"--port",
|
||||
"8001",
|
||||
]
|
||||
)
|
||||
assert args.model is None
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.port == 8001
|
||||
|
||||
args = parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"--tensor-parallel-size=2",
|
||||
"--model=my-model",
|
||||
"--trust-remote-code",
|
||||
"--port=8001",
|
||||
]
|
||||
)
|
||||
assert args.model is None
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.port == 8001
|
||||
|
||||
# Test other config values are preserved
|
||||
args = parser_with_config.parse_args(
|
||||
[
|
||||
"serve",
|
||||
"cli-model",
|
||||
"--config",
|
||||
cli_config_file_with_model,
|
||||
]
|
||||
)
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
assert args.port == 12312
|
||||
|
||||
|
||||
def test_convert_ids_list_to_tokens():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||
token_ids = tokenizer.encode("Hello, world!")
|
||||
# token_ids = [9707, 11, 1879, 0]
|
||||
assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"]
|
||||
tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
|
||||
assert tokens == ["Hello", ",", " world", "!"]
|
||||
|
||||
|
||||
def test_load_config_file(tmp_path):
|
||||
# Define the configuration data
|
||||
config_data = {
|
||||
"enable-logging": True,
|
||||
"list-arg": ["item1", "item2"],
|
||||
"port": 12323,
|
||||
"tensor-parallel-size": 4,
|
||||
}
|
||||
|
||||
# Write the configuration data to a temporary YAML file
|
||||
config_file_path = tmp_path / "config.yaml"
|
||||
with open(config_file_path, "w") as config_file:
|
||||
yaml.dump(config_data, config_file)
|
||||
|
||||
# Initialize the parser
|
||||
parser = FlexibleArgumentParser()
|
||||
|
||||
# Call the function with the temporary file path
|
||||
processed_args = parser.load_config_file(str(config_file_path))
|
||||
|
||||
# Expected output
|
||||
expected_args = [
|
||||
"--enable-logging",
|
||||
"--list-arg",
|
||||
"item1",
|
||||
"item2",
|
||||
"--port",
|
||||
"12323",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]
|
||||
|
||||
# Assert that the processed arguments match the expected output
|
||||
assert processed_args == expected_args
|
||||
os.remove(str(config_file_path))
|
||||
|
||||
|
||||
def test_compilation_mode_string_values(parser):
|
||||
"""Test that -cc.mode accepts both integer and string mode values."""
|
||||
args = parser.parse_args(["-cc.mode", "0"])
|
||||
assert args.compilation_config == {"mode": 0}
|
||||
|
||||
args = parser.parse_args(["-O3"])
|
||||
assert args.optimization_level == 3
|
||||
|
||||
args = parser.parse_args(["-cc.mode=NONE"])
|
||||
assert args.compilation_config == {"mode": "NONE"}
|
||||
|
||||
args = parser.parse_args(["-cc.mode", "STOCK_TORCH_COMPILE"])
|
||||
assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}
|
||||
|
||||
args = parser.parse_args(["-cc.mode=DYNAMO_TRACE_ONCE"])
|
||||
assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}
|
||||
|
||||
args = parser.parse_args(["-cc.mode", "VLLM_COMPILE"])
|
||||
assert args.compilation_config == {"mode": "VLLM_COMPILE"}
|
||||
|
||||
args = parser.parse_args(["-cc.mode=none"])
|
||||
assert args.compilation_config == {"mode": "none"}
|
||||
|
||||
args = parser.parse_args(["-cc.mode=vllm_compile"])
|
||||
assert args.compilation_config == {"mode": "vllm_compile"}
|
||||
|
||||
|
||||
def test_compilation_config_mode_validator():
|
||||
"""Test that CompilationConfig.mode field validator converts strings to integers."""
|
||||
from vllm.config.compilation import CompilationConfig, CompilationMode
|
||||
|
||||
config = CompilationConfig(mode=0)
|
||||
assert config.mode == CompilationMode.NONE
|
||||
|
||||
config = CompilationConfig(mode=3)
|
||||
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
config = CompilationConfig(mode="NONE")
|
||||
assert config.mode == CompilationMode.NONE
|
||||
|
||||
config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
|
||||
assert config.mode == CompilationMode.STOCK_TORCH_COMPILE
|
||||
|
||||
config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
|
||||
assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE
|
||||
|
||||
config = CompilationConfig(mode="VLLM_COMPILE")
|
||||
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
config = CompilationConfig(mode="none")
|
||||
assert config.mode == CompilationMode.NONE
|
||||
|
||||
config = CompilationConfig(mode="vllm_compile")
|
||||
assert config.mode == CompilationMode.VLLM_COMPILE
|
||||
|
||||
with pytest.raises(ValidationError, match="Invalid compilation mode"):
|
||||
CompilationConfig(mode="INVALID_MODE")
|
||||
|
||||
|
||||
def test_flat_product():
|
||||
# Check regular itertools.product behavior
|
||||
result1 = list(flat_product([1, 2, 3], ["a", "b"]))
|
||||
assert result1 == [
|
||||
(1, "a"),
|
||||
(1, "b"),
|
||||
(2, "a"),
|
||||
(2, "b"),
|
||||
(3, "a"),
|
||||
(3, "b"),
|
||||
]
|
||||
|
||||
# check that the tuples get flattened
|
||||
result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
|
||||
assert result2 == [
|
||||
(1, 2, "a", 5, 6),
|
||||
(1, 2, "b", 5, 6),
|
||||
(3, 4, "a", 5, 6),
|
||||
(3, 4, "b", 5, 6),
|
||||
]
|
||||
42
tests/utils_/test_async_utils.py
Normal file
42
tests/utils_/test_async_utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
async def _mock_async_iterator(idx: int):
|
||||
try:
|
||||
while True:
|
||||
yield f"item from iterator {idx}"
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
print(f"iterator {idx} cancelled")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_async_iterators():
|
||||
iterators = [_mock_async_iterator(i) for i in range(3)]
|
||||
merged_iterator = merge_async_iterators(*iterators)
|
||||
|
||||
async def stream_output(generator: AsyncIterator[tuple[int, str]]):
|
||||
async for idx, output in generator:
|
||||
print(f"idx: {idx}, output: {output}")
|
||||
|
||||
task = asyncio.create_task(stream_output(merged_iterator))
|
||||
await asyncio.sleep(0.5)
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
for iterator in iterators:
|
||||
try:
|
||||
await asyncio.wait_for(anext(iterator), 1)
|
||||
except StopAsyncIteration:
|
||||
# All iterators should be cancelled and print this message.
|
||||
print("Iterator was cancelled normally")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
raise AssertionError() from e
|
||||
125
tests/utils_/test_cache.py
Normal file
125
tests/utils_/test_cache.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.utils.cache import CacheInfo, LRUCache
|
||||
|
||||
|
||||
class TestLRUCache(LRUCache):
|
||||
def _on_remove(self, key, value):
|
||||
if not hasattr(self, "_remove_counter"):
|
||||
self._remove_counter = 0
|
||||
self._remove_counter += 1
|
||||
|
||||
|
||||
def test_lru_cache():
|
||||
cache = TestLRUCache(3)
|
||||
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||
|
||||
cache.put(1, 1)
|
||||
assert len(cache) == 1
|
||||
|
||||
cache.put(1, 1)
|
||||
assert len(cache) == 1
|
||||
|
||||
cache.put(2, 2)
|
||||
assert len(cache) == 2
|
||||
|
||||
cache.put(3, 3)
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {1, 2, 3}
|
||||
|
||||
cache.put(4, 4)
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {2, 3, 4}
|
||||
assert cache._remove_counter == 1
|
||||
|
||||
assert cache.get(2) == 2
|
||||
assert cache.stat() == CacheInfo(hits=1, total=1)
|
||||
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||
|
||||
assert cache[2] == 2
|
||||
assert cache.stat() == CacheInfo(hits=2, total=2)
|
||||
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||
|
||||
cache.put(5, 5)
|
||||
assert set(cache.cache) == {2, 4, 5}
|
||||
assert cache._remove_counter == 2
|
||||
|
||||
assert cache.pop(5) == 5
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 4}
|
||||
assert cache._remove_counter == 3
|
||||
|
||||
assert cache.get(-1) is None
|
||||
assert cache.stat() == CacheInfo(hits=2, total=3)
|
||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
|
||||
|
||||
cache.pop(10)
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 4}
|
||||
assert cache._remove_counter == 3
|
||||
|
||||
cache.get(10)
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 4}
|
||||
assert cache._remove_counter == 3
|
||||
|
||||
cache.put(6, 6)
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {2, 4, 6}
|
||||
assert 2 in cache
|
||||
assert 4 in cache
|
||||
assert 6 in cache
|
||||
|
||||
cache.remove_oldest()
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 6}
|
||||
assert cache._remove_counter == 4
|
||||
|
||||
cache.clear()
|
||||
assert len(cache) == 0
|
||||
assert cache._remove_counter == 6
|
||||
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||
|
||||
cache._remove_counter = 0
|
||||
|
||||
cache[1] = 1
|
||||
assert len(cache) == 1
|
||||
|
||||
cache[1] = 1
|
||||
assert len(cache) == 1
|
||||
|
||||
cache[2] = 2
|
||||
assert len(cache) == 2
|
||||
|
||||
cache[3] = 3
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {1, 2, 3}
|
||||
|
||||
cache[4] = 4
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {2, 3, 4}
|
||||
assert cache._remove_counter == 1
|
||||
assert cache[2] == 2
|
||||
|
||||
cache[5] = 5
|
||||
assert set(cache.cache) == {2, 4, 5}
|
||||
assert cache._remove_counter == 2
|
||||
|
||||
del cache[5]
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 4}
|
||||
assert cache._remove_counter == 3
|
||||
|
||||
cache.pop(10)
|
||||
assert len(cache) == 2
|
||||
assert set(cache.cache) == {2, 4}
|
||||
assert cache._remove_counter == 3
|
||||
|
||||
cache[6] = 6
|
||||
assert len(cache) == 3
|
||||
assert set(cache.cache) == {2, 4, 6}
|
||||
assert 2 in cache
|
||||
assert 4 in cache
|
||||
assert 6 in cache
|
||||
31
tests/utils_/test_collection_utils.py
Normal file
31
tests/utils_/test_collection_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.utils.collection_utils import swap_dict_values
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
# Tests for both keys exist
|
||||
({1: "a", 2: "b"}, 1, 2),
|
||||
# Tests for one key does not exist
|
||||
({1: "a", 2: "b"}, 1, 3),
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
],
|
||||
)
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
|
||||
swap_dict_values(obj, key1, key2)
|
||||
|
||||
if key1 in original_obj:
|
||||
assert obj[key2] == original_obj[key1]
|
||||
else:
|
||||
assert key2 not in obj
|
||||
if key2 in original_obj:
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
97
tests/utils_/test_func_utils.py
Normal file
97
tests/utils_/test_func_utils.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.func_utils import deprecate_kwargs, supports_kw
|
||||
|
||||
from ..utils import error_on_warning
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
],
|
||||
)
|
||||
def test_supports_kw(
|
||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||
):
|
||||
assert (
|
||||
supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
== is_supported
|
||||
)
|
||||
85
tests/utils_/test_gc_utils.py
Normal file
85
tests/utils_/test_gc_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from vllm.utils.gc_utils import (
|
||||
GCDebugConfig,
|
||||
_compute_detailed_type,
|
||||
_compute_top_gc_collected_objects,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Normal:
|
||||
v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListWrapper:
|
||||
vs: list[int]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.vs)
|
||||
|
||||
|
||||
def test_compute_detailed_type():
|
||||
assert (
|
||||
_compute_detailed_type(Normal(v=8))
|
||||
== "<class 'tests.utils_.test_gc_utils.Normal'>"
|
||||
)
|
||||
|
||||
assert _compute_detailed_type([1, 2, 3]) == "<class 'list'>(size:3)"
|
||||
assert _compute_detailed_type({4, 5}) == "<class 'set'>(size:2)"
|
||||
assert _compute_detailed_type({6: 7}) == "<class 'dict'>(size:1)"
|
||||
assert (
|
||||
_compute_detailed_type(ListWrapper(vs=[]))
|
||||
== "<class 'tests.utils_.test_gc_utils.ListWrapper'>(size:0)"
|
||||
)
|
||||
|
||||
|
||||
def test_compute_top_gc_collected_objects():
|
||||
objects: list[Any] = [
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9],
|
||||
[10, 11, 12],
|
||||
{13, 14},
|
||||
{15: 16, 17: 18},
|
||||
Normal(v=19),
|
||||
Normal(v=20),
|
||||
Normal(v=21),
|
||||
]
|
||||
assert _compute_top_gc_collected_objects(objects, top=-1) == ""
|
||||
assert _compute_top_gc_collected_objects(objects, top=0) == ""
|
||||
assert (
|
||||
_compute_top_gc_collected_objects(objects, top=1)
|
||||
== " 4:<class 'list'>(size:3)"
|
||||
)
|
||||
assert _compute_top_gc_collected_objects(objects, top=2) == "\n".join(
|
||||
[
|
||||
" 4:<class 'list'>(size:3)",
|
||||
" 3:<class 'tests.utils_.test_gc_utils.Normal'>",
|
||||
]
|
||||
)
|
||||
assert _compute_top_gc_collected_objects(objects, top=3) == "\n".join(
|
||||
[
|
||||
" 4:<class 'list'>(size:3)",
|
||||
" 3:<class 'tests.utils_.test_gc_utils.Normal'>",
|
||||
" 1:<class 'set'>(size:2)",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_gc_debug_config():
|
||||
assert not GCDebugConfig(None).enabled
|
||||
assert not GCDebugConfig("").enabled
|
||||
assert not GCDebugConfig("0").enabled
|
||||
|
||||
config = GCDebugConfig("1")
|
||||
assert config.enabled
|
||||
assert config.top_objects == -1
|
||||
|
||||
config = GCDebugConfig('{"top_objects":5}')
|
||||
assert config.enabled
|
||||
assert config.top_objects == 5
|
||||
25
tests/utils_/test_hashing.py
Normal file
25
tests/utils_/test_hashing.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import pickle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.hashing import sha256
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])])
|
||||
def test_sha256(input: tuple):
|
||||
digest = sha256(input)
|
||||
assert digest is not None
|
||||
assert isinstance(digest, bytes)
|
||||
assert digest != b""
|
||||
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert digest == hashlib.sha256(input_bytes).digest()
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert digest == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert digest != sha256(input + (1,))
|
||||
46
tests/utils_/test_import_utils.py
Normal file
46
tests/utils_/test_import_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
|
||||
def _raises_module_not_found():
|
||||
return pytest.raises(ModuleNotFoundError, match="No module named")
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
with _raises_module_not_found():
|
||||
int(placeholder)
|
||||
|
||||
with _raises_module_not_found():
|
||||
placeholder()
|
||||
|
||||
with _raises_module_not_found():
|
||||
_ = placeholder.some_attr
|
||||
|
||||
with _raises_module_not_found():
|
||||
# Test conflict with internal __name attribute
|
||||
_ = placeholder.name
|
||||
|
||||
# OK to print the placeholder or use it in a f-string
|
||||
_ = repr(placeholder)
|
||||
_ = str(placeholder)
|
||||
|
||||
# No error yet; only error when it is used downstream
|
||||
placeholder_attr = placeholder.placeholder_attr("attr")
|
||||
|
||||
with _raises_module_not_found():
|
||||
int(placeholder_attr)
|
||||
|
||||
with _raises_module_not_found():
|
||||
placeholder_attr()
|
||||
|
||||
with _raises_module_not_found():
|
||||
_ = placeholder_attr.some_attr
|
||||
|
||||
with _raises_module_not_found():
|
||||
# Test conflict with internal __module attribute
|
||||
_ = placeholder_attr.module
|
||||
32
tests/utils_/test_jsontree.py
Normal file
32
tests/utils_/test_jsontree.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.utils.jsontree import json_count_leaves
|
||||
|
||||
|
||||
def test_json_count_leaves():
|
||||
"""Test json_count_leaves function from jsontree utility."""
|
||||
|
||||
# Single leaf values
|
||||
assert json_count_leaves(42) == 1
|
||||
assert json_count_leaves("hello") == 1
|
||||
assert json_count_leaves(None) == 1
|
||||
|
||||
# Empty containers
|
||||
assert json_count_leaves([]) == 0
|
||||
assert json_count_leaves({}) == 0
|
||||
assert json_count_leaves(()) == 0
|
||||
|
||||
# Flat structures
|
||||
assert json_count_leaves([1, 2, 3]) == 3
|
||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||
assert json_count_leaves((1, 2, 3)) == 3
|
||||
|
||||
# Nested structures
|
||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
assert json_count_leaves(nested_dict) == 3
|
||||
|
||||
nested_list = [1, [2, 3], 4]
|
||||
assert json_count_leaves(nested_list) == 4
|
||||
|
||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||
assert json_count_leaves(mixed_nested) == 4
|
||||
63
tests/utils_/test_mem_utils.py
Normal file
63
tests/utils_/test_mem_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from vllm_test_utils.monitor import monitor
|
||||
|
||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_memory_profiling():
|
||||
# Fake out some model loading + inference memory usage to test profiling
|
||||
# Memory used by other processes will show up as cuda usage outside of torch
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
# 512 MiB allocation outside of this instance
|
||||
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
|
||||
|
||||
baseline_snapshot = MemorySnapshot()
|
||||
|
||||
# load weights
|
||||
|
||||
weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32)
|
||||
|
||||
weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
|
||||
def measure_current_non_torch():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current_used = total - free
|
||||
current_torch = torch.cuda.memory_reserved()
|
||||
current_non_torch = current_used - current_torch
|
||||
return current_non_torch
|
||||
|
||||
with (
|
||||
memory_profiling(
|
||||
baseline_snapshot=baseline_snapshot, weights_memory=weights_memory
|
||||
) as result,
|
||||
monitor(measure_current_non_torch) as monitored_values,
|
||||
):
|
||||
# make a memory spike, 1 GiB
|
||||
spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32)
|
||||
del spike
|
||||
|
||||
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
||||
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
||||
|
||||
# this is an analytic value, it is exact,
|
||||
# we only have 256 MiB non-torch memory increase
|
||||
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
|
||||
assert measured_diff == 256 * 1024 * 1024
|
||||
|
||||
# Check that the memory usage is within 5% of the expected values
|
||||
# 5% tolerance is caused by cuda runtime.
|
||||
# we cannot control cuda runtime in the granularity of bytes,
|
||||
# which causes a small error (<10 MiB in practice)
|
||||
non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
|
||||
assert abs(non_torch_ratio - 1) <= 0.05
|
||||
assert result.torch_peak_increase == 1024 * 1024 * 1024
|
||||
del weights
|
||||
lib.cudaFree(handle1)
|
||||
lib.cudaFree(handle2)
|
||||
126
tests/utils_/test_network_utils.py
Normal file
126
tests/utils_/test_network_utils.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.utils.network_utils import (
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
join_host_port,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
)
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
# make sure we can get multiple ports, even if the env var is set
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
|
||||
s1.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
|
||||
s2.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
|
||||
s3.bind(("localhost", get_open_port()))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,expected",
|
||||
[
|
||||
("ipc://some_path", ("ipc", "some_path", "")),
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_path",
|
||||
[
|
||||
"invalid_path", # Missing scheme
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
|
||||
|
||||
def test_make_zmq_socket_ipv6():
|
||||
# Check if IPv6 is supported by trying to create an IPv6 socket
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
sock.close()
|
||||
except OSError:
|
||||
pytest.skip("IPv6 is not supported on this system")
|
||||
|
||||
ctx = zmq.Context()
|
||||
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
|
||||
socket_type = zmq.REP # Example socket type
|
||||
|
||||
# Create the socket
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, (
|
||||
"IPV6 option should be enabled for IPv6 addresses"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_get_tcp_uri():
|
||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_split_host_port():
|
||||
# valid ipv4
|
||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||
# invalid ipv4
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("127.0.0.1::5555")
|
||||
with pytest.raises(ValueError):
|
||||
# tailing colon
|
||||
assert split_host_port("127.0.0.1:5555:")
|
||||
with pytest.raises(ValueError):
|
||||
# no colon
|
||||
assert split_host_port("127.0.0.15555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("127.0.0.1:5555a")
|
||||
|
||||
# valid ipv6
|
||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||
# invalid ipv6
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("[::1]::5555")
|
||||
with pytest.raises(IndexError):
|
||||
# no colon
|
||||
assert split_host_port("[::1]5555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("[::1]:5555a")
|
||||
|
||||
|
||||
def test_join_host_port():
|
||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
40
tests/utils_/test_serial_utils.py
Normal file
40
tests/utils_/test_serial_utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import check_embeddings_close
|
||||
from vllm.utils.serial_utils import (
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE,
|
||||
ENDIANNESS,
|
||||
binary2tensor,
|
||||
tensor2binary,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endianness", ENDIANNESS)
|
||||
@pytest.mark.parametrize("embed_dtype", EMBED_DTYPE_TO_TORCH_DTYPE.keys())
|
||||
@torch.inference_mode()
|
||||
def test_encode_and_decode(embed_dtype: str, endianness: str):
|
||||
for i in range(10):
|
||||
tensor = torch.rand(2, 3, 5, 7, 11, 13, device="cpu", dtype=torch.float32)
|
||||
shape = tensor.shape
|
||||
binary = tensor2binary(tensor, embed_dtype, endianness)
|
||||
new_tensor = binary2tensor(binary, shape, embed_dtype, endianness).to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
if embed_dtype in ["float32", "float16"]:
|
||||
torch.testing.assert_close(tensor, new_tensor, atol=0.001, rtol=0.001)
|
||||
elif embed_dtype == "bfloat16":
|
||||
torch.testing.assert_close(tensor, new_tensor, atol=0.01, rtol=0.01)
|
||||
else: # for fp8
|
||||
torch.testing.assert_close(tensor, new_tensor, atol=0.1, rtol=0.1)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=tensor.view(1, -1),
|
||||
embeddings_1_lst=new_tensor.view(1, -1),
|
||||
name_0="gt",
|
||||
name_1="new",
|
||||
tol=1e-2,
|
||||
)
|
||||
19
tests/utils_/test_system_utils.py
Normal file
19
tests/utils_/test_system_utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from vllm.utils.system_utils import unique_filepath
|
||||
|
||||
|
||||
def test_unique_filepath():
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt"
|
||||
paths = set()
|
||||
for i in range(10):
|
||||
path = unique_filepath(path_fn)
|
||||
path.write_text("test")
|
||||
paths.add(path)
|
||||
assert len(paths) == 10
|
||||
assert len(list(Path(temp_dir).glob("*.txt"))) == 10
|
||||
203
tests/utils_/test_tensor_schema.py
Normal file
203
tests/utils_/test_tensor_schema.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
|
||||
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
|
||||
from vllm.model_executor.models.hyperclovax_vision import HCXVisionVideoPixelInputs
|
||||
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
|
||||
|
||||
|
||||
def test_tensor_schema_valid_tensor():
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_optional_fields():
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(16, 64, 3, 32, 32),
|
||||
image_sizes=None,
|
||||
)
|
||||
|
||||
Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
|
||||
|
||||
|
||||
def test_tensor_schema_constant_dim_failure():
|
||||
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_invalid_types_in_list():
|
||||
with pytest.raises(TypeError, match="is not one of the expected types"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=[
|
||||
torch.randn(64, 3, 32, 32),
|
||||
"not_a_tensor",
|
||||
torch.randn(64, 3, 32, 32),
|
||||
],
|
||||
image_sizes=torch.randint(0, 256, (3, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_rank_mismatch():
|
||||
with pytest.raises(ValueError, match="has rank 3 but expected 5"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(16, 64, 3),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_missing_required_field():
|
||||
with pytest.raises(ValueError, match="Required field 'pixel_values' is missing"):
|
||||
Phi3VImagePixelInputs(
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_symbolic_dim_mismatch():
|
||||
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(12, 64, 3, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_list_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_variable_patch_counts_valid():
|
||||
# Each image has a different number of patches (p)
|
||||
# Each tensor has shape (p, 3, 32, 32)
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=[
|
||||
torch.randn(16, 3, 32, 32), # p = 16
|
||||
torch.randn(32, 3, 32, 32), # p = 32
|
||||
torch.randn(64, 3, 32, 32), # p = 64
|
||||
],
|
||||
image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_tuple_tensor_valid():
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_double_nested_tensors():
|
||||
x = torch.rand(4, 3, 32, 32)
|
||||
y = torch.rand(2, 3, 32, 32)
|
||||
|
||||
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
|
||||
|
||||
|
||||
def test_tensor_schema_inconsistent_shapes_in_list():
|
||||
with pytest.raises(ValueError, match="contains inconsistent shapes"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=[
|
||||
torch.randn(64, 3, 32, 32),
|
||||
torch.randn(64, 3, 16, 16),
|
||||
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
|
||||
],
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_empty_list():
|
||||
with pytest.raises(ValueError, match="is an empty sequence"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=[],
|
||||
image_sizes=torch.randint(0, 256, (0, 2)),
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_validation_disabled_skips_shape_check():
|
||||
# This should NOT raise, because validation is turned off
|
||||
# This would normally fail (dim[2] should be 3, not 4)
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=torch.randn(16, 64, 4, 32, 32),
|
||||
image_sizes=torch.randint(0, 256, (16, 2)),
|
||||
validate=False,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_valid_resolve_binding_dims():
|
||||
pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={"h": 336, "w": 336},
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_invalid_resolve_binding_dims():
|
||||
pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
|
||||
image_sizes = torch.randint(0, 256, (16, 2))
|
||||
|
||||
# Should raise because 'h' and 'w' don't match resolve bindings
|
||||
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
|
||||
Phi3VImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
resolve_bindings={"h": 336, "w": 336},
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim():
|
||||
input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160)
|
||||
input_features_mask = torch.randn(3, 8) # (b=3, fo=8)
|
||||
audio_embed_sizes = [8, 8, 8] # len = b = 3
|
||||
|
||||
GraniteSpeechAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
audio_embed_sizes=audio_embed_sizes,
|
||||
)
|
||||
|
||||
|
||||
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
|
||||
input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160)
|
||||
input_features_mask = torch.randn(4, 8) # (b=4, fo=8)
|
||||
audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b
|
||||
|
||||
with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
|
||||
GraniteSpeechAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
audio_embed_sizes=audio_embed_sizes,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_tensor_schema_with_static_last_dim():
|
||||
image_embeds = torch.randn(256, 1024)
|
||||
image_grid_thw = torch.randint(0, 4, (2, 3))
|
||||
|
||||
Glm4vImageEmbeddingInputs(
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_tensor_schema_with_static_last_dim():
|
||||
image_embeds = torch.randn(256, 1024)
|
||||
image_grid_thw = torch.randint(0, 4, (2, 4)) # Wrong last dim
|
||||
|
||||
with pytest.raises(ValueError, match="dim\\[1\\] expected 3, got 4"):
|
||||
Glm4vImageEmbeddingInputs(
|
||||
image_embeds=image_embeds,
|
||||
image_grid_thw=image_grid_thw,
|
||||
)
|
||||
128
tests/utils_/test_torch_utils.py
Normal file
128
tests/utils_/test_torch_utils.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.torch_utils import (
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
is_lossless_cast,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("src_dtype", "tgt_dtype", "expected_result"),
|
||||
[
|
||||
# Different precision_levels
|
||||
(torch.bool, torch.int8, True),
|
||||
(torch.bool, torch.float16, True),
|
||||
(torch.bool, torch.complex32, True),
|
||||
(torch.int64, torch.bool, False),
|
||||
(torch.int64, torch.float16, True),
|
||||
(torch.int64, torch.complex32, True),
|
||||
(torch.float64, torch.bool, False),
|
||||
(torch.float64, torch.int8, False),
|
||||
(torch.float64, torch.complex32, True),
|
||||
(torch.complex128, torch.bool, False),
|
||||
(torch.complex128, torch.int8, False),
|
||||
(torch.complex128, torch.float16, False),
|
||||
# precision_level=0
|
||||
(torch.bool, torch.bool, True),
|
||||
# precision_level=1
|
||||
(torch.int8, torch.int16, True),
|
||||
(torch.int16, torch.int8, False),
|
||||
(torch.uint8, torch.int8, False),
|
||||
(torch.int8, torch.uint8, False),
|
||||
# precision_level=2
|
||||
(torch.float16, torch.float32, True),
|
||||
(torch.float32, torch.float16, False),
|
||||
(torch.bfloat16, torch.float32, True),
|
||||
(torch.float32, torch.bfloat16, False),
|
||||
# precision_level=3
|
||||
(torch.complex32, torch.complex64, True),
|
||||
(torch.complex64, torch.complex32, False),
|
||||
],
|
||||
)
|
||||
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
|
||||
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dtypes", "expected_result"),
|
||||
[
|
||||
([torch.bool], torch.bool),
|
||||
([torch.bool, torch.int8], torch.int8),
|
||||
([torch.bool, torch.int8, torch.float16], torch.float16),
|
||||
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
|
||||
],
|
||||
)
|
||||
def test_common_broadcastable_dtype(dtypes, expected_result):
|
||||
assert common_broadcastable_dtype(dtypes) == expected_result
|
||||
|
||||
|
||||
def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
|
||||
import threading
|
||||
|
||||
child_stream = torch.cuda.Stream()
|
||||
thread_stream_ready = threading.Event()
|
||||
thread_can_exit = threading.Event()
|
||||
|
||||
def child_thread_func():
|
||||
with torch.cuda.stream(child_stream):
|
||||
thread_stream_ready.set()
|
||||
thread_can_exit.wait(timeout=10)
|
||||
|
||||
child_thread = threading.Thread(target=child_thread_func)
|
||||
child_thread.start()
|
||||
|
||||
try:
|
||||
assert thread_stream_ready.wait(timeout=5), (
|
||||
"Child thread failed to enter stream context in time"
|
||||
)
|
||||
|
||||
main_current_stream = current_stream()
|
||||
|
||||
assert main_current_stream != child_stream, (
|
||||
"Main thread's current_stream was contaminated by child thread"
|
||||
)
|
||||
assert main_current_stream == main_expected_stream, (
|
||||
f"Main thread's stream changed unexpectedly. "
|
||||
f"Expected {main_expected_stream}, got {main_current_stream}"
|
||||
)
|
||||
|
||||
thread_can_exit.set()
|
||||
|
||||
finally:
|
||||
child_thread.join(timeout=5)
|
||||
if child_thread.is_alive():
|
||||
pytest.fail("Child thread failed to exit properly")
|
||||
|
||||
|
||||
def test_current_stream_multithread():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
if current_platform.is_rocm():
|
||||
main_dedicated_stream = current_stream()
|
||||
|
||||
assert main_dedicated_stream.cuda_stream != 0, (
|
||||
"ROCm should create a dedicated stream, not use default stream (0x0)"
|
||||
)
|
||||
|
||||
main_stream_again = current_stream()
|
||||
assert main_stream_again == main_dedicated_stream, (
|
||||
"Multiple calls to current_stream should return the same dedicated stream"
|
||||
)
|
||||
|
||||
_test_stream_thread(main_dedicated_stream)
|
||||
else:
|
||||
main_default_stream = torch.cuda.default_stream()
|
||||
main_initial_stream = current_stream()
|
||||
|
||||
assert main_initial_stream == main_default_stream, (
|
||||
"First call to current_stream should return default stream on CUDA"
|
||||
)
|
||||
|
||||
_test_stream_thread(main_default_stream)
|
||||
Reference in New Issue
Block a user