Sync from v0.13
This commit is contained in:
119
tests/plugins_tests/test_io_processor_plugins.py
Normal file
119
tests/plugins_tests/test_io_processor_plugins.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--skip-tokenizer-init",
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff",
|
||||
"--model-impl",
|
||||
"terratorch",
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_prithvi_mae_plugin_online(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
):
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": model_name,
|
||||
"softmax": False,
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload_url,
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response))
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
plugin_data = parsed_response.data
|
||||
|
||||
assert all(
|
||||
plugin_data.get(attr) for attr in ["type", "format", "data", "request_id"]
|
||||
)
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(plugin_data["data"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
enable_mm_embeds=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
model_impl="terratorch",
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr) for attr in ["type", "format", "data", "request_id"]
|
||||
)
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
47
tests/plugins_tests/test_platform_plugins.py
Normal file
47
tests/plugins_tests/test_platform_plugins.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
|
||||
def test_platform_plugins():
|
||||
# simulate workload by running an example
|
||||
import runpy
|
||||
|
||||
current_file = __file__
|
||||
import os
|
||||
|
||||
example_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(current_file))),
|
||||
"examples",
|
||||
"offline_inference/basic/basic.py",
|
||||
)
|
||||
runpy.run_path(example_file)
|
||||
|
||||
# check if the plugin is loaded correctly
|
||||
from vllm.platforms import _init_trace, current_platform
|
||||
|
||||
assert current_platform.device_name == "DummyDevice", (
|
||||
f"Expected DummyDevice, got {current_platform.device_name}, "
|
||||
"possibly because current_platform is imported before the plugin"
|
||||
f" is loaded. The first import:\n{_init_trace}"
|
||||
)
|
||||
|
||||
|
||||
def test_oot_custom_op(monkeypatch: pytest.MonkeyPatch):
|
||||
# simulate workload by running an example
|
||||
load_general_plugins()
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
layer = RotaryEmbedding(16, 16, 16, 16, True, torch.float16)
|
||||
assert layer.__class__.__name__ == "DummyRotaryEmbedding", (
|
||||
f"Expected DummyRotaryEmbedding, got {layer.__class__.__name__}, "
|
||||
"possibly because the custom op is not registered correctly."
|
||||
)
|
||||
assert hasattr(layer, "addition_config"), (
|
||||
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
|
||||
"which is set by the custom op."
|
||||
)
|
||||
36
tests/plugins_tests/test_scheduler_plugins.py
Normal file
36
tests/plugins_tests/test_scheduler_plugins.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
|
||||
class DummyV1Scheduler(Scheduler):
|
||||
def schedule(self):
|
||||
raise Exception("Exception raised by DummyV1Scheduler")
|
||||
|
||||
|
||||
def test_scheduler_plugins_v1(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
# Explicitly turn off engine multiprocessing so
|
||||
# that the scheduler runs in this process
|
||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
with pytest.raises(Exception) as exception_info:
|
||||
engine_args = EngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True, # reduce test time
|
||||
scheduler_cls=DummyV1Scheduler,
|
||||
)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args=engine_args)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
engine.add_request("0", "foo", sampling_params)
|
||||
engine.step()
|
||||
|
||||
assert str(exception_info.value) == "Exception raised by DummyV1Scheduler"
|
||||
76
tests/plugins_tests/test_stats_logger_plugins.py
Normal file
76
tests/plugins_tests/test_stats_logger_plugins.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from dummy_stat_logger.dummy_stat_logger import DummyStatLogger
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import load_stat_logger_plugin_factories
|
||||
|
||||
|
||||
def test_stat_logger_plugin_is_discovered(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")
|
||||
|
||||
factories = load_stat_logger_plugin_factories()
|
||||
assert len(factories) == 1, f"Expected 1 factory, got {len(factories)}"
|
||||
assert factories[0] is DummyStatLogger, (
|
||||
f"Expected DummyStatLogger class, got {factories[0]}"
|
||||
)
|
||||
|
||||
# instantiate and confirm the right type
|
||||
vllm_config = VllmConfig()
|
||||
instance = factories[0](vllm_config)
|
||||
assert isinstance(instance, DummyStatLogger)
|
||||
|
||||
|
||||
def test_no_plugins_loaded_if_env_empty(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PLUGINS", "")
|
||||
|
||||
factories = load_stat_logger_plugin_factories()
|
||||
assert factories == []
|
||||
|
||||
|
||||
def test_invalid_stat_logger_plugin_raises(monkeypatch: pytest.MonkeyPatch):
|
||||
def fake_plugin_loader(group: str):
|
||||
assert group == "vllm.stat_logger_plugins"
|
||||
return {"bad": object()}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setattr(
|
||||
"vllm.v1.metrics.loggers.load_plugins_by_group",
|
||||
fake_plugin_loader,
|
||||
)
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="Stat logger plugin 'bad' must be a subclass of StatLoggerBase",
|
||||
):
|
||||
load_stat_logger_plugin_factories()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stat_logger_plugin_integration_with_engine(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True, # reduce test time
|
||||
disable_log_stats=True, # disable default loggers
|
||||
)
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args=engine_args)
|
||||
|
||||
assert len(engine.logger_manager.stat_loggers) == 2
|
||||
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
|
||||
assert isinstance(
|
||||
engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
|
||||
DummyStatLogger,
|
||||
)
|
||||
|
||||
engine.shutdown()
|
||||
Reference in New Issue
Block a user