Sync from v0.13
This commit is contained in:
0
tests/v1/engine/__init__.py
Normal file
0
tests/v1/engine/__init__.py
Normal file
90
tests/v1/engine/conftest.py
Normal file
90
tests/v1/engine/conftest.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.v1.engine.utils import (
|
||||
FULL_STRINGS,
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
PROMPT_LEN,
|
||||
TOKENIZER_NAME,
|
||||
DummyOutputProcessorTestVectors,
|
||||
generate_dummy_prompt_logprobs_tensors,
|
||||
generate_dummy_sample_logprobs,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
|
||||
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
||||
EngineCorePromptLogprobsType = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, without logprobs
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with no logprobs
|
||||
"""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
|
||||
# Tokenize prompts under test & create dummy generated tokens
|
||||
prompt_tokens = [tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS]
|
||||
generation_tokens = [
|
||||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
|
||||
]
|
||||
# Generate prompt strings
|
||||
prompt_strings = [
|
||||
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
||||
for prompt_tokens in prompt_tokens
|
||||
]
|
||||
prompt_strings_len = [len(prompt_string) for prompt_string in prompt_strings]
|
||||
return DummyOutputProcessorTestVectors(
|
||||
tokenizer=tokenizer,
|
||||
vllm_config=vllm_config,
|
||||
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
|
||||
prompt_tokens=prompt_tokens,
|
||||
generation_tokens=generation_tokens,
|
||||
prompt_strings=prompt_strings,
|
||||
prompt_strings_len=prompt_strings_len,
|
||||
generation_strings=[
|
||||
text[prompt_len:]
|
||||
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
|
||||
],
|
||||
prompt_logprobs=[],
|
||||
generation_logprobs=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, with logprobs
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with logprobs
|
||||
"""
|
||||
# Build dummy test vectors without logprobs
|
||||
dtv = _build_test_vectors_no_logprobs()
|
||||
# Inject logprobs into dummy test vectors
|
||||
# data structure
|
||||
dtv.generation_logprobs = [
|
||||
generate_dummy_sample_logprobs(
|
||||
sampled_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer,
|
||||
)
|
||||
for tokens_list in dtv.generation_tokens
|
||||
]
|
||||
dtv.prompt_logprobs = [
|
||||
generate_dummy_prompt_logprobs_tensors(
|
||||
prompt_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer,
|
||||
)
|
||||
for tokens_list in dtv.prompt_tokens
|
||||
]
|
||||
return dtv
|
||||
311
tests/v1/engine/test_abort_final_step.py
Normal file
311
tests/v1/engine/test_abort_final_step.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Test for the fix in PR #29987: Eagerly abort cancelled final-step requests.
|
||||
|
||||
This test verifies that when a request is aborted during its final execution
|
||||
step (when it would naturally complete), it is properly marked as aborted
|
||||
rather than being treated as normally completed.
|
||||
|
||||
The test uses a dummy KV connector to verify that the connector receives
|
||||
the correct finish status (FINISHED_ABORTED, not FINISHED_LENGTH_CAPPED).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
TEXT_PROMPT = "Hello"
|
||||
|
||||
|
||||
class DummyKVConnectorMetadata(KVConnectorMetadata):
|
||||
"""Dummy metadata for the test connector."""
|
||||
|
||||
def __init__(self):
|
||||
self.requests: list = []
|
||||
|
||||
|
||||
class DummyKVConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Dummy KV connector that captures request finish statuses to a file.
|
||||
This is used to verify the fix - without the fix, a request aborted
|
||||
during its final step would be captured as FINISHED_LENGTH_CAPPED
|
||||
instead of FINISHED_ABORTED.
|
||||
|
||||
The connector runs in a separate process, so we write statuses to a file
|
||||
that can be read by the test process.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
# Get the status file path from extra config
|
||||
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config or {}
|
||||
self.status_file = extra_config.get("status_file")
|
||||
# Log that we were initialized
|
||||
if self.status_file:
|
||||
try:
|
||||
with open(self.status_file, "a") as f:
|
||||
f.write(f"INIT:{role.name}\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
return (0, False)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: Any,
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
return DummyKVConnectorMetadata()
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Capture the request status when finished by writing to a file."""
|
||||
if self.status_file:
|
||||
try:
|
||||
with open(self.status_file, "a") as f:
|
||||
# Write the status name (e.g., "FINISHED_ABORTED")
|
||||
f.write(f"{request.status.name}\n")
|
||||
except Exception as e:
|
||||
# Log but don't fail - this is just test instrumentation
|
||||
print(f"[DummyKVConnector] Failed to write status: {e}")
|
||||
return False, None
|
||||
|
||||
def start_load_kv(self, forward_context: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: Any,
|
||||
attn_metadata: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
# Register the dummy connector
|
||||
KVConnectorFactory.register_connector(
|
||||
"DummyKVConnector", __name__, DummyKVConnector.__name__
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_scheduling", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_during_final_step(async_scheduling: bool):
|
||||
"""
|
||||
Test that a request aborted during its final execution step is treated as
|
||||
aborted rather than completed.
|
||||
|
||||
This test:
|
||||
1. Monkeypatches execute_model to wait for a file to be deleted
|
||||
2. Configures a dummy KV connector to capture finish statuses
|
||||
3. Starts a request with max_tokens=1 (will complete on first decode step)
|
||||
4. Aborts the request, then deletes the file to unblock execute_model
|
||||
5. Verifies the KV connector received FINISHED_ABORTED not FINISHED_LENGTH_CAPPED
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/29987.
|
||||
|
||||
Without the fix, the KV connector would see FINISHED_LENGTH_CAPPED because
|
||||
update_from_output() would mark the request as completed before processing
|
||||
the abort. This causes KV cache blocks to not be freed properly in
|
||||
disaggregated prefill scenarios.
|
||||
|
||||
With the fix, _process_aborts_queue() runs before update_from_output(), so the
|
||||
abort takes precedence and the KV connector sees FINISHED_ABORTED.
|
||||
"""
|
||||
|
||||
# Create three temporary files:
|
||||
# 1. ready_file: deleted by execute_model to signal it has started
|
||||
# 2. block_file: execute_model waits for this to be deleted
|
||||
# 3. status_file: KV connector writes finish statuses here
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f:
|
||||
ready_file = Path(f.name)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as f2:
|
||||
block_file = Path(f2.name)
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f3:
|
||||
status_file = Path(f3.name)
|
||||
|
||||
try:
|
||||
# Get the original execute_model method
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
original_execute_model = Worker.execute_model
|
||||
|
||||
def execute_model_with_wait(self, scheduler_output):
|
||||
# Signal that execute_model has been called by deleting ready_file
|
||||
if ready_file.exists():
|
||||
ready_file.unlink()
|
||||
|
||||
# Wait for the block file to be deleted (triggered from test after abort)
|
||||
# This runs in the worker process (after fork), so we poll the filesystem
|
||||
while block_file.exists():
|
||||
time.sleep(0.01)
|
||||
return original_execute_model(self, scheduler_output)
|
||||
|
||||
# Patch execute_model to inject the wait
|
||||
# This happens before the worker process is forked, so the patch applies there
|
||||
with patch.object(Worker, "execute_model", execute_model_with_wait):
|
||||
request_id = "test-abort-final-step"
|
||||
|
||||
# Configure engine with dummy KV connector
|
||||
# Pass the status file path so the connector can write to it
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="DummyKVConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"status_file": str(status_file)},
|
||||
)
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
async_scheduling=async_scheduling,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
|
||||
try:
|
||||
# Create a request that will complete after just 1 token
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
ignore_eos=True,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
)
|
||||
|
||||
# Start generation in a task
|
||||
outputs = []
|
||||
|
||||
async def generate():
|
||||
async for output in engine.generate(
|
||||
request_id=request_id,
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
outputs.append(output)
|
||||
|
||||
gen_task = asyncio.create_task(generate())
|
||||
|
||||
# Wait for execute_model to signal it has started (with timeout)
|
||||
timeout = 5.0 # 5 second timeout
|
||||
start_time = time.time()
|
||||
while ready_file.exists():
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
"Timeout waiting for execute_model to start. "
|
||||
"The monkeypatch may not be working correctly, "
|
||||
"for example if spawn was used instead of fork."
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Abort the request while execute_model is blocked
|
||||
await engine.abort(request_id)
|
||||
|
||||
# Now unblock execute_model by deleting the file
|
||||
# The abort should be processed before the model output
|
||||
block_file.unlink()
|
||||
|
||||
# Wait for generation to complete
|
||||
await gen_task
|
||||
|
||||
# Give the scheduler a moment to finish cleanup
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Verify we got output
|
||||
assert len(outputs) > 0, "Should have received at least one output"
|
||||
|
||||
# The final output should have finish_reason="abort"
|
||||
final_output = outputs[-1]
|
||||
assert final_output.finished, (
|
||||
"Final output should be marked as finished"
|
||||
)
|
||||
assert final_output.outputs[0].finish_reason == "abort", (
|
||||
f"Expected finish_reason='abort' but got "
|
||||
f"'{final_output.outputs[0].finish_reason}'. "
|
||||
)
|
||||
|
||||
with open(status_file) as f4:
|
||||
status_lines = f4.read().strip().split("\n")
|
||||
# Filter for actual finish statuses (not INIT or empty lines)
|
||||
captured_statuses = [
|
||||
line
|
||||
for line in status_lines
|
||||
if line and line.startswith("FINISHED_")
|
||||
]
|
||||
|
||||
assert len(captured_statuses) >= 1, (
|
||||
f"Expected at least 1 captured finish status, got "
|
||||
f"{len(captured_statuses)}. File content: {status_lines}"
|
||||
)
|
||||
|
||||
assert "FINISHED_ABORTED" in captured_statuses, (
|
||||
f"KV connector should see FINISHED_ABORTED but got "
|
||||
f"{captured_statuses}. "
|
||||
)
|
||||
|
||||
# Verify cleanup
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
finally:
|
||||
# Shutdown the engine
|
||||
engine.shutdown()
|
||||
|
||||
finally:
|
||||
# Clean up temporary files if they still exist
|
||||
if ready_file.exists():
|
||||
ready_file.unlink()
|
||||
if block_file.exists():
|
||||
block_file.unlink()
|
||||
if status_file.exists():
|
||||
status_file.unlink()
|
||||
599
tests/v1/engine/test_async_llm.py
Normal file
599
tests/v1/engine/test_async_llm.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import (
|
||||
AggregatedLoggingStatLogger,
|
||||
LoggingStatLogger,
|
||||
PerEngineStatLoggerAdapter,
|
||||
PrometheusStatLogger,
|
||||
)
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
TEXT_ENGINE_ARGS = AsyncEngineArgs(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
VISION_ENGINE_ARGS = AsyncEngineArgs(
|
||||
model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
|
||||
)
|
||||
|
||||
TEXT_PROMPT = "Hello my name is Robert and"
|
||||
|
||||
VISION_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
|
||||
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"What is in the image?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
VISION_PROMPT = {
|
||||
"prompt": VISION_PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
|
||||
}
|
||||
|
||||
|
||||
async def generate(
|
||||
engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int,
|
||||
n: int = 1,
|
||||
prompt_logprobs: int | None = None,
|
||||
cancel_after: int | None = None,
|
||||
) -> tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
count = 0
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0.5,
|
||||
seed=33,
|
||||
n=n,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
async for out in engine.generate(
|
||||
request_id=request_id, prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
num_tokens = sum(len(output.token_ids) for output in out.outputs)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
if cancel_after is not None and count >= cancel_after:
|
||||
return count, request_id
|
||||
|
||||
await asyncio.sleep(0.0)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(
|
||||
engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}"
|
||||
)
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(
|
||||
output_kind: RequestOutputKind,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 100
|
||||
NUM_EXPECTED_TOKENS_LONG = 50000
|
||||
REQUEST_IDS_TO_ABORT = range(1, 100, 10)
|
||||
PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = (
|
||||
NUM_EXPECTED_TOKENS_LONG
|
||||
if (idx in REQUEST_IDS_TO_ABORT)
|
||||
else NUM_EXPECTED_TOKENS
|
||||
)
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind, max_tokens, n)
|
||||
)
|
||||
)
|
||||
|
||||
# API server cancels requests when they disconnect.
|
||||
for idx in REQUEST_IDS_TO_ABORT:
|
||||
tasks[idx].cancel()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Confirm the other requests are okay.
|
||||
for idx, task in enumerate(tasks):
|
||||
# Confirm that it was actually canceled.
|
||||
if idx in REQUEST_IDS_TO_ABORT:
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
else:
|
||||
# Otherwise, make sure the request was not impacted.
|
||||
num_generated_tokens, request_id = await task
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
expected_tokens = NUM_EXPECTED_TOKENS * n
|
||||
assert num_generated_tokens == expected_tokens, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {expected_tokens}"
|
||||
)
|
||||
|
||||
# Make sure all aborted requests were really aborted.
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# Confirm we can do another generation.
|
||||
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
|
||||
)
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_abort(output_kind: RequestOutputKind):
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 50
|
||||
NUM_EXPECTED_TOKENS = 100
|
||||
NUM_EXPECTED_TOKENS_LONG = 50000
|
||||
REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
|
||||
PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests.
|
||||
tasks: list[asyncio.Task] = []
|
||||
for idx, request_id in enumerate(request_ids):
|
||||
max_tokens = (
|
||||
NUM_EXPECTED_TOKENS_LONG
|
||||
if (idx in REQUEST_IDS_TO_ABORT)
|
||||
else NUM_EXPECTED_TOKENS
|
||||
)
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(
|
||||
engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Let requests start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Use multi-abort to abort multiple requests at once
|
||||
abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
|
||||
await engine.abort(abort_request_ids)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Verify results
|
||||
for idx, result in enumerate(results):
|
||||
if idx in REQUEST_IDS_TO_ABORT:
|
||||
# Aborted requests should return partial results
|
||||
assert isinstance(result, tuple), (
|
||||
f"Request {idx} should have completed with partial results"
|
||||
)
|
||||
num_generated_tokens, request_id = result
|
||||
# Should have generated some tokens before abort
|
||||
assert num_generated_tokens > 0, (
|
||||
f"Aborted request {request_id} should have generated some tokens"
|
||||
)
|
||||
else:
|
||||
# Non-aborted requests should complete normally
|
||||
assert isinstance(result, tuple), (
|
||||
f"Request {idx} should have completed successfully"
|
||||
)
|
||||
num_generated_tokens, request_id = result
|
||||
n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
|
||||
expected_tokens = NUM_EXPECTED_TOKENS * n
|
||||
assert num_generated_tokens == expected_tokens, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {expected_tokens}"
|
||||
)
|
||||
|
||||
# Make sure all aborted requests were cleaned up
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n", [1, 3])
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_finished_flag(
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
prompt: PromptType,
|
||||
):
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
n=n,
|
||||
)
|
||||
outputs = [
|
||||
out
|
||||
async for out in engine.generate(
|
||||
request_id="request-33", prompt=prompt, sampling_params=sampling_params
|
||||
)
|
||||
]
|
||||
|
||||
# Assert only the last output has the finished flag set
|
||||
assert all(not out.finished for out in outputs[:-1])
|
||||
assert outputs[-1].finished
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"engine_args,prompt",
|
||||
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_stream_cancellation(
|
||||
engine_args: AsyncEngineArgs, prompt: PromptType
|
||||
):
|
||||
"""Test that requests can be cancelled mid-stream."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_TOKENS = 1000
|
||||
NUM_EXPECTED_TOKENS = 20
|
||||
|
||||
request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]
|
||||
|
||||
# Create concurrent requests that will be cancelled mid-stream
|
||||
tasks = []
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(
|
||||
engine,
|
||||
request_id,
|
||||
prompt,
|
||||
RequestOutputKind.DELTA,
|
||||
NUM_TOKENS,
|
||||
cancel_after=NUM_EXPECTED_TOKENS,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all tasks were cancelled at the expected point
|
||||
for num_generated_tokens, request_id in results:
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} tokens but "
|
||||
f"expected to cancel after {NUM_EXPECTED_TOKENS}"
|
||||
)
|
||||
|
||||
# Make sure no requests are left hanging
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
# Confirm we can reuse the request id after the cancellations.
|
||||
request_id = request_ids[0]
|
||||
task = asyncio.create_task(
|
||||
generate(
|
||||
engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
|
||||
)
|
||||
)
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
class MockLoggingStatLogger(LoggingStatLogger):
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
super().__init__(vllm_config, engine_index)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
|
||||
def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
|
||||
super().__init__(vllm_config, engine_indexes)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.stat_loggers
|
||||
assert (
|
||||
len(stat_loggers) == 3
|
||||
) # MockLoggingStatLogger + LoggingStatLogger + Promethus Logger
|
||||
print(f"{stat_loggers=}")
|
||||
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
|
||||
assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
|
||||
assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
|
||||
assert isinstance(stat_loggers[2], PrometheusStatLogger)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_aggregated_loggers():
|
||||
"""Test that we can customize the aggregated loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.stat_loggers
|
||||
assert len(stat_loggers) == 4
|
||||
# MockLoggingStatLogger + MockAggregatedStatLogger
|
||||
# + LoggingStatLogger + PrometheusStatLogger
|
||||
stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
|
||||
stat_loggers[1].log.assert_called_once()
|
||||
assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
|
||||
assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
|
||||
assert isinstance(stat_loggers[3], PrometheusStatLogger)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument():
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33,
|
||||
)
|
||||
|
||||
# Test with valid DP rank.
|
||||
async for _ in engine.generate(
|
||||
request_id="request-34",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=0,
|
||||
):
|
||||
pass
|
||||
|
||||
# Test with out-of-range DP rank.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in engine.generate(
|
||||
request_id="request-35",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_health():
|
||||
"""Test that check_health returns normally for healthy engine
|
||||
and raises EngineDeadError when the engine is dead.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Test 1: Healthy engine should not raise any exception
|
||||
await engine.check_health()
|
||||
|
||||
# Test 2: Mock the errored property to simulate a dead engine
|
||||
with (
|
||||
patch.object(
|
||||
type(engine),
|
||||
"errored",
|
||||
new_callable=lambda: property(lambda self: True),
|
||||
),
|
||||
pytest.raises(EngineDeadError),
|
||||
):
|
||||
await engine.check_health()
|
||||
|
||||
# Test 3: Verify healthy engine still works after mock
|
||||
await engine.check_health()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_final_output(output_kind: RequestOutputKind):
|
||||
"""Test that abort() returns a final output with correct information."""
|
||||
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
request_id = "test-abort-final-output"
|
||||
|
||||
# Start a long-running request
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=3000, # Long enough to allow abort
|
||||
ignore_eos=True,
|
||||
output_kind=output_kind,
|
||||
temperature=0.5,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
outputs: list[RequestOutput] = []
|
||||
generated = asyncio.create_task(
|
||||
collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
|
||||
)
|
||||
|
||||
# Let it generate some tokens
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Abort the request
|
||||
await engine.abort(request_id)
|
||||
|
||||
# Wait for generation to complete and return final output
|
||||
final_output = await generated
|
||||
|
||||
# Verify we got a final output
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
assert len(final_output.outputs) == 1
|
||||
|
||||
assert final_output.outputs[0].finish_reason == "abort"
|
||||
assert final_output.outputs[0].stop_reason is None
|
||||
|
||||
# Verify num_cached_tokens is set correctly
|
||||
assert hasattr(final_output, "num_cached_tokens")
|
||||
assert final_output.num_cached_tokens >= 0
|
||||
|
||||
# If we got intermediate outputs, verify they are consistent
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
# For DELTA, sum all intermediate tokens should <= final tokens
|
||||
token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
|
||||
assert token_count > 0
|
||||
# This would ordinarily be 0, but could end up > 0 if the
|
||||
# final abort is coalesced with another chunk in the output queue.
|
||||
assert len(final_output.outputs[0].token_ids) >= 0
|
||||
else:
|
||||
# For FINAL_ONLY, we should only get the final output
|
||||
assert len(outputs) == 0
|
||||
assert len(final_output.outputs[0].token_ids) > 0
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
async def collect_outputs(
|
||||
engine: AsyncLLM,
|
||||
request_id: str,
|
||||
prompt: PromptType,
|
||||
sampling_params: SamplingParams,
|
||||
outputs_list: list[RequestOutput],
|
||||
) -> RequestOutput | None:
|
||||
"""Helper to collect outputs and return the final one."""
|
||||
final_output: RequestOutput | None = None
|
||||
async for output in engine.generate(
|
||||
request_id=request_id, prompt=prompt, sampling_params=sampling_params
|
||||
):
|
||||
if not output.finished:
|
||||
outputs_list.append(output)
|
||||
final_output = output
|
||||
return final_output
|
||||
92
tests/v1/engine/test_engine_args.py
Normal file
92
tests/v1/engine/test_engine_args.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from argparse import ArgumentError
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.hashing import _xxhash
|
||||
|
||||
|
||||
def test_prefix_caching_from_cli():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.enable_prefix_caching, (
|
||||
"V1 turns on prefix caching by default."
|
||||
)
|
||||
|
||||
# Turn it off possible with flag.
|
||||
args = parser.parse_args(["--no-enable-prefix-caching"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert not vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# Turn it on with flag.
|
||||
args = parser.parse_args(["--enable-prefix-caching"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# default hash algorithm is "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to sha256_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256_cbor"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# an invalid hash algorithm raises an error
|
||||
parser.exit_on_error = False
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
|
||||
def test_prefix_caching_xxhash_from_cli():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
|
||||
# set hash algorithm to xxhash (pickle)
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
|
||||
|
||||
# set hash algorithm to xxhash_cbor
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
|
||||
device_memory = current_platform.get_device_total_memory()
|
||||
device_name = current_platform.get_device_name().lower()
|
||||
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
|
||||
# For GPUs like H100, H200, and MI300x with >= 70GB memory
|
||||
default_llm_tokens = 16384
|
||||
default_server_tokens = 8192
|
||||
default_max_num_seqs = 1024
|
||||
else:
|
||||
default_llm_tokens = 8192
|
||||
default_server_tokens = 2048
|
||||
default_max_num_seqs = 256
|
||||
|
||||
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_llm_tokens # noqa: E501
|
||||
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER)
|
||||
assert vllm_config.scheduler_config.max_num_seqs == default_max_num_seqs
|
||||
assert vllm_config.scheduler_config.max_num_batched_tokens == default_server_tokens # noqa: E501
|
||||
599
tests/v1/engine/test_engine_core.py
Normal file
599
tests/v1/engine/test_engine_core.py
Normal file
@@ -0,0 +1,599 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ECTransferConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.uniproc_executor import UniProcExecutor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
from ...utils import create_new_process_for_each_test, multi_gpu_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
# test_engine_core_concurrent_batches assumes exactly 12 tokens per prompt.
|
||||
# Adjust prompt if changing model to maintain 12-token length.
|
||||
PROMPT = "I am Gyoubu Masataka Oniwa"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request() -> EngineCoreRequest:
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
mm_features=None,
|
||||
sampling_params=SamplingParams(),
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core():
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
# Second request.
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
# Add two requests in a row.
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(make_request()))
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 4
|
||||
|
||||
# Loop through until they are all done.
|
||||
while (outs := engine_core.step_fn()[0].get(0)) and outs.outputs:
|
||||
pass
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
"""Test abort cycle."""
|
||||
|
||||
# Basic abort.
|
||||
req = make_request()
|
||||
request_id = req.request_id
|
||||
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
assert engine_core.scheduler.has_unfinished_requests()
|
||||
assert not engine_core.scheduler.has_finished_requests()
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 1
|
||||
assert engine_core.scheduler.has_unfinished_requests()
|
||||
assert not engine_core.scheduler.has_finished_requests()
|
||||
|
||||
engine_core.abort_requests([request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
assert not engine_core.scheduler.has_unfinished_requests()
|
||||
assert engine_core.scheduler.has_finished_requests()
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert not engine_core.scheduler.has_unfinished_requests()
|
||||
assert not engine_core.scheduler.has_finished_requests()
|
||||
|
||||
# Add, step, abort 1 of the 3.
|
||||
req0 = make_request()
|
||||
req1 = make_request()
|
||||
req2 = make_request()
|
||||
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
||||
assert len(engine_core.scheduler.waiting) == 2
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req2))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 3
|
||||
|
||||
# Abort just one.
|
||||
engine_core.abort_requests([req1.request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
_ = engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 2
|
||||
|
||||
# Abort the other requests at the same time.
|
||||
engine_core.abort_requests([req2.request_id, req0.request_id])
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
# Sending duplicate requests with same request_id
|
||||
req0 = make_request()
|
||||
req1 = make_request()
|
||||
req0.request_id = req1.request_id = "test"
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
||||
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_advanced_sampling():
|
||||
"""
|
||||
A basic end-to-end test to verify that the engine functions correctly
|
||||
when additional sampling parameters, such as top_p, min_tokens, and
|
||||
presence_penalty, are set.
|
||||
"""
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
request.sampling_params = SamplingParams(
|
||||
min_tokens=4,
|
||||
presence_penalty=1.0,
|
||||
frequency_penalty=1.0,
|
||||
repetition_penalty=0.1,
|
||||
stop_token_ids=[1001, 1002],
|
||||
)
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(request))
|
||||
|
||||
def _check_engine_state():
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
# Loop through until they are all done.
|
||||
while engine_core.scheduler.has_requests():
|
||||
engine_core.step_fn()
|
||||
assert len(engine_core.scheduler.waiting) == 0
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
_check_engine_state()
|
||||
|
||||
# Second request.
|
||||
request2 = make_request()
|
||||
request2.sampling_params = SamplingParams(
|
||||
top_p=0.99,
|
||||
top_k=50,
|
||||
)
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(request2))
|
||||
_check_engine_state()
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_concurrent_batches():
|
||||
"""
|
||||
Test that the engine can handle multiple concurrent batches.
|
||||
"""
|
||||
|
||||
def make_request_with_max_tokens(req_id: str, max_tokens: int) -> EngineCoreRequest:
|
||||
request = make_request()
|
||||
request.request_id = req_id
|
||||
request.sampling_params.max_tokens = max_tokens
|
||||
return request
|
||||
|
||||
class DummyExecutor(UniProcExecutor):
|
||||
def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
super().initialize_from_config(kv_cache_configs)
|
||||
|
||||
# Create a thread pool with a single worker
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
non_block=False,
|
||||
) -> Future[ModelRunnerOutput | None]:
|
||||
"""Make execute_model non-blocking."""
|
||||
|
||||
# DummyExecutor used only for testing async case.
|
||||
assert non_block
|
||||
|
||||
def _execute():
|
||||
output = self.collective_rpc("execute_model", args=(scheduler_output,))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
return copy.deepcopy(output[0])
|
||||
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
def sample_tokens(
|
||||
self, grammar_output, non_block=False
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make sample_tokens non-blocking."""
|
||||
|
||||
# DummyExecutor used only for testing async case.
|
||||
assert non_block
|
||||
|
||||
def _execute():
|
||||
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
return copy.deepcopy(output[0])
|
||||
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, "thread_pool"):
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
# To test concurrent batches.
|
||||
max_num_seqs=2,
|
||||
# Avoid all requests being scheduled once.
|
||||
enable_prefix_caching=False,
|
||||
max_num_batched_tokens=10,
|
||||
# Reduce startup time.
|
||||
enforce_eager=True,
|
||||
# Test concurrent batch behaviour independently of async scheduling.
|
||||
async_scheduling=False,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, log_stats=False, executor_class=DummyExecutor
|
||||
)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||
req0 = make_request_with_max_tokens("0", 5)
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req0))
|
||||
req1 = make_request_with_max_tokens("1", 5)
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(req1))
|
||||
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue()[0] is None
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 10
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue()[0] == {}
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 2
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 8
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests["0"].num_computed_tokens == 12
|
||||
assert engine_core.scheduler.requests["1"].num_computed_tokens == 8
|
||||
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Finish Batch 1 and schedule Batch 3: (4, req1).
|
||||
# Note that req0 cannot be scheduled
|
||||
# because it is in the decoding stage now.
|
||||
engine_core.step_with_batch_queue()
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 4
|
||||
|
||||
# Finish Batch 2. Get first token of req0.
|
||||
# Schedule Batch 4: (1, req0).
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 1
|
||||
|
||||
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 1
|
||||
|
||||
# Loop until req0 is finished.
|
||||
req_id = 0
|
||||
expected_num_tokens = [
|
||||
engine_core.scheduler.requests["0"].num_tokens + 1,
|
||||
engine_core.scheduler.requests["1"].num_tokens + 1,
|
||||
]
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
output = engine_core.step_with_batch_queue()[0]
|
||||
# Every step consumes an output.
|
||||
assert output is not None
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert (
|
||||
engine_core.scheduler.requests[req_id].num_tokens
|
||||
== expected_num_tokens[req_id]
|
||||
)
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_engine_core_tp():
|
||||
"""
|
||||
Test engine can initialize worker in tp properly
|
||||
"""
|
||||
|
||||
"""Setup the EngineCore."""
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
tensor_parallel_size=2,
|
||||
# Reduce startup time.
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
def get_worker_cache_config_field(worker, key: str):
|
||||
return getattr(worker.cache_config, key)
|
||||
|
||||
num_gpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_gpu_blocks",)
|
||||
)
|
||||
num_cpu_blocks = engine_core.collective_rpc(
|
||||
get_worker_cache_config_field, args=("num_cpu_blocks",)
|
||||
)
|
||||
assert all(x is not None for x in num_gpu_blocks)
|
||||
assert all(x is not None for x in num_cpu_blocks)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_invalid_request_id_type():
|
||||
"""Test that engine raises TypeError for non-string request_id."""
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
# Test with UUID object (common mistake)
|
||||
uuid_request = make_request()
|
||||
uuid_request.request_id = uuid.uuid4() # UUID object instead of string
|
||||
|
||||
with pytest.raises(TypeError, match="request_id must be a string, got.*UUID"):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(uuid_request))
|
||||
|
||||
# Test with integer
|
||||
int_request = make_request()
|
||||
int_request.request_id = 12345
|
||||
|
||||
with pytest.raises(TypeError, match="request_id must be a string, got.*int"):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(int_request))
|
||||
|
||||
# Test with None
|
||||
none_request = make_request()
|
||||
none_request.request_id = None
|
||||
|
||||
with pytest.raises(TypeError, match="request_id must be a string, got.*NoneType"):
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(none_request))
|
||||
|
||||
# Verify engine is still functional after errors
|
||||
valid_request = make_request()
|
||||
engine_core.add_request(*engine_core.preprocess_add_request(valid_request))
|
||||
assert len(engine_core.scheduler.waiting) == 1
|
||||
assert len(engine_core.scheduler.running) == 0
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize(
|
||||
("ec_role", "gpu_memory_utilization", "enable_prefix_caching"),
|
||||
[
|
||||
("ec_producer", 0.01, False),
|
||||
# NOTE: ec_producer never allows prefix caching
|
||||
("ec_consumer", 0.7, True),
|
||||
("ec_consumer", 0.7, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("use_kv_connector", [False, True])
|
||||
def test_encoder_instance_zero_kv_cache(
|
||||
ec_role: str,
|
||||
gpu_memory_utilization: float,
|
||||
enable_prefix_caching: bool,
|
||||
use_kv_connector: bool,
|
||||
):
|
||||
"""EPD (Encoder-Prefill-Decode) Encoder-cache-specific tests
|
||||
|
||||
This test verifies encoder-only instance initializes with 0 KV cache blocks.
|
||||
Under EPD disagg mode, Encoder instances (EC producer role) only execute
|
||||
vision encoder, so they don't need KV cache for text generation.
|
||||
"""
|
||||
# Form vllm config
|
||||
model_config = ModelConfig(
|
||||
model="llava-hf/llava-1.5-7b-hf", # Multimodal model
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=10,
|
||||
max_num_batched_tokens=512,
|
||||
max_model_len=512,
|
||||
disable_hybrid_kv_cache_manager=True,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
)
|
||||
if use_kv_connector
|
||||
else None
|
||||
)
|
||||
ec_transfer_config = ECTransferConfig(
|
||||
ec_connector="ECExampleConnector",
|
||||
ec_role=ec_role,
|
||||
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test_encoder"},
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
ec_transfer_config=ec_transfer_config,
|
||||
)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
print(f"executor_class: {executor_class}")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(
|
||||
vllm_config=vllm_config, executor_class=executor_class, log_stats=True
|
||||
)
|
||||
|
||||
# Check encoder cache manager exists
|
||||
assert engine_core.scheduler.encoder_cache_manager is not None, (
|
||||
"encoder_cache_manager should exist"
|
||||
)
|
||||
|
||||
if ec_role == "ec_producer":
|
||||
# Check 1: num_blocks should be 0
|
||||
# NOTE: num_blocks=1 as BlockPool always needs a null_block.
|
||||
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
|
||||
print(f"kv_cache_config: {kv_cache_config}")
|
||||
assert kv_cache_config.num_blocks == 1, (
|
||||
f"ec_producer should only have 1 KV blocks, "
|
||||
f"got {kv_cache_config.num_blocks}"
|
||||
)
|
||||
|
||||
# Check 2: kv_cache_groups should be empty
|
||||
assert len(kv_cache_config.kv_cache_groups) == 0, (
|
||||
f"ec_producer should have 0 KV cache groups, "
|
||||
f"got {len(kv_cache_config.kv_cache_groups)}"
|
||||
)
|
||||
|
||||
# Check 3: kv_cache_tensors should be empty
|
||||
assert len(kv_cache_config.kv_cache_tensors) == 0, (
|
||||
f"Encoder instance should have 0 KV cache tensors, "
|
||||
f"got {len(kv_cache_config.kv_cache_tensors)}"
|
||||
)
|
||||
|
||||
# Check 4: Verify EC connector is initialized and is producer
|
||||
assert engine_core.scheduler.ec_connector is not None, (
|
||||
"Encoder instance should have EC connector"
|
||||
)
|
||||
assert engine_core.scheduler.ec_connector.is_producer, (
|
||||
"Encoder instance EC connector should be producer"
|
||||
)
|
||||
|
||||
# Check 5: Verify chunked prefill is disabled
|
||||
assert not vllm_config.scheduler_config.enable_chunked_prefill, (
|
||||
"Encoder instance should disable chunked prefill (no KV cache)"
|
||||
)
|
||||
|
||||
elif ec_role == "ec_consumer":
|
||||
# Check 1: num_blocks should be > 1
|
||||
kv_cache_config = engine_core.scheduler.kv_cache_manager.kv_cache_config
|
||||
print(f"kv_cache_config: {kv_cache_config}")
|
||||
assert kv_cache_config.num_blocks > 1, (
|
||||
f"ec_consumer should have >1 KV blocks, got {kv_cache_config.num_blocks}"
|
||||
)
|
||||
|
||||
# Check 2: kv_cache_groups should NOT be empty
|
||||
assert len(kv_cache_config.kv_cache_groups) > 0, (
|
||||
f"ec_consumer should have KV cache groups, "
|
||||
f"got {len(kv_cache_config.kv_cache_groups)}"
|
||||
)
|
||||
|
||||
# Check 3: Verify EC connector is consumer
|
||||
assert engine_core.scheduler.ec_connector is not None, (
|
||||
"Consumer instance should have EC connector"
|
||||
)
|
||||
assert not engine_core.scheduler.ec_connector.is_producer, (
|
||||
"Consumer instance EC connector should be consumer"
|
||||
)
|
||||
851
tests/v1/engine/test_engine_core_client.py
Normal file
851
tests/v1/engine/test_engine_core_client.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch, ZmqEventPublisher
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import AsyncMPClient, EngineCoreClient, SyncMPClient
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request(
|
||||
params: SamplingParams, prompt_tokens_ids: list[int] | None = None
|
||||
) -> EngineCoreRequest:
|
||||
if not prompt_tokens_ids:
|
||||
prompt_tokens_ids = PROMPT_TOKENS
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
|
||||
def loop_until_done(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = client.get_output().outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
continue
|
||||
|
||||
all_finished = True
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
if not out.finished:
|
||||
all_finished = False
|
||||
|
||||
if all_finished:
|
||||
break
|
||||
|
||||
|
||||
async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
continue
|
||||
|
||||
all_finished = True
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
if not out.finished:
|
||||
all_finished = False
|
||||
|
||||
if all_finished:
|
||||
break
|
||||
|
||||
|
||||
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
continue
|
||||
|
||||
# Add outputs to the dict
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
|
||||
# Check if all request IDs in outputs have finished
|
||||
if all(outs and outs[-1].finished for outs in outputs.values()):
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo(self, msg: str, err_msg: str | None = None, sleep: float | None = None) -> str:
|
||||
print(f"echo util function called: {msg}, {err_msg}")
|
||||
if sleep is not None:
|
||||
time.sleep(sleep)
|
||||
if err_msg is not None:
|
||||
raise ValueError(err_msg)
|
||||
return msg
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||
def test_engine_core_client(
|
||||
monkeypatch: pytest.MonkeyPatch, multiprocessing_mode: bool
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
client.add_request(request)
|
||||
time.sleep(0.01)
|
||||
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
# since we have to call get_output() explicitly
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
client.add_request(request)
|
||||
time.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort after request is finished."""
|
||||
|
||||
# Note: this code pathway will only work for multiprocessing
|
||||
# since we have to call get_output() explicitly
|
||||
|
||||
request = requests[0]
|
||||
client.add_request(request)
|
||||
time.sleep(10.0)
|
||||
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
if multiprocessing_mode:
|
||||
"""Utility method invocation"""
|
||||
|
||||
core_client: SyncMPClient = client
|
||||
|
||||
result = core_client.call_utility("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
core_client.call_utility("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}"
|
||||
)
|
||||
"""Utility method invocation"""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
||||
# Test that cancelling the utility call doesn't destabilize the
|
||||
# engine.
|
||||
util_task = asyncio.create_task(
|
||||
core_client.call_utility_async("echo", "testarg2", None, 0.5)
|
||||
) # sleep for 0.5 sec
|
||||
await asyncio.sleep(0.05)
|
||||
cancelled = util_task.cancel()
|
||||
assert cancelled
|
||||
|
||||
# Ensure client is still functional. The engine runs utility
|
||||
# methods in a single thread so this request won't be processed
|
||||
# until the cancelled sleeping one is complete.
|
||||
result = await asyncio.wait_for(
|
||||
core_client.call_utility_async("echo", "testarg3"), timeout=1.0
|
||||
)
|
||||
assert result == "testarg3"
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyDataclass:
|
||||
message: str
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo_dc(
|
||||
self,
|
||||
msg: str,
|
||||
return_list: bool = False,
|
||||
) -> MyDataclass | list[MyDataclass]:
|
||||
print(f"echo dc util function called: {msg}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
# Return dataclass to verify support for returning custom types
|
||||
# (for which there is special handling to make it work with msgspec).
|
||||
return [val for _ in range(3)] if return_list else val
|
||||
|
||||
|
||||
# Dummy utility function to test dict serialization with custom types.
|
||||
def echo_dc_dict(
|
||||
self,
|
||||
msg: str,
|
||||
return_dict: bool = False,
|
||||
) -> MyDataclass | dict[str, MyDataclass]:
|
||||
print(f"echo dc dict util function called: {msg}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
# Return dict of dataclasses to verify support for returning dicts
|
||||
# with custom value types.
|
||||
if return_dict:
|
||||
return {"key1": val, "key2": val, "key3": val}
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
# Dummy utility function to test nested structures with custom types.
|
||||
def echo_dc_nested(
|
||||
self,
|
||||
msg: str,
|
||||
structure_type: str = "list_of_dicts",
|
||||
) -> Any:
|
||||
print(f"echo dc nested util function called: {msg}, structure: {structure_type}")
|
||||
val = None if msg is None else MyDataclass(msg)
|
||||
|
||||
if structure_type == "list_of_dicts": # noqa
|
||||
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
return [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
elif structure_type == "dict_of_lists":
|
||||
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||
return {"list1": [val, val], "list2": [val, val]}
|
||||
elif structure_type == "deep_nested":
|
||||
# Return deeply nested: {"outer": [{"inner": [val, val]},
|
||||
# {"inner": [val]}]}
|
||||
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
|
||||
else:
|
||||
return val
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_return(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test utility method returning custom / non-native data type.
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async("echo_dc", "testarg2", False)
|
||||
assert isinstance(result, MyDataclass) and result.message == "testarg2"
|
||||
result = await core_client.call_utility_async("echo_dc", "testarg2", True)
|
||||
assert isinstance(result, list) and all(
|
||||
isinstance(r, MyDataclass) and r.message == "testarg2" for r in result
|
||||
)
|
||||
|
||||
# Test returning None and list of Nones
|
||||
result = await core_client.call_utility_async("echo_dc", None, False)
|
||||
assert result is None
|
||||
result = await core_client.call_utility_async("echo_dc", None, True)
|
||||
assert isinstance(result, list) and all(r is None for r in result)
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_custom_dict_return(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
# Test utility method returning custom / non-native data type.
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
# Test single object return
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", False
|
||||
)
|
||||
assert isinstance(result, MyDataclass) and result.message == "testarg3"
|
||||
|
||||
# Test dict return with custom value types
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_dict", "testarg3", True
|
||||
)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
assert isinstance(val, MyDataclass) and val.message == "testarg3"
|
||||
|
||||
# Test returning dict with None values
|
||||
result = await core_client.call_utility_async("echo_dc_dict", None, True)
|
||||
assert isinstance(result, dict) and len(result) == 3
|
||||
for key, val in result.items():
|
||||
assert key in ["key1", "key2", "key3"]
|
||||
assert val is None
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_util_method_nested_structures(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
# Must set insecure serialization to allow returning custom types.
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested1", "list_of_dicts"
|
||||
)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for i, item in enumerate(result):
|
||||
assert isinstance(item, dict)
|
||||
if i == 0:
|
||||
assert "a" in item and "b" in item
|
||||
assert (
|
||||
isinstance(item["a"], MyDataclass)
|
||||
and item["a"].message == "nested1"
|
||||
)
|
||||
assert (
|
||||
isinstance(item["b"], MyDataclass)
|
||||
and item["b"].message == "nested1"
|
||||
)
|
||||
else:
|
||||
assert "c" in item and "d" in item
|
||||
assert (
|
||||
isinstance(item["c"], MyDataclass)
|
||||
and item["c"].message == "nested1"
|
||||
)
|
||||
assert (
|
||||
isinstance(item["d"], MyDataclass)
|
||||
and item["d"].message == "nested1"
|
||||
)
|
||||
|
||||
# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested2", "dict_of_lists"
|
||||
)
|
||||
assert isinstance(result, dict) and len(result) == 2
|
||||
assert "list1" in result and "list2" in result
|
||||
for key, lst in result.items():
|
||||
assert isinstance(lst, list) and len(lst) == 2
|
||||
for item in lst:
|
||||
assert isinstance(item, MyDataclass) and item.message == "nested2"
|
||||
|
||||
# Test deeply nested: {"outer": [{"inner": [val, val]},
|
||||
# {"inner": [val]}]}
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", "nested3", "deep_nested"
|
||||
)
|
||||
assert isinstance(result, dict) and "outer" in result
|
||||
outer_list = result["outer"]
|
||||
assert isinstance(outer_list, list) and len(outer_list) == 2
|
||||
|
||||
# First dict in outer list should have "inner" with 2 items
|
||||
inner_dict1 = outer_list[0]
|
||||
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
|
||||
inner_list1 = inner_dict1["inner"]
|
||||
assert isinstance(inner_list1, list) and len(inner_list1) == 2
|
||||
for item in inner_list1:
|
||||
assert isinstance(item, MyDataclass) and item.message == "nested3"
|
||||
|
||||
# Second dict in outer list should have "inner" with 1 item
|
||||
inner_dict2 = outer_list[1]
|
||||
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
|
||||
inner_list2 = inner_dict2["inner"]
|
||||
assert isinstance(inner_list2, list) and len(inner_list2) == 1
|
||||
assert (
|
||||
isinstance(inner_list2[0], MyDataclass)
|
||||
and inner_list2[0].message == "nested3"
|
||||
)
|
||||
|
||||
# Test with None values in nested structures
|
||||
result = await core_client.call_utility_async(
|
||||
"echo_dc_nested", None, "list_of_dicts"
|
||||
)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
for item in result:
|
||||
assert isinstance(item, dict)
|
||||
for val in item.values():
|
||||
assert val is None
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
def test_kv_cache_events(
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(
|
||||
endpoint, topic=publisher_config.topic, decode_type=KVEventBatch
|
||||
)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, "We should have exactly one BlockStored event"
|
||||
event = received.events[0]
|
||||
assert isinstance(event, BlockStored), "We should have a BlockStored event"
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes"
|
||||
)
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size"
|
||||
)
|
||||
assert event.parent_block_hash is None, "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens"
|
||||
)
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
async def test_kv_cache_events_dp(
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
dp_size = 2
|
||||
tp_size = 2
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=dp_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Build endpoints for all DP ranks
|
||||
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
endpoints = []
|
||||
for i in range(dp_size):
|
||||
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(base_endpoint, i)
|
||||
endpoints.append(offset_endpoint)
|
||||
|
||||
subscriber = MockSubscriber(
|
||||
endpoints, topic=publisher_config.topic, decode_type=KVEventBatch
|
||||
)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
all_request_ids = []
|
||||
|
||||
# Create and add 25 requests
|
||||
# NOTE: attempts to force routing to both dp groups but can be flaky
|
||||
for i in range(25):
|
||||
await asyncio.sleep(0.01)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
await client.add_request_async(request)
|
||||
all_request_ids.append(request.request_id)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Initialize outputs dict for all requests
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in all_request_ids}
|
||||
|
||||
print("processing requests...")
|
||||
await asyncio.wait_for(
|
||||
loop_until_fully_done_async(client, outputs), timeout=20.0
|
||||
)
|
||||
|
||||
# Receive from subscriber until no more messages
|
||||
print("collecting results...")
|
||||
results = []
|
||||
while True:
|
||||
result = subscriber.receive_one(timeout=1)
|
||||
print(result)
|
||||
if result is None:
|
||||
break
|
||||
results.append(result)
|
||||
|
||||
# Collect all events and data_parallel_ranks from all results
|
||||
all_dp_ranks = [received.data_parallel_rank for (_, received) in results]
|
||||
unique_dps = set(all_dp_ranks)
|
||||
assert len(unique_dps) == 2, (
|
||||
f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||
)
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.timeout(20)
|
||||
def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, pytest.raises(Exception) as e_info:
|
||||
# Monkey-patch to extract core process pid while it's starting.
|
||||
core_proc_pid = [None]
|
||||
cepm_ctor = CoreEngineProcManager.__init__
|
||||
|
||||
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
|
||||
cepm_ctor(self, *args, **kwargs)
|
||||
core_proc_pid[0] = self.processes[0].pid
|
||||
|
||||
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
|
||||
|
||||
t = time.time()
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT
|
||||
)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
print(f"VllmConfig creation took {time.time() - t:.2f} seconds.")
|
||||
|
||||
# Start another thread to wait for engine core process to start
|
||||
# and kill it - simulate fatal uncaught process exit.
|
||||
|
||||
def kill_first_child():
|
||||
while (child_pid := core_proc_pid[0]) is None:
|
||||
time.sleep(0.5)
|
||||
print(f"Killing child core process {child_pid}")
|
||||
assert isinstance(child_pid, int)
|
||||
os.kill(child_pid, signal.SIGKILL)
|
||||
|
||||
Thread(target=kill_first_child, daemon=True).start()
|
||||
|
||||
_core_client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
assert "Engine core initialization failed" in str(e_info.value)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_engine_core_proc_instantiation_cuda_empty(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
|
||||
is empty. This ensures the engine frontend does not need access to GPUs.
|
||||
"""
|
||||
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
# Create a simple mock executor instead of a complex custom class
|
||||
mock_executor_class = MagicMock(spec=Executor)
|
||||
|
||||
def create_mock_executor(vllm_config):
|
||||
mock_executor = MagicMock()
|
||||
|
||||
# Only implement the methods that are actually called during init
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
mock_spec = FullAttentionSpec(
|
||||
block_size=16, num_kv_heads=1, head_size=64, dtype=torch.float16
|
||||
)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{"default": mock_spec}]
|
||||
mock_executor.determine_available_memory.return_value = [1024 * 1024 * 1024]
|
||||
mock_executor.initialize_from_config.return_value = None
|
||||
mock_executor.max_concurrent_batches = 1
|
||||
|
||||
return mock_executor
|
||||
|
||||
mock_executor_class.side_effect = create_mock_executor
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
|
||||
|
||||
from vllm.v1.engine.utils import EngineZmqAddresses
|
||||
|
||||
def mock_startup_handshake(
|
||||
self, handshake_socket, local_client, headless, parallel_config
|
||||
):
|
||||
return EngineZmqAddresses(
|
||||
inputs=["tcp://127.0.0.1:5555"],
|
||||
outputs=["tcp://127.0.0.1:5556"],
|
||||
coordinator_input=None,
|
||||
coordinator_output=None,
|
||||
)
|
||||
|
||||
# Background processes are not important here
|
||||
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
|
||||
|
||||
vllm_config = EngineArgs(
|
||||
model="deepseek-ai/DeepSeek-V2-Lite", trust_remote_code=True
|
||||
).create_engine_config()
|
||||
engine_core_proc = EngineCoreProc(
|
||||
vllm_config=vllm_config,
|
||||
local_client=True,
|
||||
handshake_address="tcp://127.0.0.1:12345",
|
||||
executor_class=mock_executor_class,
|
||||
log_stats=False,
|
||||
engine_index=0,
|
||||
)
|
||||
|
||||
engine_core_proc.shutdown()
|
||||
196
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
196
tests/v1/engine/test_fast_incdec_prefix_err.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
|
||||
def test_fast_inc_detok_invalid_utf8_err_case():
|
||||
"""
|
||||
Test edge case where tokenizer can produce non-monotonic,
|
||||
invalid UTF-8 output, which breaks the internal state of
|
||||
tokenizers' DecodeStream.
|
||||
See https://github.com/vllm-project/vllm/issues/17448.
|
||||
|
||||
Thanks to reproducer from @fpaupier:
|
||||
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
||||
|
||||
# Create a test request
|
||||
prompt_token_ids = [107, 4606, 236787, 107]
|
||||
params = SamplingParams(skip_special_tokens=True)
|
||||
request = EngineCoreRequest(
|
||||
request_id="test",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
mm_features=None,
|
||||
sampling_params=params,
|
||||
pooling_params=None,
|
||||
eos_token_id=None,
|
||||
arrival_time=0.0,
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
data_parallel_rank=None,
|
||||
)
|
||||
|
||||
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
|
||||
|
||||
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", (
|
||||
"Should use FastIncrementalDetokenizer by default"
|
||||
)
|
||||
|
||||
# Process tokens incrementally
|
||||
test_tokens = [
|
||||
236840,
|
||||
107,
|
||||
138,
|
||||
236782,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
1083,
|
||||
623,
|
||||
121908,
|
||||
147418,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
236779,
|
||||
2084,
|
||||
1083,
|
||||
623,
|
||||
203292,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
6265,
|
||||
236779,
|
||||
7777,
|
||||
1083,
|
||||
623,
|
||||
121908,
|
||||
147418,
|
||||
569,
|
||||
537,
|
||||
236789,
|
||||
65880,
|
||||
569,
|
||||
537,
|
||||
236789,
|
||||
62580,
|
||||
853,
|
||||
115693,
|
||||
210118,
|
||||
35178,
|
||||
16055,
|
||||
1270,
|
||||
759,
|
||||
215817,
|
||||
4758,
|
||||
1925,
|
||||
1117,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
1083,
|
||||
623,
|
||||
110733,
|
||||
46291,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
236779,
|
||||
2084,
|
||||
1083,
|
||||
623,
|
||||
136955,
|
||||
56731,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
5654,
|
||||
236779,
|
||||
7777,
|
||||
1083,
|
||||
623,
|
||||
194776,
|
||||
2947,
|
||||
496,
|
||||
109811,
|
||||
1608,
|
||||
890,
|
||||
215817,
|
||||
4758,
|
||||
1925,
|
||||
1117,
|
||||
2789,
|
||||
432,
|
||||
398,
|
||||
602,
|
||||
31118,
|
||||
569,
|
||||
124866,
|
||||
134772,
|
||||
509,
|
||||
19478,
|
||||
1640,
|
||||
33779,
|
||||
236743,
|
||||
236770,
|
||||
236819,
|
||||
236825,
|
||||
236771,
|
||||
432,
|
||||
398,
|
||||
432,
|
||||
237167,
|
||||
827,
|
||||
107,
|
||||
140,
|
||||
236775,
|
||||
77984,
|
||||
1083,
|
||||
623,
|
||||
2709,
|
||||
236745,
|
||||
2555,
|
||||
513,
|
||||
236789,
|
||||
602,
|
||||
31118,
|
||||
569,
|
||||
]
|
||||
|
||||
output = ""
|
||||
for i, token_id in enumerate(test_tokens):
|
||||
detokenizer.update([token_id], False)
|
||||
|
||||
finished = i == len(test_tokens) - 1
|
||||
output += detokenizer.get_next_output_text(finished, delta=True)
|
||||
|
||||
assert (
|
||||
output
|
||||
== r"""[
|
||||
{
|
||||
"source": "Résultats",
|
||||
"source_type": "CONCEPT",
|
||||
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
|
||||
"target": "Israël",
|
||||
"target_type": "ORGANIZATION",
|
||||
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
|
||||
"relationship": "Obtention d'un niveau de"""
|
||||
)
|
||||
54
tests/v1/engine/test_init_error_messaging.py
Normal file
54
tests/v1/engine/test_init_error_messaging.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
|
||||
def test_kv_cache_oom_no_memory():
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 0)
|
||||
|
||||
|
||||
def test_kv_cache_oom_insufficient_memory(monkeypatch):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
config = MagicMock()
|
||||
config.model_config.max_model_len = 2048
|
||||
config.cache_config.block_size = 16
|
||||
config.parallel_config.tensor_parallel_size = 1
|
||||
config.parallel_config.pipeline_parallel_size = 1
|
||||
config.parallel_config.decode_context_parallel_size = 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
|
||||
lambda c, s: 100 * 1024**3, # 100 GiB
|
||||
)
|
||||
|
||||
spec = {
|
||||
"layer_0": FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=128,
|
||||
dtype="float16",
|
||||
)
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB
|
||||
237
tests/v1/engine/test_llm_engine.py
Normal file
237
tests/v1/engine/test_llm_engine.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tests.conftest import VllmRunner
|
||||
else:
|
||||
VllmRunner = object
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
def _vllm_model(
|
||||
apc: bool,
|
||||
vllm_runner: type[VllmRunner],
|
||||
*,
|
||||
skip_tokenizer_init: bool = False,
|
||||
):
|
||||
"""Set up VllmRunner instance."""
|
||||
return vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=apc,
|
||||
gpu_memory_utilization=0.5,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
# Function scope decouples tests & allows
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model(vllm_runner, request):
|
||||
"""VllmRunner test fixture parameterized by APC True/False."""
|
||||
with _vllm_model(request.param, vllm_runner) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def vllm_model_apc(vllm_runner):
|
||||
"""VllmRunner test fixture with APC."""
|
||||
with _vllm_model(True, vllm_runner) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
# Function scope decouples tests & allows
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True],
|
||||
)
|
||||
def vllm_model_skip_tokenizer_init(vllm_runner, request):
|
||||
"""VllmRunner test fixture with APC."""
|
||||
with _vllm_model(
|
||||
request.param,
|
||||
vllm_runner,
|
||||
skip_tokenizer_init=True,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
def _get_test_sampling_params(
|
||||
prompt_list: list[str],
|
||||
seed: int | None = 42,
|
||||
structured_outputs: bool = False,
|
||||
) -> tuple[list[SamplingParams], list[int]]:
|
||||
"""Generate random sampling params for a batch."""
|
||||
|
||||
def get_mostly_n_gt1() -> int:
|
||||
r"""Mostly n \in [2,20], ~1/3 n=1"""
|
||||
x = random.randint(0, 28)
|
||||
if x < 10:
|
||||
return 1
|
||||
else:
|
||||
return x - 8
|
||||
|
||||
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
|
||||
# High temperature to maximize the chance of unique completions
|
||||
return [
|
||||
SamplingParams(
|
||||
temperature=0.95,
|
||||
top_p=0.95,
|
||||
n=n,
|
||||
seed=seed,
|
||||
structured_outputs=StructuredOutputsParams(regex="[0-9]+")
|
||||
if structured_outputs
|
||||
else None,
|
||||
)
|
||||
for n in n_list
|
||||
], n_list
|
||||
|
||||
|
||||
def test_compatibility_with_skip_tokenizer_init(
|
||||
vllm_model_skip_tokenizer_init: VllmRunner,
|
||||
example_prompts: list[str],
|
||||
):
|
||||
# Case 1: Structured output request should raise an error.
|
||||
sampling_params_list, _ = _get_test_sampling_params(
|
||||
example_prompts,
|
||||
structured_outputs=True,
|
||||
)
|
||||
llm: LLM = vllm_model_skip_tokenizer_init.llm
|
||||
with pytest.raises(ValueError):
|
||||
_ = llm.generate(example_prompts, sampling_params_list)
|
||||
|
||||
|
||||
def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
|
||||
|
||||
Args:
|
||||
vllm_model: VllmRunner instance under test.
|
||||
example_prompt: test fixture providing prompts for testing.
|
||||
"""
|
||||
sampling_params_list, n_list = _get_test_sampling_params(example_prompts)
|
||||
llm: LLM = vllm_model.llm
|
||||
outputs = llm.generate(example_prompts, sampling_params_list)
|
||||
|
||||
# Validate each request response
|
||||
for out, n in zip(outputs, n_list):
|
||||
completion_counts: dict[str, int] = {}
|
||||
# Assert correct number of completions
|
||||
assert len(out.outputs) == n, f"{len(out.outputs)} completions; {n} expected."
|
||||
for idx in range(n):
|
||||
comp = out.outputs[idx]
|
||||
# Assert correct completion indices
|
||||
assert comp.index == idx, f"Index {comp.index}; expected {idx}."
|
||||
text = comp.text
|
||||
completion_counts[text] = completion_counts.get(text, 0) + 1
|
||||
# Assert unique completions
|
||||
if len(completion_counts) != n:
|
||||
repeats = {txt: num for (txt, num) in completion_counts.items() if num > 1}
|
||||
raise AssertionError(
|
||||
f"{len(completion_counts)} unique completions; expected"
|
||||
f" {n}. Repeats: {repeats}"
|
||||
)
|
||||
|
||||
|
||||
def test_engine_metrics(vllm_runner, example_prompts):
|
||||
max_tokens = 100
|
||||
# Use spec decoding to test num_accepted_tokens_per_pos
|
||||
speculative_config = {
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 5,
|
||||
}
|
||||
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
llm: LLM = vllm_model.llm
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = llm.generate(example_prompts, sampling_params)
|
||||
|
||||
n_prompts = len(example_prompts)
|
||||
assert len(outputs) == n_prompts
|
||||
|
||||
total_tokens = 0
|
||||
for out in outputs:
|
||||
assert len(out.outputs) == 1
|
||||
total_tokens += len(out.outputs[0].token_ids)
|
||||
assert total_tokens == max_tokens * n_prompts
|
||||
|
||||
metrics = llm.get_metrics()
|
||||
|
||||
def find_metric(name) -> list[Metric]:
|
||||
found = []
|
||||
for metric in metrics:
|
||||
if metric.name == name:
|
||||
found.append(metric)
|
||||
return found
|
||||
|
||||
num_requests_running = find_metric("vllm:num_requests_running")
|
||||
assert len(num_requests_running) == 1
|
||||
assert isinstance(num_requests_running[0], Gauge)
|
||||
assert num_requests_running[0].value == 0.0
|
||||
|
||||
generation_tokens = find_metric("vllm:generation_tokens")
|
||||
assert len(generation_tokens) == 1
|
||||
assert isinstance(generation_tokens[0], Counter)
|
||||
assert generation_tokens[0].value == total_tokens
|
||||
|
||||
request_generation_tokens = find_metric("vllm:request_generation_tokens")
|
||||
assert len(request_generation_tokens) == 1
|
||||
assert isinstance(request_generation_tokens[0], Histogram)
|
||||
assert "+Inf" in request_generation_tokens[0].buckets
|
||||
assert request_generation_tokens[0].buckets["+Inf"] == n_prompts
|
||||
assert request_generation_tokens[0].count == n_prompts
|
||||
assert request_generation_tokens[0].sum == total_tokens
|
||||
|
||||
num_accepted_tokens_per_pos = find_metric(
|
||||
"vllm:spec_decode_num_accepted_tokens_per_pos"
|
||||
)
|
||||
assert len(num_accepted_tokens_per_pos) == 1
|
||||
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
|
||||
assert len(num_accepted_tokens_per_pos[0].values) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate(
|
||||
{"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params
|
||||
)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
1272
tests/v1/engine/test_output_processor.py
Normal file
1272
tests/v1/engine/test_output_processor.py
Normal file
File diff suppressed because it is too large
Load Diff
103
tests/v1/engine/test_parallel_sampling.py
Normal file
103
tests/v1/engine/test_parallel_sampling.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
|
||||
|
||||
def test_parent_request_to_output_stream() -> None:
|
||||
parent_request = ParentRequest("parent_id", SamplingParams(n=2))
|
||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||
output_0 = CompletionOutput(
|
||||
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
output_1 = CompletionOutput(
|
||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
# Request not finished
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
|
||||
# output_1 finished
|
||||
output_1.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_1], False) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
# Finished output_1 had already returned, DO NOT returned again
|
||||
assert ("parent_id", [output_0], False) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
|
||||
# output_0 finished
|
||||
output_0.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0], True) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
||||
# Finished output_0 had already returned, DO NOT returned again
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == ("parent_id", [], True)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == ("parent_id", [], True)
|
||||
|
||||
|
||||
def test_parent_request_to_output_final_only() -> None:
|
||||
parent_request = ParentRequest(
|
||||
"parent_id", SamplingParams(n=2, output_kind=RequestOutputKind.FINAL_ONLY)
|
||||
)
|
||||
parent_request.child_requests = {"child_id_0", "child_id_1"}
|
||||
output_0 = CompletionOutput(
|
||||
index=0, text="child 0", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
output_1 = CompletionOutput(
|
||||
index=1, text="child 1", token_ids=[], cumulative_logprob=None, logprobs=None
|
||||
)
|
||||
# Request not finished, return nothing
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
# output_1 finished, but outputs won't be returned until all child requests finished
|
||||
output_1.finish_reason = "ended"
|
||||
assert parent_request.get_outputs("child_id_0", output_0) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
assert parent_request.get_outputs("child_id_1", output_1) == (
|
||||
"parent_id",
|
||||
[],
|
||||
False,
|
||||
)
|
||||
# output_0 finished, as all child requests finished, the output would be returned
|
||||
output_0.finish_reason = "ended"
|
||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
||||
"child_id_0", output_0
|
||||
)
|
||||
assert ("parent_id", [output_0, output_1], True) == parent_request.get_outputs(
|
||||
"child_id_1", output_1
|
||||
)
|
||||
202
tests/v1/engine/test_process_multi_modal_uuids.py
Normal file
202
tests/v1/engine/test_process_multi_modal_uuids.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import input_processor as input_processor_mod
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
stop_pil_image = ImageAsset("stop_sign").pil_image
|
||||
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
|
||||
|
||||
|
||||
def _mock_input_processor(
|
||||
monkeypatch, *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
|
||||
) -> InputProcessor:
|
||||
"""
|
||||
Create a Processor instance with minimal configuration suitable for unit
|
||||
tests without accessing external resources.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
ModelConfig, "try_get_generation_config", lambda self: {}, raising=True
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig, "__post_init__", lambda self, *args: None, raising=True
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig,
|
||||
"verify_with_parallel_config",
|
||||
lambda self, parallel_config: None,
|
||||
raising=True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
input_processor_mod,
|
||||
"processor_cache_from_config",
|
||||
lambda vllm_config, mm_registry: None,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(VllmConfig, "__post_init__", lambda self: None, raising=True)
|
||||
|
||||
model_config = ModelConfig(
|
||||
skip_tokenizer_init=True,
|
||||
max_model_len=128,
|
||||
mm_processor_cache_gb=mm_cache_gb,
|
||||
generation_config="vllm",
|
||||
tokenizer="dummy",
|
||||
)
|
||||
|
||||
# Minimal multimodal_config to satisfy references in
|
||||
# Processor.process_inputs.
|
||||
class _MockMMConfig:
|
||||
def __init__(self, gb: float):
|
||||
self.mm_processor_cache_gb = gb
|
||||
|
||||
model_config.multimodal_config = _MockMMConfig(mm_cache_gb) # type: ignore[attr-defined]
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
device_config=DeviceConfig(device="cpu"),
|
||||
)
|
||||
|
||||
return InputProcessor(vllm_config, tokenizer=None)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
||||
input_processor = _mock_input_processor(monkeypatch)
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
|
||||
# Mismatch: 2 items but only 1 uuid provided
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must have same length as data"):
|
||||
input_processor.process_inputs(
|
||||
request_id="req-1",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
|
||||
input_processor = _mock_input_processor(monkeypatch)
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
|
||||
# Two modalities provided in data
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image],
|
||||
"video": [baby_reading_np_ndarrays],
|
||||
},
|
||||
# Only image uuids provided; video missing should raise
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must be provided if multi_modal_data"):
|
||||
input_processor.process_inputs(
|
||||
request_id="req-2",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_cache_gb, enable_prefix_caching",
|
||||
[
|
||||
(4.0, True), # default behavior
|
||||
(4.0, False), # prefix caching disabled
|
||||
(0.0, True), # processor cache disabled
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
|
||||
):
|
||||
input_processor = _mock_input_processor(
|
||||
monkeypatch,
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
# Capture the overrides passed to InputPreprocessor.preprocess
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
# Minimal processed inputs for decoder-only flow
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
# Monkeypatch only the bound preprocess method on this instance
|
||||
monkeypatch.setattr(
|
||||
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
# Use a consistent two-image scenario across all configurations
|
||||
mm_uuids = {"image": [None, "hash_stop"], "video": None}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
input_processor.process_inputs(
|
||||
request_id="req-3",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
assert captured["mm_uuids"] == mm_uuids
|
||||
|
||||
|
||||
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
# When both processor cache is 0 and prefix caching disabled, the
|
||||
# processor builds overrides from request id instead of using user UUIDs.
|
||||
input_processor = _mock_input_processor(
|
||||
monkeypatch, mm_cache_gb=0.0, enable_prefix_caching=False
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
request_id = "req-42"
|
||||
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
input_processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
# Expect request-id-based overrides are passed through
|
||||
assert captured["mm_uuids"] == {
|
||||
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
|
||||
"video": [f"{request_id}-video-0"],
|
||||
}
|
||||
407
tests/v1/engine/utils.py
Normal file
407
tests/v1/engine/utils.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast
|
||||
|
||||
# Number of sample logprobs to request when testing sample logprobs
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5
|
||||
# Number of prompt logprobs to request when testing prompt logprobs
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
|
||||
|
||||
TOKENIZER_NAME = "meta-llama/Llama-3.2-1B"
|
||||
|
||||
FULL_STRINGS = [
|
||||
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
||||
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
|
||||
"Nick is the name of my brother in addition to my colleague from Red Hat.",
|
||||
]
|
||||
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
|
||||
PROMPT_LEN = 5
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def _create_random_top_logprob_test_vector(
|
||||
num_logprobs: int,
|
||||
lower: float,
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random vector of top logprob float values.
|
||||
|
||||
Use to create fake sample logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
logprobs to be sorted in descending order, something
|
||||
which is omitted in this function.
|
||||
|
||||
Args:
|
||||
num_logprobs: number of top logprobs
|
||||
lower: lower range of logprob float values
|
||||
upper: upper range of logprob float values
|
||||
|
||||
Returns:
|
||||
1D length-`num_logprobs` torch Tensor of float logprob values
|
||||
"""
|
||||
return torch.rand(num_logprobs) * (upper - lower) + lower
|
||||
|
||||
|
||||
def _create_random_top_logprob_test_matrix(
|
||||
shape: tuple,
|
||||
lower: float,
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random matrix of top logprob float values.
|
||||
|
||||
Use to create fake prompt logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
logprobs to be sorted in descending order along rows,
|
||||
something which is omitted in this function.
|
||||
|
||||
Args:
|
||||
shape: (num_tokens,num_logprobs) tuple representing
|
||||
matrix shape
|
||||
lower: lower range of logprob float values
|
||||
upper: upper range of logprob float values
|
||||
|
||||
Returns:
|
||||
2D num_tokens x num_logprobs torch Tensor of float logprob values
|
||||
"""
|
||||
return torch.rand(*shape) * (upper - lower) + lower
|
||||
|
||||
|
||||
def _create_random_top_token_test_vector(
|
||||
num_logprobs: int,
|
||||
lower: int,
|
||||
upper: int,
|
||||
sampled_token_id: int,
|
||||
adjust_num_logprobs: bool = True,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Create a random vector of top logprob token indices
|
||||
|
||||
Use to create fake sample logprobs for testing. The sampled token
|
||||
ID must always be one of the top logprobs, which this dummy test
|
||||
vector generator enforces. OpenAI API
|
||||
compatible engines must be able to return an additional sample
|
||||
logprob for the sampled token if the sampled token was not
|
||||
among the top sample logprobs; `adjust_num_logprobs` emulates
|
||||
this behavior by increasing the vector length by 1 if
|
||||
`adjust_num_logprobs` is set.
|
||||
|
||||
Args:
|
||||
num_logprobs: number of top logprobs
|
||||
lower: lower range of token ids
|
||||
upper: upper range of token ids
|
||||
sampled_token_id: the token actually sampled
|
||||
adjust_num_logprobs: if True, emulate situation where sampled
|
||||
token logprob must be injected into top
|
||||
logprobs
|
||||
|
||||
Returns:
|
||||
1D length-x torch Tensor of token ids where x is
|
||||
`num_logprobs+1` if `adjust_num_logprobs` and
|
||||
`num_logprobs` otherwise
|
||||
sampled_token_rank: the rank of sampled_token_id in the vocab
|
||||
vector when sorted in descending order by
|
||||
logprob
|
||||
"""
|
||||
|
||||
# Calculate the final number of logprobs required
|
||||
total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs
|
||||
|
||||
# Generate random indices using torch
|
||||
choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower
|
||||
|
||||
# Ensure the sampled token ID is included in the tensor
|
||||
choice_tensor[0] = sampled_token_id
|
||||
|
||||
# Check if the sampled_token_id occurs in choice_tensor[1:]
|
||||
if sampled_token_id in choice_tensor[1:]:
|
||||
sampled_token_rank = (
|
||||
(choice_tensor[1:] == sampled_token_id).nonzero(as_tuple=True)[0].item()
|
||||
)
|
||||
else:
|
||||
# If not found, assign a random int between num_logprobs and 50700
|
||||
sampled_token_rank = random.randint(num_logprobs, 50700)
|
||||
|
||||
return choice_tensor, sampled_token_rank
|
||||
|
||||
|
||||
def _create_random_top_token_test_matrix(
|
||||
shape: tuple[int, int],
|
||||
lower: int,
|
||||
upper: int,
|
||||
tokens_list: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Create a random matrix of top logprob token indices
|
||||
|
||||
Use to create fake prompt logprobs for testing.
|
||||
|
||||
Token ids are generated randomly and sampled without
|
||||
replacement.
|
||||
|
||||
Args:
|
||||
shape: (num_tokens, num_logprobs) tuple representing
|
||||
matrix shape
|
||||
lower: lower range of token ids
|
||||
upper: upper range of token ids
|
||||
|
||||
Returns:
|
||||
tuple containing:
|
||||
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
|
||||
- 1D tensor of ranks of prompt tokens in their respective
|
||||
rows, or random values
|
||||
"""
|
||||
num_elements = shape[0] * shape[1]
|
||||
choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
|
||||
matrix = torch.cat(
|
||||
(
|
||||
torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
|
||||
choice_tensor.view(shape),
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Initialize the tensor for storing the ranks
|
||||
prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)
|
||||
|
||||
# Iterate over each row to check presence of
|
||||
# tokens_list[rdx] and determine its index
|
||||
for rdx in range(shape[0]):
|
||||
row = matrix[rdx, 1:] # Skip the first column as it contains the token list
|
||||
token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
|
||||
if token_index.numel() > 0:
|
||||
prompt_token_ranks[rdx] = token_index.item()
|
||||
else:
|
||||
prompt_token_ranks[rdx] = random.randint(shape[1], 50700)
|
||||
|
||||
return matrix, prompt_token_ranks
|
||||
|
||||
|
||||
def decode_token(
|
||||
tok_id: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> str:
|
||||
"""Reproduce the process of detokenizing a token for testing purposes.
|
||||
|
||||
Args:
|
||||
tok_id: token id to detokenize
|
||||
tokenizer: tokenizer to use for detokenization
|
||||
|
||||
Returns:
|
||||
string representation of token
|
||||
"""
|
||||
return tokenizer.convert_ids_to_tokens(tok_id)
|
||||
|
||||
|
||||
def generate_dummy_sample_logprobs(
|
||||
sampled_tokens_list: list,
|
||||
num_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> list[tuple[list[int], list[float], int]]:
|
||||
"""Generate dummy sample logprobs
|
||||
|
||||
Generate a test data structure which imitates the list of sample logprobs
|
||||
which would be assembled in the engine core during decode phase.
|
||||
|
||||
Args:
|
||||
sampled_tokens_list: list of sampled tokens
|
||||
num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token
|
||||
tokenizer: model tokenizer to use for detokenization
|
||||
|
||||
Returns
|
||||
list of (top token ids vector, logprobs vector, sampled token rank)
|
||||
Python lists tuples; in each tuple the logprobs and top token ids
|
||||
vectors have the same length which is either `num_logprobs` or
|
||||
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
|
||||
sampled token within the vocab vector when sorted by logprob in
|
||||
descending order.
|
||||
"""
|
||||
res = []
|
||||
for sampled_token_id in sampled_tokens_list:
|
||||
(
|
||||
token_vector,
|
||||
sampled_token_rank,
|
||||
) = _create_random_top_token_test_vector(
|
||||
num_logprobs, 0, len(tokenizer.vocab) - 1, sampled_token_id
|
||||
)
|
||||
|
||||
res.append(
|
||||
(
|
||||
token_vector,
|
||||
_create_random_top_logprob_test_vector(num_logprobs + 1, -100, 0),
|
||||
sampled_token_rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert tensors in the list tuples to Python lists
|
||||
res_list_format = [
|
||||
(log_probs_tensor.tolist(), token_ids_tensor.tolist(), sampled_token_rank)
|
||||
for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
|
||||
]
|
||||
|
||||
return res_list_format
|
||||
|
||||
|
||||
def generate_dummy_prompt_logprobs_tensors(
|
||||
prompt_tokens_list: list,
|
||||
num_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> LogprobsTensors:
|
||||
"""Generate dummy prompt logprobs tensors
|
||||
|
||||
Generate a test data structure which imitates the torch Tensors of prompt
|
||||
logprobs which would be assembled in the engine core during chunked
|
||||
prefill.
|
||||
|
||||
Args:
|
||||
prompt_tokens_list: list of prompt tokens
|
||||
num_logprobs: return `num_logprobs` logprobs per token
|
||||
tokenizer: model tokenizer to use for detokenization
|
||||
|
||||
Returns
|
||||
Single tuple of (logprobs matrix, top token ids matrix) torch Tensor,
|
||||
where both matrices have dimensions
|
||||
num_prompt_tokens x num_logprobs
|
||||
"""
|
||||
# For now, assume the whole prompt is processed in one chunk; thus,
|
||||
# the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`.
|
||||
# Prior to injecting `None` at the beginning of prompt logprobs (which
|
||||
# happens later in the detokenizer, not here), the prompt logprobs in
|
||||
# the ith position are predicting the probability distribution of the
|
||||
# prompt token in (i+1)st position. Thus, we concat
|
||||
# `prompt_tokens_list[1:]` to the dummy token ids, just as the engine
|
||||
# would.
|
||||
num_prompt_logprobs = len(prompt_tokens_list) - 1
|
||||
(
|
||||
token_vector,
|
||||
prompt_token_ranks,
|
||||
) = _create_random_top_token_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs),
|
||||
0,
|
||||
len(tokenizer.vocab) - 1,
|
||||
prompt_tokens_list[1:],
|
||||
)
|
||||
return LogprobsTensors(
|
||||
token_vector,
|
||||
_create_random_top_logprob_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs + 1), -100, 0
|
||||
),
|
||||
prompt_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyOutputProcessorTestVectors:
|
||||
"""Dummy test vectors for output processor tests"""
|
||||
|
||||
tokenizer: GeneralTokenizerType
|
||||
vllm_config: EngineArgs
|
||||
full_tokens: list[list[int]] # Prompt + generated tokens
|
||||
prompt_tokens: list[list[int]]
|
||||
generation_tokens: list[list[int]]
|
||||
# Each request is associated with a tuple of
|
||||
# (top tokens, top logprobs, ranks) prompt logprobs tensors
|
||||
prompt_logprobs: list[LogprobsTensors]
|
||||
# Each request is associated with a sample logprobs; a request's
|
||||
# sample logprobs are a list of (top tokens, top logprobs, ranks)
|
||||
# sample logprobs tensors at each sequence position
|
||||
generation_logprobs: list[list[tuple[list[int], list[float], int]]]
|
||||
prompt_strings: list[str]
|
||||
prompt_strings_len: list[int]
|
||||
generation_strings: list[str]
|
||||
|
||||
|
||||
class MockEngineCore:
|
||||
"""Mock engine core outputs form premade tokens lists."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokens_list: list[list[int]],
|
||||
# For each request, for each sampled token offset,
|
||||
# a tuple of
|
||||
# (list of topk token ids, list of sample logprob vals, rank)
|
||||
generated_logprobs_raw: list[list[tuple[list[int], list[float], int]]]
|
||||
| None = None,
|
||||
# For each request, a tuple of
|
||||
# (prompt logprob val matrix, prompt logprob tok id matrix);
|
||||
# each matrix has dimensions
|
||||
# (num prompt toks) x (num prompt logprobs+1)
|
||||
prompt_logprobs_raw: list[LogprobsTensors] | None = None,
|
||||
eos_token_id: int | None = None,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
ignore_eos: bool = False,
|
||||
) -> None:
|
||||
self.num_requests = len(tokens_list)
|
||||
self.tokens_list = tokens_list
|
||||
self.current_idx = 0
|
||||
self.generated_logprobs_raw = generated_logprobs_raw
|
||||
self.do_logprobs = generated_logprobs_raw is not None
|
||||
self.prompt_logprobs_raw = prompt_logprobs_raw
|
||||
self.do_prompt_logprobs = prompt_logprobs_raw is not None
|
||||
self.request_finished = [False for _ in range(self.num_requests)]
|
||||
self.eos_token_id = eos_token_id
|
||||
self.stop_token_ids = stop_token_ids
|
||||
self.ignore_eos = ignore_eos
|
||||
|
||||
def get_outputs(self) -> list[EngineCoreOutput]:
|
||||
do_logprobs = self.do_logprobs
|
||||
do_prompt_logprobs = self.do_prompt_logprobs
|
||||
token_idx = self.current_idx
|
||||
|
||||
outputs = []
|
||||
for req_idx, token_ids in enumerate(self.tokens_list):
|
||||
if not self.request_finished[req_idx]:
|
||||
if do_logprobs:
|
||||
assert self.generated_logprobs_raw is not None
|
||||
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
|
||||
self.generated_logprobs_raw[req_idx][token_idx]
|
||||
)
|
||||
logprobs = LogprobsLists(
|
||||
np.array([logprobs_token_ids_]),
|
||||
np.array([logprobs_]),
|
||||
np.array([sampled_token_ranks_]),
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
if do_prompt_logprobs:
|
||||
if self.current_idx == 0:
|
||||
assert self.prompt_logprobs_raw is not None
|
||||
prompt_logprobs = self.prompt_logprobs_raw[req_idx]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
new_token_id = token_ids[token_idx]
|
||||
output = EngineCoreOutput(
|
||||
request_id=f"request-{req_idx}",
|
||||
new_token_ids=[new_token_id],
|
||||
new_logprobs=logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs,
|
||||
)
|
||||
if token_idx == len(token_ids) - 1:
|
||||
output.finish_reason = FinishReason.LENGTH
|
||||
self.request_finished[req_idx] = True
|
||||
if not self.ignore_eos and new_token_id == self.eos_token_id:
|
||||
output.finish_reason = FinishReason.STOP
|
||||
self.request_finished[req_idx] = True
|
||||
if new_token_id in (self.stop_token_ids or ()):
|
||||
output.finish_reason = FinishReason.STOP
|
||||
output.stop_reason = new_token_id
|
||||
self.request_finished[req_idx] = True
|
||||
outputs.append(output)
|
||||
|
||||
self.current_idx += 1
|
||||
return outputs
|
||||
Reference in New Issue
Block a user