Sync from v0.13
This commit is contained in:
0
tests/engine/__init__.py
Normal file
0
tests/engine/__init__.py
Normal file
@@ -1,270 +0,0 @@
|
||||
import random
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from tests.core.utils import create_seq_group
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.output_processor.multi_step import MultiStepOutputProcessor
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (Logprob, SequenceGroupOutput, SequenceOutput,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [1, 12])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_appends_token_ids(num_new_tokens: int, seq_output_len: int):
|
||||
"""Verify multi-step decoding appends token ids correctly.
|
||||
|
||||
We append token ids and verify all the token ids were appended correctly.
|
||||
Note that ignore_eos=True.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=1024,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=seq_output_len +
|
||||
num_new_tokens,
|
||||
ignore_eos=True),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] != new_token_ids
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
assert seq.get_token_ids()[-len(new_token_ids):] == new_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [5, 6, 7, 8])
|
||||
@pytest.mark.parametrize("max_tokens", [128 + 3])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_max_tokens(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, max_tokens: int):
|
||||
"""Verify tokens after max_tokens are dropped and not appended to the
|
||||
sequence.
|
||||
"""
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go over max tokens in len.
|
||||
assert seq.get_len() == seq_prompt_len + max_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:max_tokens - seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_respects_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""Verify the eos token id is included in the sequence, but subsequent
|
||||
tokens are dropped (not appended to sequence).
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens, ),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to not go beyond provided eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + (eos_index + 1)
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:eos_index + 1]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_prompt_len", [1024])
|
||||
@pytest.mark.parametrize("seq_output_len", [128])
|
||||
@pytest.mark.parametrize("num_new_tokens", [12])
|
||||
@pytest.mark.parametrize("seed", list(range(6)))
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_ignores_eos_token_id(num_new_tokens: int, seq_prompt_len: int,
|
||||
seq_output_len: int, seed: int):
|
||||
"""When sampling parameters dictate that we should ignore the eos token id,
|
||||
ensure all token ids are appended even if the eos token id is emitted.
|
||||
"""
|
||||
random.seed(seed)
|
||||
detokenizer = MagicMock(spec=Detokenizer)
|
||||
scheduler = MagicMock(spec=Scheduler)
|
||||
stop_checker = MagicMock(spec=StopChecker)
|
||||
seq_counter = Counter()
|
||||
|
||||
eos_token_id = 100
|
||||
|
||||
output_processor = MultiStepOutputProcessor(
|
||||
detokenizer=detokenizer,
|
||||
scheduler=scheduler,
|
||||
seq_counter=seq_counter,
|
||||
get_tokenizer_for_seq=lambda _: mock_tokenizer(eos_token_id),
|
||||
stop_checker=stop_checker,
|
||||
)
|
||||
|
||||
seq_group = create_seq_group(
|
||||
seq_prompt_len=seq_prompt_len,
|
||||
seq_output_lens=[seq_output_len],
|
||||
sampling_params=SamplingParams(
|
||||
# Ensure enough space.
|
||||
max_tokens=seq_output_len + num_new_tokens,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
seq = seq_group.get_seqs()[0]
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
new_token_ids = list(range(num_new_tokens))
|
||||
assert eos_token_id not in new_token_ids
|
||||
eos_index = random.randint(0, len(new_token_ids) - 1)
|
||||
new_token_ids[eos_index] = eos_token_id
|
||||
|
||||
outputs = [
|
||||
SequenceGroupOutput(
|
||||
samples=[
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq.seq_id,
|
||||
output_token=output_token,
|
||||
logprobs={output_token: Logprob(0.0)},
|
||||
)
|
||||
],
|
||||
prompt_logprobs=None,
|
||||
) for output_token in new_token_ids
|
||||
]
|
||||
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len
|
||||
output_processor.process_outputs(seq_group, outputs)
|
||||
|
||||
# Expect the processed sequence to go beyond eos.
|
||||
assert seq.get_len() == seq_prompt_len + seq_output_len + num_new_tokens
|
||||
|
||||
# Expect the correct tokens were appended.
|
||||
expected_appended_tokens = new_token_ids[:seq_output_len + num_new_tokens -
|
||||
seq_output_len]
|
||||
assert seq.get_token_ids(
|
||||
)[-len(expected_appended_tokens):] == expected_appended_tokens
|
||||
|
||||
|
||||
def mock_tokenizer(eos_token_id=1000):
|
||||
tokenizer = MagicMock(spec=PreTrainedTokenizer)
|
||||
tokenizer.eos_token_id = eos_token_id
|
||||
return tokenizer
|
||||
384
tests/engine/test_arg_utils.py
Normal file
384
tests/engine/test_arg_utils.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
contains_type,
|
||||
get_kwargs,
|
||||
get_type,
|
||||
get_type_hints,
|
||||
is_not_builtin,
|
||||
is_type,
|
||||
literal_to_kwargs,
|
||||
optional_type,
|
||||
parse_type,
|
||||
)
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type", "value", "expected"),
|
||||
[
|
||||
(int, "42", 42),
|
||||
(float, "3.14", 3.14),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
|
||||
],
|
||||
)
|
||||
def test_parse_type(type, value, expected):
|
||||
parse_type_func = parse_type(type)
|
||||
assert parse_type_func(value) == expected
|
||||
|
||||
|
||||
def test_optional_type():
|
||||
optional_type_func = optional_type(int)
|
||||
assert optional_type_func("None") is None
|
||||
assert optional_type_func("42") == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "type", "expected"),
|
||||
[
|
||||
(int, int, True),
|
||||
(int, float, False),
|
||||
(list[int], list, True),
|
||||
(list[int], tuple, False),
|
||||
(Literal[0, 1], Literal, True),
|
||||
],
|
||||
)
|
||||
def test_is_type(type_hint, type, expected):
|
||||
assert is_type(type_hint, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "type", "expected"),
|
||||
[
|
||||
({float, int}, int, True),
|
||||
({int, tuple}, int, True),
|
||||
({int, tuple[int]}, int, True),
|
||||
({int, tuple[int, ...]}, int, True),
|
||||
({int, tuple[int]}, float, False),
|
||||
({int, tuple[int, ...]}, float, False),
|
||||
({str, Literal["x", "y"]}, Literal, True),
|
||||
],
|
||||
)
|
||||
def test_contains_type(type_hints, type, expected):
|
||||
assert contains_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "type", "expected"),
|
||||
[
|
||||
({int, float}, int, int),
|
||||
({int, float}, str, None),
|
||||
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
|
||||
],
|
||||
)
|
||||
def test_get_type(type_hints, type, expected):
|
||||
assert get_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hints", "expected"),
|
||||
[
|
||||
({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
|
||||
({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
|
||||
({Literal[1, "a"]}, Exception),
|
||||
],
|
||||
)
|
||||
def test_literal_to_kwargs(type_hints, expected):
|
||||
context = nullcontext()
|
||||
if expected is Exception:
|
||||
context = pytest.raises(expected)
|
||||
with context:
|
||||
assert literal_to_kwargs(type_hints) == expected
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class NestedConfig:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
optional_bool: bool | None = None
|
||||
"""Optional bool with default None"""
|
||||
optional_literal: Literal["x", "y"] | None = None
|
||||
"""Optional literal with default None"""
|
||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||
"""Tuple with variable length"""
|
||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||
"""Tuple with fixed length"""
|
||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||
"""List with variable length"""
|
||||
list_literal: list[Literal[1, 2]] = field(default_factory=list)
|
||||
"""List with literal choices"""
|
||||
list_union: list[str | type[object]] = field(default_factory=list)
|
||||
"""List with union type"""
|
||||
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
|
||||
"""Set with variable length"""
|
||||
literal_literal: Literal[Literal[1], Literal[2]] = 1
|
||||
"""Literal of literals with default 1"""
|
||||
json_tip: dict = field(default_factory=dict)
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "expected"),
|
||||
[
|
||||
(int, False),
|
||||
(DummyConfig, True),
|
||||
],
|
||||
)
|
||||
def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("type_hint", "expected"),
|
||||
[
|
||||
(Annotated[int, "annotation"], {int}),
|
||||
(int | None, {int, type(None)}),
|
||||
(Annotated[int | None, "annotation"], {int, type(None)}),
|
||||
(Annotated[int, "annotation"] | None, {int, type(None)}),
|
||||
],
|
||||
ids=["Annotated", "or_None", "Annotated_or_None", "or_None_Annotated"],
|
||||
)
|
||||
def test_get_type_hints(type_hint, expected):
|
||||
assert get_type_hints(type_hint) == expected
|
||||
|
||||
|
||||
def test_get_kwargs():
|
||||
kwargs = get_kwargs(DummyConfig)
|
||||
print(kwargs)
|
||||
|
||||
# bools should not have their type set
|
||||
assert kwargs["regular_bool"].get("type") is None
|
||||
assert kwargs["optional_bool"].get("type") is None
|
||||
# optional literals should have None as a choice
|
||||
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
|
||||
# tuples should have the correct nargs
|
||||
assert kwargs["tuple_n"]["nargs"] == "+"
|
||||
assert kwargs["tuple_2"]["nargs"] == 2
|
||||
# lists should work
|
||||
assert kwargs["list_n"]["type"] is int
|
||||
assert kwargs["list_n"]["nargs"] == "+"
|
||||
# lists with literals should have the correct choices
|
||||
assert kwargs["list_literal"]["type"] is int
|
||||
assert kwargs["list_literal"]["nargs"] == "+"
|
||||
assert kwargs["list_literal"]["choices"] == [1, 2]
|
||||
# lists with unions should become str type.
|
||||
# If not, we cannot know which type to use for parsing
|
||||
assert kwargs["list_union"]["type"] is str
|
||||
# sets should work like lists
|
||||
assert kwargs["set_n"]["type"] is int
|
||||
assert kwargs["set_n"]["nargs"] == "+"
|
||||
# literals of literals should have merged choices
|
||||
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
||||
# dict should have json tip in help
|
||||
json_tip = "Should either be a valid JSON string or JSON keys"
|
||||
assert json_tip in kwargs["json_tip"]["help"]
|
||||
# nested config should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg", "expected"),
|
||||
[
|
||||
(None, dict()),
|
||||
('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
|
||||
(
|
||||
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
|
||||
{
|
||||
"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
|
||||
"image": {"foo": "bar"},
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_media_io_kwargs_parser(arg, expected):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--media-io-kwargs", arg])
|
||||
|
||||
assert args.media_io_kwargs == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
[
|
||||
(["-O", "1"], "1"),
|
||||
(["-O", "2"], "2"),
|
||||
(["-O", "3"], "3"),
|
||||
(["-O0"], "0"),
|
||||
(["-O1"], "1"),
|
||||
(["-O2"], "2"),
|
||||
(["-O3"], "3"),
|
||||
],
|
||||
)
|
||||
def test_optimization_level(args, expected):
|
||||
"""
|
||||
Test space-separated optimization levels (-O 1, -O 2, -O 3) map to
|
||||
optimization_level.
|
||||
"""
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.optimization_level == expected
|
||||
assert parsed_args.compilation_config.mode is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
[
|
||||
(["-cc.mode=0"], 0),
|
||||
(["-cc.mode=1"], 1),
|
||||
(["-cc.mode=2"], 2),
|
||||
(["-cc.mode=3"], 3),
|
||||
],
|
||||
)
|
||||
def test_mode_parser(args, expected):
|
||||
"""
|
||||
Test compilation config modes (-cc.mode=int) map to compilation_config.
|
||||
"""
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
parsed_args = parser.parse_args(args)
|
||||
assert parsed_args.compilation_config.mode == expected
|
||||
|
||||
|
||||
def test_compilation_config():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# default value
|
||||
args = parser.parse_args([])
|
||||
assert args.compilation_config == CompilationConfig()
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"-cc",
|
||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.mode == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.backend == "eager"
|
||||
)
|
||||
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args(
|
||||
[
|
||||
"--compilation-config="
|
||||
'{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
|
||||
'"backend": "inductor"}',
|
||||
]
|
||||
)
|
||||
assert (
|
||||
args.compilation_config.mode == 3
|
||||
and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
|
||||
and args.compilation_config.backend == "inductor"
|
||||
)
|
||||
|
||||
|
||||
def test_prefix_cache_default():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
|
||||
# should be None by default (depends on model).
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert engine_args.enable_prefix_caching is None
|
||||
|
||||
# with flag to turn it on.
|
||||
args = parser.parse_args(["--enable-prefix-caching"])
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert engine_args.enable_prefix_caching
|
||||
|
||||
# with disable flag to turn it off.
|
||||
args = parser.parse_args(["--no-enable-prefix-caching"])
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert not engine_args.enable_prefix_caching
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg", "expected", "option"),
|
||||
[
|
||||
(None, None, "mm-processor-kwargs"),
|
||||
("{}", {}, "mm-processor-kwargs"),
|
||||
('{"num_crops": 4}', {"num_crops": 4}, "mm-processor-kwargs"),
|
||||
('{"foo": {"bar": "baz"}}', {"foo": {"bar": "baz"}}, "mm-processor-kwargs"),
|
||||
],
|
||||
)
|
||||
def test_composite_arg_parser(arg, expected, option):
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args([f"--{option}", arg])
|
||||
assert getattr(args, option.replace("-", "_")) == expected
|
||||
|
||||
|
||||
def test_human_readable_model_len():
|
||||
# `exit_on_error` disabled to test invalid values below
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser(exit_on_error=False))
|
||||
|
||||
args = parser.parse_args([])
|
||||
assert args.max_model_len is None
|
||||
|
||||
args = parser.parse_args(["--max-model-len", "1024"])
|
||||
assert args.max_model_len == 1024
|
||||
|
||||
# Lower
|
||||
args = parser.parse_args(["--max-model-len", "1m"])
|
||||
assert args.max_model_len == 1_000_000
|
||||
args = parser.parse_args(["--max-model-len", "10k"])
|
||||
assert args.max_model_len == 10_000
|
||||
args = parser.parse_args(["--max-model-len", "2g"])
|
||||
assert args.max_model_len == 2_000_000_000
|
||||
args = parser.parse_args(["--max-model-len", "2t"])
|
||||
assert args.max_model_len == 2_000_000_000_000
|
||||
|
||||
# Capital
|
||||
args = parser.parse_args(["--max-model-len", "3K"])
|
||||
assert args.max_model_len == 2**10 * 3
|
||||
args = parser.parse_args(["--max-model-len", "10M"])
|
||||
assert args.max_model_len == 2**20 * 10
|
||||
args = parser.parse_args(["--max-model-len", "4G"])
|
||||
assert args.max_model_len == 2**30 * 4
|
||||
args = parser.parse_args(["--max-model-len", "4T"])
|
||||
assert args.max_model_len == 2**40 * 4
|
||||
|
||||
# Decimal values
|
||||
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||
assert args.max_model_len == 10200
|
||||
# ..truncated to the nearest int
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
|
||||
assert args.max_model_len == 10212
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
|
||||
assert args.max_model_len == 10212345
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
|
||||
assert args.max_model_len == 10212345123
|
||||
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
|
||||
assert args.max_model_len == 10212345123456
|
||||
|
||||
# Invalid (do not allow decimals with binary multipliers)
|
||||
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||
with pytest.raises(ArgumentError):
|
||||
parser.parse_args(["--max-model-len", invalid])
|
||||
@@ -1,34 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_computed_prefix_blocks(model: str, block_size: int):
|
||||
# This test checks if we are able to run the engine to completion
|
||||
# without triggering asserts.
|
||||
# We are in a scenario where all blocks from the second request's prompt
|
||||
# are full and already computed when the second request arrives.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
prompt2 = (
|
||||
" Please recommend to me some resources where I can learn not only to "
|
||||
"handle technical difficulties of building a car, but also "
|
||||
"decoration.")
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
block_size=block_size,
|
||||
enable_prefix_caching=True)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
engine.add_request("0", prompt + prompt2, sampling_params)
|
||||
engine.step()
|
||||
engine.add_request("1", prompt, sampling_params)
|
||||
engine.step()
|
||||
@@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
def test_computed_prefix_blocks(model: str):
|
||||
# This test checks if the engine generates completions both with and
|
||||
# without optional detokenization, that detokenization includes text
|
||||
# and no-detokenization doesn't, and that both completions have the same
|
||||
# token_ids.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
|
||||
llm = LLM(model=model)
|
||||
sampling_params = SamplingParams(max_tokens=10,
|
||||
temperature=0.0,
|
||||
detokenize=False)
|
||||
|
||||
outputs_no_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
sampling_params.detokenize = True
|
||||
outputs_with_detokenization = llm.generate(prompt,
|
||||
sampling_params)[0].outputs[0]
|
||||
|
||||
assert outputs_no_detokenization.text == ''
|
||||
assert outputs_with_detokenization.text != ''
|
||||
assert outputs_no_detokenization.token_ids == \
|
||||
outputs_with_detokenization.token_ids
|
||||
@@ -1,176 +0,0 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
|
||||
|
||||
class DummyWorker:
|
||||
"""Dummy version of vllm.worker.worker.Worker"""
|
||||
|
||||
def __init__(self, rank: int):
|
||||
self.rank = rank
|
||||
|
||||
def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
|
||||
sleep(0.05)
|
||||
|
||||
if isinstance(worker_input, Exception):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rank, input
|
||||
|
||||
|
||||
def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
result_handler = ResultHandler()
|
||||
workers = [
|
||||
ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank))
|
||||
for rank in range(8)
|
||||
]
|
||||
|
||||
worker_monitor = WorkerMonitor(workers, result_handler)
|
||||
assert not worker_monitor.is_alive()
|
||||
|
||||
result_handler.start()
|
||||
worker_monitor.start()
|
||||
assert worker_monitor.is_alive()
|
||||
|
||||
return workers, worker_monitor
|
||||
|
||||
|
||||
def test_local_workers() -> None:
|
||||
"""Test workers with sync task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
def execute_workers(worker_input: str) -> None:
|
||||
worker_outputs = [
|
||||
worker.execute_method("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
for rank, output in enumerate(worker_outputs):
|
||||
assert output.get() == (rank, input)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Test concurrent submission from different threads
|
||||
futures = [
|
||||
executor.submit(partial(execute_workers, f"thread {thread_num}"))
|
||||
for thread_num in range(4)
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
result = workers[0].execute_method("worker_method", exception)
|
||||
try:
|
||||
result.get()
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(2)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
def test_local_workers_clean_shutdown() -> None:
|
||||
"""Test clean shutdown"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
assert worker_monitor.is_alive()
|
||||
assert all(worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Clean shutdown
|
||||
worker_monitor.close()
|
||||
|
||||
worker_monitor.join(5)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_workers_async() -> None:
|
||||
"""Test local workers with async task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
async def execute_workers(worker_input: str) -> None:
|
||||
worker_coros = [
|
||||
worker.execute_method_async("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*worker_coros)
|
||||
for rank, result in enumerate(results):
|
||||
assert result == (rank, input)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(execute_workers(f"task {task_num}"))
|
||||
for task_num in range(4)
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", exception)
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(2)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
33
tests/engine/test_short_mm_context.py
Normal file
33
tests/engine/test_short_mm_context.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import IMAGE_ASSETS
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
|
||||
"cherry_blossom": "USER: <image>\nWhat is the season?\nASSISTANT:",
|
||||
}
|
||||
)
|
||||
|
||||
models = ["llava-hf/llava-1.5-7b-hf"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
with pytest.raises(ValueError, match="longer than the maximum model length"):
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
vllm_model.generate_greedy(
|
||||
[HF_IMAGE_PROMPTS[0]], max_tokens=1, images=[images[0]]
|
||||
)
|
||||
@@ -1,23 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(model=model, skip_tokenizer_init=True)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
with pytest.raises(ValueError) as err:
|
||||
llm.generate("abc", sampling_params)
|
||||
assert "prompts must be None if" in str(err.value)
|
||||
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
@@ -1,59 +0,0 @@
|
||||
"""Test the different finish_reason="stop" situations during generation:
|
||||
1. One of the provided stop strings
|
||||
2. One of the provided stop tokens
|
||||
3. The EOS token
|
||||
|
||||
Run `pytest tests/engine/test_stop_reason.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import transformers
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
MODEL = "facebook/opt-350m"
|
||||
STOP_STR = "."
|
||||
SEED = 42
|
||||
MAX_TOKENS = 1024
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner):
|
||||
vllm_model = vllm_runner(MODEL)
|
||||
yield vllm_model
|
||||
del vllm_model
|
||||
|
||||
|
||||
def test_stop_reason(vllm_model, example_prompts):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
|
||||
stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR)
|
||||
llm = vllm_model.model
|
||||
|
||||
# test stop token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_token_ids=[stop_token_id]))
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == stop_token_id
|
||||
|
||||
# test stop string
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED, max_tokens=MAX_TOKENS, stop="."))
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "stop"
|
||||
assert output.stop_reason == STOP_STR
|
||||
|
||||
# test EOS token
|
||||
outputs = llm.generate(example_prompts,
|
||||
sampling_params=SamplingParams(
|
||||
seed=SEED, max_tokens=MAX_TOKENS))
|
||||
for output in outputs:
|
||||
output = output.outputs[0]
|
||||
assert output.finish_reason == "length" or (
|
||||
output.finish_reason == "stop" and output.stop_reason is None)
|
||||
@@ -1,111 +0,0 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
return vllm_runner(MODEL)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
), None)
|
||||
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
|
||||
# Ensure we don't backtrack
|
||||
assert output.text.startswith(output_text)
|
||||
output_text = output.text
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
||||
Reference in New Issue
Block a user