Sync from v0.13
This commit is contained in:
0
tests/v1/kv_connector/unit/__init__.py
Normal file
0
tests/v1/kv_connector/unit/__init__.py
Normal file
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for backwards compatibility with external KV connector implementations.
|
||||
|
||||
This test ensures that external connectors (loaded via kv_connector_module_path)
|
||||
implemented with the old signature continue to work:
|
||||
- Old signature: __init__(self, vllm_config, role)
|
||||
- New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from .utils import create_scheduler, create_vllm_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class OldStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the old signature with 2 required arguments.
|
||||
This simulates external connectors that haven't been updated yet.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
# Old-style call to super().__init__ with only 2 arguments
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class NewStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the new signature with 3 required arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
# New-style call to super().__init__ with all 3 arguments
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_old_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with old signature (2 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated with backwards
|
||||
compatibility support.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, OldStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_new_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with new signature (3 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, NewStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_old_signature_super_init(role):
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() without
|
||||
kv_cache_config parameter.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = OldStyleTestConnector(vllm_config, role)
|
||||
|
||||
assert connector is not None
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
def test_old_signature_super_init_with_kwargs():
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() with keyword
|
||||
arguments in different orders.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test with vllm_config= and role= kwargs
|
||||
connector1 = OldStyleTestConnector(
|
||||
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
|
||||
)
|
||||
assert connector1 is not None
|
||||
assert connector1._kv_cache_config is None
|
||||
|
||||
# Test with role= and vllm_config= in reversed order
|
||||
connector2 = OldStyleTestConnector(
|
||||
role=KVConnectorRole.WORKER, vllm_config=vllm_config
|
||||
)
|
||||
assert connector2 is not None
|
||||
assert connector2._kv_cache_config is None
|
||||
|
||||
|
||||
def test_internal_connector_uses_new_signature():
|
||||
"""
|
||||
Test that internal connectors (registered in factory) always use the new
|
||||
signature and get kv_cache_config.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
|
||||
ExampleConnector,
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, ExampleConnector)
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
def test_signature_detection_with_mocking():
|
||||
"""
|
||||
Test that the factory correctly applies compat_sig flag returned from
|
||||
_get_connector_class_with_compat.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
# Mock _get_connector_class_with_compat to return old-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(OldStyleTestConnector, True),
|
||||
):
|
||||
old_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert old_connector is not None
|
||||
assert isinstance(old_connector, OldStyleTestConnector)
|
||||
assert old_connector._kv_cache_config is None
|
||||
|
||||
# Mock _get_connector_class_with_compat to return new-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(NewStyleTestConnector, False),
|
||||
):
|
||||
new_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert new_connector is not None
|
||||
assert isinstance(new_connector, NewStyleTestConnector)
|
||||
assert new_connector._kv_cache_config is not None
|
||||
assert new_connector._kv_cache_config == kv_cache_config
|
||||
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
test that invalid blocks are evicted from prefix cache to prevent pollution.
|
||||
|
||||
verifies that when sync-loading fails, invalid blocks are removed from the
|
||||
prefix cache hash table so future requests cannot match and reuse corrupted data.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_invalid_blocks_evicted_prevents_cache_pollution(
|
||||
fail_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
verify invalid blocks are evicted to prevent future cache hits.
|
||||
|
||||
scenario:
|
||||
1. request 1 loads externally-computed blocks (sync mode)
|
||||
2. some blocks fail to load and are marked invalid
|
||||
3. with fail policy, invalid blocks should be evicted from prefix cache
|
||||
4. request is marked as FINISHED_ERROR
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
# request 1: will have invalid blocks
|
||||
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
|
||||
fail_scheduler.add_request(request=request1)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request1.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify eviction later
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# cache the blocks to simulate they've been computed and cached
|
||||
# (in real scenario blocks would be cached after compute)
|
||||
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
|
||||
|
||||
# verify block has a hash (is cached) before reporting invalid blocks
|
||||
assert block.block_hash is not None, (
|
||||
f"block {invalid_block_id} should be cached (have a hash) before "
|
||||
f"eviction test, but hash is None"
|
||||
)
|
||||
|
||||
# report invalid blocks
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request finished with error (fail policy)
|
||||
assert request1.status == RequestStatus.FINISHED_ERROR
|
||||
|
||||
# critical assertion: invalid block and all subsequent blocks should be evicted
|
||||
# all blocks from invalid_block_idx onwards become invalid since they were
|
||||
# computed based on the failed block
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is None, (
|
||||
f"block {block_id} at index {idx} should have been evicted "
|
||||
f"(hash reset to None), but hash is {block_obj.block_hash}. "
|
||||
f"All blocks from index {invalid_block_idx} onwards should be evicted "
|
||||
f"since they depend on the invalid block at index {invalid_block_idx}."
|
||||
)
|
||||
|
||||
# verify cache contains exactly the valid blocks (before first affected block)
|
||||
# and none of the invalid blocks (from first affected block onwards)
|
||||
|
||||
# valid blocks: all blocks before invalid_block_idx should be cached
|
||||
for idx in range(invalid_block_idx):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is not None, (
|
||||
f"valid block {block_id} at index {idx} should still be cached "
|
||||
f"(have a hash), but hash is None. Only blocks from index "
|
||||
f"{invalid_block_idx} onwards should be evicted."
|
||||
)
|
||||
|
||||
# invalid blocks: verify they're not in the cached_block_hash_to_block map
|
||||
cached_blocks = (
|
||||
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||
)
|
||||
cached_block_ids = {
|
||||
b.block_id
|
||||
for blocks_val in cached_blocks._cache.values()
|
||||
for b in (
|
||||
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
|
||||
)
|
||||
}
|
||||
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
assert block_id not in cached_block_ids, (
|
||||
f"invalid block {block_id} at index {idx} should not be in cache hash table"
|
||||
)
|
||||
65
tests/v1/kv_connector/unit/test_config.py
Normal file
65
tests/v1/kv_connector/unit/test_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tests for KV cache offloading configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||
[
|
||||
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
|
||||
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||
(None, None, 1, 1, None, None),
|
||||
],
|
||||
)
|
||||
def test_kv_connector(
|
||||
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||
):
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||
if expected_backend is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_offloading_backend=kv_offloading_backend,
|
||||
kv_offloading_size=kv_offloading_size,
|
||||
),
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp, pipeline_parallel_size=pp
|
||||
),
|
||||
)
|
||||
|
||||
# No KV transfer config expected
|
||||
if expected_backend is None:
|
||||
assert vllm_config.kv_transfer_config is expected_backend
|
||||
return
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
assert kv_transfer_config.kv_connector == expected_backend
|
||||
assert kv_transfer_config.kv_role == "kv_both"
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
|
||||
assert kv_connector_extra_config["num_cpu_blocks"] == 0
|
||||
# Existing config should be preserved
|
||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||
# Existing config should be replaced
|
||||
assert "existing_key" not in kv_connector_extra_config
|
||||
415
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
415
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for DecodeBenchConnector.
|
||||
|
||||
Tests the functionality of the DecodeBenchConnector which fills KV cache
|
||||
with dummy values for decode performance benchmarking.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
|
||||
# ruff: noqa: E501
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import (
|
||||
DecodeBenchConnector,
|
||||
DecodeBenchConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (
|
||||
EOS_TOKEN_ID,
|
||||
create_model_runner_output,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class DecodeBenchTestRunner:
|
||||
"""Test runner for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, block_size: int, num_gpu_blocks: int):
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
self.req_id = -1
|
||||
|
||||
# Create vllm config with DecodeBenchConnector
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="DecodeBenchConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler: Scheduler = create_scheduler(
|
||||
vllm_config, num_blocks=num_gpu_blocks
|
||||
)
|
||||
|
||||
# Create worker-side connector
|
||||
self.worker_connector = DecodeBenchConnector(
|
||||
vllm_config, KVConnectorRole.WORKER
|
||||
)
|
||||
|
||||
# Create dummy KV caches for testing
|
||||
# Shape: [num_blocks, 2, num_heads, block_size, head_dim]
|
||||
# Using simplified shape for testing
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
self.kv_caches = {
|
||||
f"layer_{i}": torch.zeros(
|
||||
num_gpu_blocks, 2, num_heads, block_size, head_dim
|
||||
)
|
||||
for i in range(2) # 2 layers for testing
|
||||
}
|
||||
|
||||
# Register KV caches with worker connector
|
||||
self.worker_connector.register_kv_caches(self.kv_caches)
|
||||
|
||||
# Extract scheduler-side connector
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, DecodeBenchConnector)
|
||||
self.scheduler_connector: DecodeBenchConnector = scheduler_connector
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]) -> Request:
|
||||
"""Create a new request with given token IDs."""
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=100),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
return req
|
||||
|
||||
def run_single_step(self, token_id: int = 0):
|
||||
"""Run a single scheduler + worker step."""
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
# Get connector metadata
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata)
|
||||
|
||||
# Bind metadata and load KV
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
# Create model runner output
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
return scheduler_output, kv_connector_metadata
|
||||
|
||||
|
||||
def test_decode_bench_connector_basic():
|
||||
"""Test basic functionality of DecodeBenchConnector."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with multiple blocks worth of tokens
|
||||
num_tokens = block_size * 3 # 3 blocks
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache with dummy values
|
||||
scheduler_output, metadata = runner.run_single_step()
|
||||
|
||||
# Check that get_num_new_matched_tokens returned correct value
|
||||
# Should be num_tokens - 1 (all except the last token for decode)
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
|
||||
# Check metadata has the request to fill
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify KV caches were filled with constant value
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
# Check that the block was filled
|
||||
block_data = kv_cache[block_id]
|
||||
# Should be filled with constant value 0.015
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_no_refill():
|
||||
"""Test that DecodeBenchConnector only fills once per request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request
|
||||
num_tokens = block_size * 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache
|
||||
_, metadata1 = runner.run_single_step()
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
|
||||
# Run second step - should NOT fill again (already filled)
|
||||
_, metadata2 = runner.run_single_step()
|
||||
assert len(metadata2.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_single_token():
|
||||
"""Test DecodeBenchConnector with single token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with just 1 token
|
||||
# Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode)
|
||||
token_ids = [1]
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run step - should NOT fill KV cache
|
||||
_, metadata = runner.run_single_step()
|
||||
assert len(metadata.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_two_tokens():
|
||||
"""Test DecodeBenchConnector with two token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with 2 tokens
|
||||
# Should fill 1 token (first token), decode the second
|
||||
token_ids = [1, 2]
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == 1
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block
|
||||
|
||||
|
||||
def test_decode_bench_connector_large_context():
|
||||
"""Test DecodeBenchConnector with large context size."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 1000
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with many blocks
|
||||
num_blocks = 20
|
||||
num_tokens = block_size * num_blocks
|
||||
token_ids = list(range(num_tokens))
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify blocks were filled
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
block_data = kv_cache[block_id]
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_multiple_requests():
|
||||
"""Test DecodeBenchConnector with multiple sequential requests."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# First request
|
||||
req1 = runner.new_request([1] * (block_size * 2))
|
||||
_, metadata1 = runner.run_single_step()
|
||||
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
assert req1.request_id in metadata1.reqs_to_fill
|
||||
|
||||
# Complete first request
|
||||
while runner.scheduler.running:
|
||||
runner.run_single_step()
|
||||
|
||||
# Add EOS to finish
|
||||
scheduler_output = runner.scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=runner.scheduler.running,
|
||||
token_id=EOS_TOKEN_ID,
|
||||
use_eos=True,
|
||||
)
|
||||
runner.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Second request - should also get filled
|
||||
req2 = runner.new_request([2] * (block_size * 3))
|
||||
_, metadata2 = runner.run_single_step()
|
||||
|
||||
assert len(metadata2.reqs_to_fill) == 1
|
||||
assert req2.request_id in metadata2.reqs_to_fill
|
||||
|
||||
# Different request should have different metadata
|
||||
_, num_tokens1 = metadata1.reqs_to_fill[req1.request_id]
|
||||
_, num_tokens2 = metadata2.reqs_to_fill[req2.request_id]
|
||||
|
||||
assert num_tokens1 == block_size * 2 - 1
|
||||
assert num_tokens2 == block_size * 3 - 1
|
||||
|
||||
|
||||
def test_decode_bench_connector_partial_block():
|
||||
"""Test DecodeBenchConnector with partial block filling."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request that doesn't align to block boundaries
|
||||
# e.g., 2.5 blocks worth of tokens
|
||||
num_tokens = block_size * 2 + block_size // 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Should allocate 3 blocks to hold the partial data
|
||||
expected_num_blocks = 3
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
|
||||
def test_decode_bench_connector_concurrent_requests():
|
||||
"""Test DecodeBenchConnector with multiple concurrent requests in the same batch."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 1000
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create multiple requests that will be batched together
|
||||
req1 = runner.new_request([1] * (block_size * 2))
|
||||
req2 = runner.new_request([2] * (block_size * 3))
|
||||
req3 = runner.new_request([3] * (block_size * 1))
|
||||
|
||||
# Run first step - all requests should be filled concurrently
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
# All three requests should be in the metadata
|
||||
assert len(metadata.reqs_to_fill) == 3
|
||||
assert req1.request_id in metadata.reqs_to_fill
|
||||
assert req2.request_id in metadata.reqs_to_fill
|
||||
assert req3.request_id in metadata.reqs_to_fill
|
||||
|
||||
# Verify each request has correct fill info
|
||||
block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id]
|
||||
block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id]
|
||||
block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id]
|
||||
|
||||
# Verify token counts (all tokens except last one)
|
||||
assert num_tokens1 == block_size * 2 - 1
|
||||
assert num_tokens2 == block_size * 3 - 1
|
||||
assert num_tokens3 == block_size * 1 - 1
|
||||
|
||||
# Verify block counts for each request
|
||||
assert len(block_ids_per_group1[0]) == 2 # 2 blocks
|
||||
assert len(block_ids_per_group2[0]) == 3 # 3 blocks
|
||||
assert len(block_ids_per_group3[0]) == 1 # 1 block
|
||||
|
||||
# Verify all blocks are filled in KV cache
|
||||
for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items():
|
||||
block_ids = block_ids_per_group[0]
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
block_data = kv_cache[block_id]
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
# Run second step - should NOT fill again (already filled)
|
||||
_, metadata2 = runner.run_single_step()
|
||||
assert len(metadata2.reqs_to_fill) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
|
||||
def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.waiting) == 0
|
||||
256
tests/v1/kv_connector/unit/test_example_connector.py
Normal file
256
tests/v1/kv_connector/unit/test_example_connector.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
|
||||
|
||||
TEXT_PROMPTS = [
|
||||
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class InputCase(NamedTuple):
|
||||
text: str
|
||||
img: list[Image]
|
||||
expected_len: int
|
||||
info: str
|
||||
|
||||
|
||||
def _check_path_len(path):
|
||||
"""Return the latest length in path"""
|
||||
return len(list(path.iterdir()))
|
||||
|
||||
|
||||
def _list_path(path):
|
||||
"""Return the list of foldername (hashes generated) under the path"""
|
||||
return list(path.iterdir())
|
||||
|
||||
|
||||
def run_test(
|
||||
tmp_path,
|
||||
processor,
|
||||
llm: LLM,
|
||||
question: str,
|
||||
image_urls: list[Image],
|
||||
expected_len: int,
|
||||
info: str,
|
||||
):
|
||||
"""
|
||||
One individual test to process the prompt and output base on 1 set of input
|
||||
Then check if the length in the storage path matches the expected length
|
||||
`info` introduces details or purpose of the individual test
|
||||
"""
|
||||
print(f"***info: {info}***")
|
||||
print(f"**Expected storage path length after llm generate: {expected_len}**")
|
||||
process_prompt(processor, llm, question, image_urls)
|
||||
|
||||
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
|
||||
print(f"Hashes under the storage path: {_list_path(tmp_path)}")
|
||||
|
||||
assert _check_path_len(tmp_path) == expected_len, (
|
||||
f"Expect storage path length {expected_len} ;",
|
||||
f"but end up {_check_path_len(tmp_path)} instead. ",
|
||||
f"Info: {info}",
|
||||
)
|
||||
|
||||
|
||||
def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
|
||||
"""
|
||||
Form the prompt based on the text and image input, then llm generate output
|
||||
"""
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
|
||||
}
|
||||
for image_pil in image_urls
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}),
|
||||
},
|
||||
sampling_params=SAMPLING_PARAMS,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("Output:")
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"hipErrorLaunchFailure when running this test, see issue:"
|
||||
"https://github.com/ROCm/pytorch/issues/2822"
|
||||
),
|
||||
)
|
||||
def test_shared_storage_connector_hashes(tmp_path):
|
||||
"""
|
||||
Tests that ExampleConnector saves KV to the storage locations
|
||||
with proper hashes; that are unique for inputs with identical text but
|
||||
different images (same size), or same multiple images but different orders.
|
||||
"""
|
||||
# Using tmp_path as the storage path to store KV
|
||||
print(f"KV storage path at: {str(tmp_path)}")
|
||||
|
||||
# Configure the ExampleConnector
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=1,
|
||||
gpu_memory_utilization=0.4,
|
||||
enforce_eager=True,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
)
|
||||
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
# Create processor to handle the chat prompt
|
||||
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||
|
||||
# Prepare images for the tests
|
||||
# Resize to the same size to check hashes correctness
|
||||
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
|
||||
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
|
||||
|
||||
# Make sure that they are not the same picture
|
||||
assert image_1 != image_2, "The images should not be identical"
|
||||
|
||||
# Create the LLM instance
|
||||
engine_args = asdict(engine_args)
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# Prepare the input cases
|
||||
input_cases = [
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=1,
|
||||
info="image_1 single input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the first time. "
|
||||
"It is in same pixel size with image_1, yet it "
|
||||
"should be able to form a new unique hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_1 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=3,
|
||||
info="image_1 with image_2 input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info="The image order is swapped. Should form new hash.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_1, image_2] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_2, image_1] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Pure text input test as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Identical pure text input as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[1],
|
||||
img=[],
|
||||
expected_len=6,
|
||||
info="Another pure text input as a case-control",
|
||||
),
|
||||
]
|
||||
|
||||
# Run tests
|
||||
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
|
||||
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
|
||||
run_test(tmp_path, processor, llm, text, img, expected_len, info)
|
||||
|
||||
print("All tests passed successfully!")
|
||||
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Tests for correctness in invalid block handling.
|
||||
|
||||
These tests verify correct behavior in three scenarios:
|
||||
1. Sync recompute case: Blocks should not be freed for running requests
|
||||
that need to recompute invalid blocks
|
||||
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
|
||||
3. Async recompute case: Invalid blocks should not be cached after transfer
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recompute_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='recompute'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_sync_recompute_blocks_not_freed_for_running_requests(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test sync recompute case - blocks must not be freed for running requests.
|
||||
|
||||
When a running request has invalid blocks and retry_policy is 'recompute':
|
||||
1. Request should remain in RUNNING state
|
||||
2. num_computed_tokens should be truncated to invalid block boundary
|
||||
3. Blocks should NOT be freed (request still needs them for recomputation)
|
||||
4. Request should remain in scheduler.requests and scheduler.running
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(recompute_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get the allocated block IDs before invalid blocks are reported
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
|
||||
# store original num_computed_tokens for comparison
|
||||
original_num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False, # not finished - should continue running
|
||||
)
|
||||
|
||||
outputs = recompute_scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# critical assertions for recompute case:
|
||||
|
||||
# 1. request should still be RUNNING (not finished, not aborted)
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
f"Request should remain RUNNING for recompute, got {request.status}"
|
||||
)
|
||||
|
||||
# 2. num_computed_tokens should be truncated to first invalid block
|
||||
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_truncated_tokens, (
|
||||
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
assert request.num_computed_tokens < original_num_computed_tokens, (
|
||||
"num_computed_tokens should be reduced after invalid block detection"
|
||||
)
|
||||
|
||||
# 3. no output should be generated (request is still running)
|
||||
# the request should be skipped in the output loop
|
||||
assert len(outputs) == 0 or request.request_id not in [
|
||||
out.request_id for outs in outputs.values() for out in outs.outputs
|
||||
], "No output should be generated for recompute requests"
|
||||
|
||||
# 4. request should still be in running queue
|
||||
assert request in recompute_scheduler.running, (
|
||||
"Request should remain in running queue for recomputation"
|
||||
)
|
||||
|
||||
# 5. request should still be in scheduler.requests (not deleted)
|
||||
assert request.request_id in recompute_scheduler.requests, (
|
||||
"Request should not be deleted from scheduler.requests"
|
||||
)
|
||||
|
||||
# 6. blocks should NOT be freed - verify blocks are still allocated
|
||||
try:
|
||||
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
assert allocated_blocks is not None
|
||||
assert len(allocated_blocks[0]) > 0, (
|
||||
"Blocks should still be allocated for recomputation"
|
||||
)
|
||||
except KeyError:
|
||||
pytest.fail(
|
||||
"Blocks were freed incorrectly! Running requests need their blocks "
|
||||
"to recompute invalid portions."
|
||||
)
|
||||
|
||||
# 7. verify request can be rescheduled in next step
|
||||
scheduler_output_2 = recompute_scheduler.schedule()
|
||||
|
||||
# request should appear in the new schedule to recompute invalid blocks
|
||||
scheduled_req_ids = [
|
||||
req.request_id for req in scheduler_output_2.scheduled_new_reqs
|
||||
]
|
||||
if scheduler_output_2.num_scheduled_tokens:
|
||||
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
|
||||
|
||||
assert (
|
||||
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
|
||||
), "Request should be reschedulable for recomputation"
|
||||
|
||||
|
||||
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||
"""
|
||||
Test sync fail case - invalid blocks must be evicted from cache.
|
||||
|
||||
When a request fails with policy='fail' and has invalid blocks from sync loading:
|
||||
1. Request should be finished with FINISHED_ERROR
|
||||
2. Invalid blocks should be evicted from the KV cache
|
||||
3. Valid blocks (if shared) should remain in cache
|
||||
4. Future requests should not reuse the invalid blocks
|
||||
|
||||
This test verifies that invalid blocks are properly evicted to prevent
|
||||
cache corruption and reuse of invalid data.
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# verify the block is in the block pool before we report it as invalid
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
assert block is not None
|
||||
|
||||
# report invalid blocks - request should fail
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request is finished with error
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
# verify output is generated
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
# verify the request was removed from scheduler
|
||||
assert request.request_id not in fail_scheduler.requests
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
# critical: verify invalid block was actually freed from cache
|
||||
# this is the key assertion - the invalid block should no longer be
|
||||
# tracked by the KV cache manager for this request
|
||||
# if it's still there, a future request could reuse the invalid data
|
||||
try:
|
||||
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
# if we get here, check if blocks were actually freed
|
||||
if block_ids is not None and len(block_ids[0]) > 0:
|
||||
pytest.fail(
|
||||
f"Invalid blocks still tracked for finished request! "
|
||||
f"Request {request.request_id} should have been freed but "
|
||||
f"still has {len(block_ids[0])} blocks allocated."
|
||||
)
|
||||
# blocks list exists but is empty - this is fine, they were freed
|
||||
except KeyError:
|
||||
# expected - request completely removed from tracking
|
||||
pass
|
||||
|
||||
# critical: verify invalid block was evicted from prefix cache
|
||||
# the block should no longer have a hash (hash is reset on eviction)
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should have been evicted from cache "
|
||||
f"(hash should be None), but hash is still {block.block_hash}"
|
||||
)
|
||||
|
||||
|
||||
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test async recompute case - invalid blocks not cached after transfer.
|
||||
|
||||
When async KV loading has invalid blocks and retry_policy is 'recompute':
|
||||
1. Blocks are allocated but not cached yet
|
||||
2. When async transfer completes, only valid blocks should be cached
|
||||
3. Invalid blocks should never enter the prefix cache
|
||||
|
||||
This test verifies correctness, the failed_recving_kv_req_ids protection
|
||||
ensures only valid blocks are cached when the transfer completes, and we
|
||||
only evict blocks from cache that are already hashed in the block table.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating async load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be waiting for remote KVs
|
||||
assert len(recompute_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# get the allocated block IDs
|
||||
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify it's not cached yet and stays uncached
|
||||
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# verify block has no hash before invalid blocks are reported
|
||||
assert block.block_hash is None, (
|
||||
"Async loading blocks should not be cached yet (no hash)"
|
||||
)
|
||||
|
||||
# report invalid blocks (transfer not finished yet)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=None, # transfer NOT finished
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
# critical: spy on evict_blocks to verify it's NOT called for async blocks
|
||||
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
|
||||
evict_blocks_calls = []
|
||||
|
||||
def evict_blocks_spy(block_ids):
|
||||
evict_blocks_calls.append(set(block_ids))
|
||||
return original_evict_blocks(block_ids)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||
):
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||
assert len(evict_blocks_calls) == 0, (
|
||||
f"evict_blocks should not be called for async-only invalid blocks, "
|
||||
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
|
||||
)
|
||||
|
||||
# request should still be waiting (not finished with error due to recompute policy)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# verify num_computed_tokens was truncated to before invalid block
|
||||
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_valid_tokens
|
||||
|
||||
# verify invalid block still has no hash (was not evicted)
|
||||
assert block.block_hash is None, (
|
||||
f"Async loading blocks shouldn't be cached or evicted. "
|
||||
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||
)
|
||||
|
||||
# now simulate async transfer completing
|
||||
model_runner_output_2 = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request.request_id},
|
||||
invalid_block_ids=None,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
|
||||
|
||||
# verify request is now marked as finished receiving and ready to be processed
|
||||
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# critical: verify invalid block still has no hash before recompute
|
||||
# the async transfer invalid data was never cached
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should not be cached before recompute "
|
||||
f"(hash should be None), but hash is {block.block_hash}"
|
||||
)
|
||||
|
||||
# critical end-to-end test: spy on cache_blocks to verify it's called with
|
||||
# the truncated num_computed_tokens value
|
||||
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
|
||||
cache_blocks_calls = []
|
||||
|
||||
def cache_blocks_spy(req, num_tokens):
|
||||
cache_blocks_calls.append((req.request_id, num_tokens))
|
||||
return original_cache_blocks(req, num_tokens)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
|
||||
):
|
||||
# call schedule() again - this triggers _update_waiting_for_remote_kv()
|
||||
# which should call cache_blocks with the truncated value
|
||||
recompute_scheduler.schedule()
|
||||
|
||||
# verify cache_blocks was called with the truncated value
|
||||
assert len(cache_blocks_calls) == 1, (
|
||||
f"cache_blocks should be called exactly once, "
|
||||
f"got {len(cache_blocks_calls)} calls"
|
||||
)
|
||||
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
|
||||
assert cached_req_id == request.request_id
|
||||
assert cached_num_tokens == expected_valid_tokens, (
|
||||
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
|
||||
f"but was called with {cached_num_tokens}"
|
||||
)
|
||||
|
||||
# request should now be RUNNING (scheduled immediately after transfer completes)
|
||||
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
|
||||
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
|
||||
assert request.num_computed_tokens >= expected_valid_tokens, (
|
||||
f"num_computed_tokens should be at least {expected_valid_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
|
||||
# request should no longer be in the failed/finished receiving sets
|
||||
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
|
||||
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# request should be in the running queue
|
||||
assert request in recompute_scheduler.running
|
||||
60
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
60
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
|
||||
ExampleConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
# Importing utils registers TestExampleConnector with the factory
|
||||
from .utils import create_vllm_config
|
||||
|
||||
|
||||
def _make_empty_scheduler_output():
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
kv_connector_metadata=ExampleConnectorMetadata(),
|
||||
)
|
||||
|
||||
|
||||
def test_kv_connector_mixin_clears_metadata():
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
|
||||
vllm_config.kv_transfer_config.kv_role = "kv_both"
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
|
||||
|
||||
# Initialize the global connector instance
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
try:
|
||||
# Minimal scheduler output with empty metadata; mixin should still
|
||||
# bind/clear metadata even if no loads happen
|
||||
scheduler_output = _make_empty_scheduler_output()
|
||||
|
||||
# Invoke the no-forward path which uses the mixin context manager
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward(
|
||||
scheduler_output, vllm_config
|
||||
)
|
||||
|
||||
# Verify clear_connector_metadata was called on the connector
|
||||
connector = get_kv_transfer_group()
|
||||
assert connector._connector_metadata is None
|
||||
# Test connector wrapper records method calls
|
||||
assert connector.call_record.get("bind_connector_metadata", 0) == 1
|
||||
assert connector.call_record.get("clear_connector_metadata", 0) == 1
|
||||
finally:
|
||||
# Ensure we clean up the global connector between tests
|
||||
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()
|
||||
335
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
335
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler():
|
||||
vllm_config = create_vllm_config()
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_async_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
(req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id)
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request1.request_id, request3.request_id},
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(invalid_block_idxs)
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
if request.request_id == request2.request_id:
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
else:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.request_finished.return_value = (False, None)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 3
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0]
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2, request3],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == request2.request_id
|
||||
assert scheduler.running[0].num_computed_tokens == (
|
||||
min(invalid_block_idxs) * scheduler.block_size
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
assert scheduler.connector.request_finished.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"num_common_prefix_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, 50, {0, 49}),
|
||||
(100, 99, 50, {25, 49}),
|
||||
(100, 99, 50, {49}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure_with_shared_blocks(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
num_common_prefix_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request2)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 2
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
# Simulate a failure in loading some of the shared blocks.
|
||||
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
# all the common prefix blocks will be computed by request1
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: min(invalid_block_idxs) * scheduler.block_size,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
for request in scheduler.running:
|
||||
assert (
|
||||
request.num_computed_tokens == expected_computed_tokens[request.request_id]
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 50, 98}),
|
||||
(100, 99, {98, 50, 0}),
|
||||
],
|
||||
)
|
||||
def test_async_progressive_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
min_invalid_block_idx = max(invalid_block_idxs) + 1
|
||||
# Simulate failures when progressively loading request blocks.
|
||||
for invalid_block_idx in invalid_block_idxs:
|
||||
(req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
|
||||
LMCacheConnectorV1,
|
||||
LMCacheKVEvents,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lmcache_engine_event():
|
||||
"""Create a mock event object that mimics what the lmcache engine returns."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes,
|
||||
parent_block_hash,
|
||||
token_ids,
|
||||
lora_id,
|
||||
block_size,
|
||||
medium,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
self.parent_block_hash = parent_block_hash
|
||||
self.token_ids = token_ids
|
||||
self.lora_id = lora_id
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
return MockEvent(
|
||||
block_hashes=["hash1", "hash2"],
|
||||
parent_block_hash="parent_hash",
|
||||
token_ids=[1, 2, 3, 4],
|
||||
lora_id=None,
|
||||
block_size=16,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connector():
|
||||
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
|
||||
connector = MagicMock(spec=LMCacheConnectorV1)
|
||||
connector._kv_cache_events = None
|
||||
connector._lmcache_engine = MagicMock()
|
||||
|
||||
# Make the methods use the real implementation
|
||||
connector.get_kv_connector_kv_cache_events = (
|
||||
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.update_connector_output = (
|
||||
LMCacheConnectorV1.update_connector_output.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.take_events = LMCacheConnectorV1.take_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
class TestGetKVConnectorKVCacheEvents:
|
||||
"""Test get_kv_connector_kv_cache_events method."""
|
||||
|
||||
def test_returns_none_when_no_events(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine has no events."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
|
||||
|
||||
def test_returns_none_when_empty_list(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine returns empty list."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = []
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test conversion of a single event from lmcache engine format."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
assert result.get_number_of_workers() == 1
|
||||
|
||||
events = result.get_all_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].block_hashes == ["hash1", "hash2"]
|
||||
assert events[0].parent_block_hash == "parent_hash"
|
||||
assert events[0].token_ids == [1, 2, 3, 4]
|
||||
assert events[0].lora_id is None
|
||||
assert events[0].block_size == 16
|
||||
assert events[0].medium == "GPU"
|
||||
|
||||
def test_converts_multiple_events(self, mock_connector):
|
||||
"""Test conversion of multiple events from lmcache engine format."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, i):
|
||||
self.block_hashes = [f"hash{i}"]
|
||||
self.parent_block_hash = f"parent{i}"
|
||||
self.token_ids = [i]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
events = [MockEvent(i) for i in range(5)]
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = events
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
|
||||
converted_events = result.get_all_events()
|
||||
assert len(converted_events) == 5
|
||||
|
||||
for i, event in enumerate(converted_events):
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == [f"hash{i}"]
|
||||
assert event.parent_block_hash == f"parent{i}"
|
||||
assert event.token_ids == [i]
|
||||
|
||||
def test_preserves_event_attributes(self, mock_connector):
|
||||
"""Test that all event attributes are correctly preserved."""
|
||||
|
||||
class MockEventWithLora:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
|
||||
self.parent_block_hash = "parent_xyz"
|
||||
self.token_ids = [100, 200, 300]
|
||||
self.lora_id = 42
|
||||
self.block_size = 32
|
||||
self.medium = "DISK"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventWithLora()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
event = events[0]
|
||||
|
||||
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
|
||||
assert event.parent_block_hash == "parent_xyz"
|
||||
assert event.token_ids == [100, 200, 300]
|
||||
assert event.lora_id == 42
|
||||
assert event.block_size == 32
|
||||
assert event.medium == "DISK"
|
||||
|
||||
def test_handles_none_parent_block_hash(self, mock_connector):
|
||||
"""Test handling of events with None parent_block_hash."""
|
||||
|
||||
class MockEventNoParent:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash1"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1, 2]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventNoParent()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
assert events[0].parent_block_hash is None
|
||||
|
||||
|
||||
class TestUpdateConnectorOutput:
|
||||
"""Test update_connector_output method."""
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that method returns early when kv_cache_events is None."""
|
||||
connector_output = KVConnectorOutput(kv_cache_events=None)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that method returns early when kv_cache_events is not
|
||||
LMCacheKVEvents."""
|
||||
# Create a mock object that is not LMCacheKVEvents
|
||||
fake_events = MagicMock()
|
||||
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_sets_kv_cache_events_when_none(self, mock_connector):
|
||||
"""Test that _kv_cache_events is set when it was None."""
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1, 2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is kv_events
|
||||
|
||||
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
|
||||
"""Test that events are added when _kv_cache_events already exists."""
|
||||
# Set up existing events
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
existing_events.add_events([event1])
|
||||
existing_events.add_events([event1]) # Simulate 2 workers reporting
|
||||
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events to add
|
||||
new_events = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event2])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Check that events were added
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3 # 2 from existing + 1 from new
|
||||
assert event1 in all_events
|
||||
assert event2 in all_events
|
||||
|
||||
def test_increments_workers_when_kv_cache_events_already_exists(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that worker count is incremented correctly."""
|
||||
# Set up existing events with 2 workers
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events from 3 workers
|
||||
new_events = LMCacheKVEvents(num_workers=3)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Worker count should be 2 + 3 = 5
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
|
||||
|
||||
def test_multiple_updates(self, mock_connector):
|
||||
"""Test multiple consecutive updates."""
|
||||
# First update
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events2.add_events([event2])
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Third update
|
||||
events3 = LMCacheKVEvents(num_workers=1)
|
||||
event3 = BlockStored(
|
||||
block_hashes=["hash3"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events3.add_events([event3])
|
||||
output3 = KVConnectorOutput(kv_cache_events=events3)
|
||||
mock_connector.update_connector_output(output3)
|
||||
|
||||
# Check final state
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
|
||||
|
||||
def test_updates_with_empty_events(self, mock_connector):
|
||||
"""Test updating with empty event lists."""
|
||||
# First update with actual events
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update with empty events
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
# No events added
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Should still have the original event
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 1
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
|
||||
|
||||
|
||||
class TestTakeEvents:
|
||||
"""Test take_events method."""
|
||||
|
||||
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that nothing is yielded when _kv_cache_events is None."""
|
||||
mock_connector._kv_cache_events = None
|
||||
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
assert events == []
|
||||
|
||||
def test_yields_events_and_clears(self, mock_connector):
|
||||
"""Test that events are yielded and then cleared."""
|
||||
# Set up events
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event1, event2])
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Check that events were yielded
|
||||
assert len(events) == 2
|
||||
assert event1 in events
|
||||
assert event2 in events
|
||||
|
||||
# Check that _kv_cache_events was cleared
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_aggregates_before_yielding(self, mock_connector):
|
||||
"""Test that events are aggregated before yielding."""
|
||||
# Set up events from multiple workers
|
||||
kv_events = LMCacheKVEvents(num_workers=3)
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
uncommon_event = BlockStored(
|
||||
block_hashes=["hash_uncommon"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# All 3 workers report common_event
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
|
||||
# Only 1 worker reports uncommon_event
|
||||
kv_events.add_events([uncommon_event])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Only the common event should be yielded
|
||||
assert len(events) == 1
|
||||
assert events[0] == common_event
|
||||
|
||||
def test_multiple_take_events_calls(self, mock_connector):
|
||||
"""Test calling take_events multiple times."""
|
||||
# First call with events
|
||||
kv_events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events1.add_events([event1])
|
||||
mock_connector._kv_cache_events = kv_events1
|
||||
|
||||
events1 = list(mock_connector.take_events())
|
||||
assert len(events1) == 1
|
||||
assert events1[0] == event1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
# Second call with no events
|
||||
events2 = list(mock_connector.take_events())
|
||||
assert events2 == []
|
||||
|
||||
# Third call after adding new events
|
||||
kv_events2 = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events2.add_events([event2])
|
||||
mock_connector._kv_cache_events = kv_events2
|
||||
|
||||
events3 = list(mock_connector.take_events())
|
||||
assert len(events3) == 1
|
||||
assert events3[0] == event2
|
||||
|
||||
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
|
||||
"""Test that nothing is yielded if aggregation removes all events."""
|
||||
# Set up events from 2 workers with no common events
|
||||
kv_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Worker 1 reports event1
|
||||
kv_events.add_events([event1])
|
||||
# Worker 2 reports event2
|
||||
kv_events.add_events([event2])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# No common events, so nothing should be yielded
|
||||
assert events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test a complete workflow from getting events to taking them."""
|
||||
# Step 1: Get events from lmcache engine
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is not None
|
||||
assert len(kv_events.get_all_events()) == 1
|
||||
|
||||
# Step 2: Update connector output (simulate receiving from worker)
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
assert mock_connector._kv_cache_events is not None
|
||||
|
||||
# Step 3: Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert len(taken_events) == 1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_multiple_workers_workflow(self, mock_connector):
|
||||
"""Test workflow with multiple workers."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, hash_val):
|
||||
self.block_hashes = [hash_val]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
# Worker 1
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker1"),
|
||||
]
|
||||
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Worker 2
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker2"),
|
||||
]
|
||||
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Take events (should only get common events)
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# With aggregation, only events reported by both workers should be present
|
||||
# In this case, hash_common was reported by both
|
||||
event_hashes = [e.block_hashes[0] for e in taken_events]
|
||||
assert "hash_common" in event_hashes
|
||||
|
||||
def test_empty_workflow(self, mock_connector):
|
||||
"""Test workflow when there are no events at any stage."""
|
||||
# Get events returns None
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is None
|
||||
|
||||
# Update with None
|
||||
output = KVConnectorOutput(kv_cache_events=None)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert taken_events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_repeated_cycles(self, mock_connector):
|
||||
"""Test multiple cycles of the complete workflow."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, cycle_num):
|
||||
self.block_hashes = [f"hash_cycle_{cycle_num}"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [cycle_num]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
for cycle in range(3):
|
||||
# Get events
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent(cycle)
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
# Update
|
||||
output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# Verify
|
||||
assert len(taken_events) == 1
|
||||
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_lmcache_kv_events_aggregation(self):
|
||||
"""
|
||||
Test LMCacheKVEvents aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Define common and unique events
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash="parent_common",
|
||||
token_ids=[1, 2, 3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker1_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker1"],
|
||||
parent_block_hash="parent_w1",
|
||||
token_ids=[4, 5],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker2_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker2"],
|
||||
parent_block_hash="parent_w2",
|
||||
token_ids=[6, 7],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker3_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker3"],
|
||||
parent_block_hash="parent_w3",
|
||||
token_ids=[8, 9],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Create events for each worker
|
||||
# Worker 0: reports common event and its unique event
|
||||
worker0_events = LMCacheKVEvents(num_workers=1)
|
||||
worker0_events.add_events([common_event, worker1_unique_event])
|
||||
|
||||
# Worker 1: reports common event and its unique event
|
||||
worker1_events = LMCacheKVEvents(num_workers=1)
|
||||
worker1_events.add_events([common_event, worker2_unique_event])
|
||||
|
||||
# Worker 2: reports common event and its unique event
|
||||
worker2_events = LMCacheKVEvents(num_workers=1)
|
||||
worker2_events.add_events([common_event, worker3_unique_event])
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_events in enumerate(
|
||||
[worker0_events, worker1_events, worker2_events]
|
||||
):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2
|
||||
else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0
|
||||
else None, # Workers 1,2 finished receiving
|
||||
kv_cache_events=worker_events,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
|
||||
|
||||
assert isinstance(kv_cache_events, LMCacheKVEvents)
|
||||
|
||||
# After aggregation, events should be combined from all workers
|
||||
# The aggregator doesn't automatically aggregate events, so we need to call
|
||||
# aggregate() to get only common events
|
||||
kv_cache_events.aggregate()
|
||||
aggregated_events = kv_cache_events.get_all_events()
|
||||
|
||||
# Only the common event should remain after aggregation
|
||||
# because it's the only event reported by all 3 workers
|
||||
assert len(aggregated_events) == 1
|
||||
assert aggregated_events[0] == common_event
|
||||
|
||||
# Verify the common event properties
|
||||
assert aggregated_events[0].block_hashes == ["hash_common"]
|
||||
assert aggregated_events[0].parent_block_hash == "parent_common"
|
||||
assert aggregated_events[0].token_ids == [1, 2, 3]
|
||||
228
tests/v1/kv_connector/unit/test_lmcache_integration.py
Normal file
228
tests/v1/kv_connector/unit/test_lmcache_integration.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# NOTE: if your PR has broken one of the tests here (sorry),
|
||||
# kindly patch the corresponding integration in
|
||||
# /vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
|
||||
# or reach out to @aposataC for assistance
|
||||
|
||||
# Assumption vs. Correctness Tests:
|
||||
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
|
||||
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def assumes(obj, attr, is_callable=False, is_instance_of=None):
|
||||
import inspect
|
||||
from dataclasses import is_dataclass
|
||||
|
||||
assumption_msg = (
|
||||
f"LMCache connector currently assumes that {obj} has a(n) {attr} attribute"
|
||||
)
|
||||
if hasattr(obj, attr):
|
||||
attr_value = getattr(obj, attr)
|
||||
elif is_dataclass(obj) and attr in getattr(obj, "__dataclass_fields__", {}):
|
||||
field = obj.__dataclass_fields__[attr]
|
||||
field_type = field.type
|
||||
origin = getattr(field_type, "__origin__", None)
|
||||
if origin is not None:
|
||||
field_type = origin
|
||||
attr_value = field_type
|
||||
else:
|
||||
raise AssertionError(assumption_msg)
|
||||
if is_callable:
|
||||
assumption_msg += f" and that {obj}.{attr} is a callable"
|
||||
assert callable(attr_value), assumption_msg
|
||||
if is_instance_of:
|
||||
assumption_msg += f" and that {obj}.{attr} is an instance of {is_instance_of}"
|
||||
if isinstance(attr_value, property):
|
||||
fget = attr_value.fget
|
||||
assert fget is not None, f"Property {obj}.{attr} has no fget"
|
||||
sig = inspect.signature(fget)
|
||||
ret_anno = sig.return_annotation
|
||||
assert ret_anno is not inspect._empty, (
|
||||
f"Property {obj}.{attr} has no return annotation"
|
||||
)
|
||||
assert ret_anno == is_instance_of, assumption_msg
|
||||
else:
|
||||
if isinstance(attr_value, type):
|
||||
assert attr_value is is_instance_of, assumption_msg
|
||||
else:
|
||||
assert isinstance(attr_value, is_instance_of), assumption_msg
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_multimodal_interface():
|
||||
# protect against interface changes
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
assumes(PlaceholderRange, "offset")
|
||||
assumes(PlaceholderRange, "length")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_config_interface():
|
||||
# protect against interface changes
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
|
||||
assumes(VllmConfig, "model_config")
|
||||
assumes(VllmConfig, "cache_config")
|
||||
assumes(VllmConfig, "parallel_config")
|
||||
assumes(VllmConfig, "kv_transfer_config")
|
||||
|
||||
assumes(KVTransferConfig, "kv_role")
|
||||
assumes(KVTransferConfig, "kv_connector_extra_config")
|
||||
|
||||
assumes(ModelConfig, "use_mla", is_instance_of=bool)
|
||||
assumes(ModelConfig, "dtype")
|
||||
assumes(ModelConfig, "max_model_len")
|
||||
assumes(ModelConfig, "get_vocab_size", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_attention_heads", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
|
||||
assumes(ModelConfig, "get_head_size", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_layers", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
|
||||
assumes(ModelConfig, "model")
|
||||
|
||||
assumes(ParallelConfig, "world_size")
|
||||
assumes(ParallelConfig, "rank")
|
||||
assumes(ParallelConfig, "tensor_parallel_size")
|
||||
assumes(ParallelConfig, "pipeline_parallel_size")
|
||||
assumes(ParallelConfig, "data_parallel_size_local")
|
||||
assumes(ParallelConfig, "data_parallel_rank_local")
|
||||
|
||||
assumes(CacheConfig, "cache_dtype")
|
||||
assumes(CacheConfig, "block_size")
|
||||
assumes(CacheConfig, "gpu_memory_utilization")
|
||||
|
||||
# kv metadata minimal case
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
|
||||
model_config = ModelConfig(dtype="bfloat16")
|
||||
parallel_config = ParallelConfig()
|
||||
cache_config = CacheConfig(cache_dtype="bfloat16")
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
|
||||
use_mla = False
|
||||
chunk_size = 256
|
||||
num_layer = model_config.get_num_layers(parallel_config)
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# dummy lmcache metadata creation example
|
||||
_ = (
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_request_interface():
|
||||
# protect against interface changes
|
||||
from types import NoneType
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
req = Request(
|
||||
request_id="test_request",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
)
|
||||
assumes(req, "mm_features", is_instance_of=(list, NoneType))
|
||||
assumes(req, "request_id")
|
||||
assumes(req, "priority")
|
||||
assumes(req, "prompt_token_ids")
|
||||
assumes(req, "sampling_params")
|
||||
assumes(req, "num_tokens")
|
||||
assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType))
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
|
||||
assumes(MultiModalFeatureSpec, "identifier")
|
||||
assumes(MultiModalFeatureSpec, "mm_position")
|
||||
|
||||
|
||||
def test_new_request_interface():
|
||||
# protect against interface changes
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
assumes(NewRequestData, "req_id")
|
||||
assumes(NewRequestData, "block_ids")
|
||||
assumes(NewRequestData, "prompt_token_ids")
|
||||
assumes(NewRequestData, "sampling_params")
|
||||
|
||||
|
||||
def test_sampling_params_interface():
|
||||
# protect against interface changes
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
assumes(SamplingParams, "extra_args")
|
||||
|
||||
# dumb example use case in LMCache
|
||||
kv_transfer_params = {
|
||||
"lmcache.tag.user": "example_user_1",
|
||||
"lmcache.ttl": 60,
|
||||
}
|
||||
sampling_params = SamplingParams(
|
||||
extra_args={"kv_transfer_params": kv_transfer_params}
|
||||
)
|
||||
assert sampling_params.extra_args["kv_transfer_params"] == kv_transfer_params
|
||||
|
||||
|
||||
def test_tp_interface():
|
||||
# protect against interface changes
|
||||
import inspect
|
||||
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
sig = inspect.signature(get_tp_group)
|
||||
GroupCoordinator = sig.return_annotation
|
||||
|
||||
assumes(GroupCoordinator, "broadcast", is_callable=True)
|
||||
assumes(GroupCoordinator, "broadcast_object", is_callable=True)
|
||||
|
||||
|
||||
def test_forward_context_interface():
|
||||
# protect against interface changes
|
||||
from vllm.forward_context import ForwardContext
|
||||
|
||||
assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
|
||||
assumes(ForwardContext, "virtual_engine")
|
||||
assumes(ForwardContext, "attn_metadata")
|
||||
|
||||
|
||||
def test_scheduler_output_interface():
|
||||
# protect against interface changes
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
assumes(SchedulerOutput, "finished_req_ids")
|
||||
assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list)
|
||||
assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict)
|
||||
assumes(SchedulerOutput, "scheduled_cached_reqs")
|
||||
|
||||
from vllm.v1.core.sched.output import CachedRequestData
|
||||
|
||||
assumes(CachedRequestData, "req_ids", is_instance_of=list)
|
||||
assumes(CachedRequestData, "new_block_ids", is_instance_of=list)
|
||||
603
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
603
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import filecmp
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiConnector,
|
||||
MultiKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
PROMPT_CONTEXT = "Hi " * 100
|
||||
PROMPTS = [
|
||||
PROMPT_CONTEXT + "Hello, my name is",
|
||||
PROMPT_CONTEXT + "The capital of France is",
|
||||
]
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
||||
|
||||
|
||||
# Test connector with custom stats for testing MultiConnector
|
||||
class MockConnectorStats(KVConnectorStats):
|
||||
"""Mock stats class for testing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MockConnector(KVConnectorBase_V1):
|
||||
"""Mock connector that implements build_kv_connector_stats for testing."""
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
return MockConnectorStats(data=data) if data is not None else None
|
||||
|
||||
|
||||
# Register the mock connector
|
||||
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||
|
||||
|
||||
# Helper function to compare directories recursively
|
||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||
"""Compares two directories recursively for identical content."""
|
||||
dcmp = filecmp.dircmp(dir1, dir2)
|
||||
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
|
||||
print(f"Differences found between {dir1} and {dir2}:")
|
||||
print(f" Left only: {dcmp.left_only}")
|
||||
print(f" Right only: {dcmp.right_only}")
|
||||
print(f" Different files: {dcmp.diff_files}")
|
||||
return False
|
||||
for sub_dir in dcmp.common_dirs:
|
||||
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"hipErrorLaunchFailure when running this test, see issue:"
|
||||
"https://github.com/ROCm/pytorch/issues/2822"
|
||||
),
|
||||
)
|
||||
def test_multi_example_connector_consistency():
|
||||
"""
|
||||
Tests that MultiConnector with two ExampleConnectors saves
|
||||
identical KV cache data to separate storage locations.
|
||||
"""
|
||||
storage_1_path = Path("storage_1/")
|
||||
storage_2_path = Path("storage_2/")
|
||||
shutil.rmtree(storage_1_path, ignore_errors=True)
|
||||
shutil.rmtree(storage_2_path, ignore_errors=True)
|
||||
storage_1_path.mkdir()
|
||||
storage_2_path.mkdir()
|
||||
|
||||
# Configure MultiConnector with two ExampleConnectors
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [
|
||||
{
|
||||
"kv_connector": "TestExampleConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_1_path),
|
||||
"name": "storage1",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
{
|
||||
"kv_connector": "TestExampleConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_2_path),
|
||||
"name": "storage2",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
# Run generation - this should trigger saving KV cache
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
# --- Verification ---
|
||||
|
||||
# Check that both storage directories were populated
|
||||
local_subdirs = list(storage_1_path.iterdir())
|
||||
external_subdirs = list(storage_2_path.iterdir())
|
||||
|
||||
assert len(local_subdirs) > 0, (
|
||||
f"Local storage path {storage_1_path} is empty after generation."
|
||||
)
|
||||
assert len(external_subdirs) > 0, (
|
||||
f"External storage path {storage_2_path} is empty after generation."
|
||||
)
|
||||
assert len(local_subdirs) == len(external_subdirs), (
|
||||
f"Mismatch in number of cache entries: "
|
||||
f"Local={len(local_subdirs)}, External={len(external_subdirs)}"
|
||||
)
|
||||
|
||||
# The subdirectories should correspond to the prompt hashes
|
||||
# Since prompts are the same, the hash directories should be the same name
|
||||
local_subdir_names = sorted([d.name for d in local_subdirs])
|
||||
external_subdir_names = sorted([d.name for d in external_subdirs])
|
||||
assert local_subdir_names == external_subdir_names, (
|
||||
"Cache directory names do not match between local and external storage"
|
||||
)
|
||||
|
||||
# Compare the contents of each corresponding cache directory
|
||||
for subdir_name in local_subdir_names:
|
||||
print(f"Comparing contents of cache directory: {subdir_name}")
|
||||
assert _compare_directories(
|
||||
storage_1_path / subdir_name, storage_2_path / subdir_name
|
||||
), (
|
||||
f"Contents differ for cache directory '{subdir_name}' between "
|
||||
f"{storage_1_path} and {storage_2_path}"
|
||||
)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens and update_state_after_alloc will be called
|
||||
# on each connector in turn.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage1-WORKER"][:5] == [
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-WORKER"][:5] == [
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will return new tokens from the first
|
||||
# connector so update_state_after_alloc will be with allocated blocks
|
||||
# on that one but with zero blocks for others (first nonzero match is
|
||||
# chosen).
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Delete storage1 connector state
|
||||
shutil.rmtree(storage_1_path)
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called for both connectors but will
|
||||
# return 0 from the first connector, but the second connector should have
|
||||
# a hit, so update_state_after_alloc will only be called with allocated
|
||||
# blocks for the second connector.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(storage_1_path)
|
||||
shutil.rmtree(storage_2_path)
|
||||
|
||||
|
||||
def get_connector_events() -> dict[str, list[str]]:
|
||||
# Read in connector events and reset the files.
|
||||
import glob
|
||||
|
||||
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
|
||||
connector_events = {}
|
||||
for fname in event_files:
|
||||
name = fname.split("connector_")[1].split("_events.log")[0]
|
||||
try:
|
||||
with open(fname, "r+") as f:
|
||||
connector_events[name] = [line.strip() for line in f if line.strip()]
|
||||
f.truncate(0)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||
|
||||
return connector_events
|
||||
|
||||
|
||||
def test_engine_id_conflict():
|
||||
configs = [KVTransferConfig() for _ in range(2)]
|
||||
ids = [config.engine_id for config in configs]
|
||||
assert ids[0] != ids[1], (
|
||||
f"Engine IDs should be different for different configs. Got {ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestMultiConnectorStats:
|
||||
"""Tests for MultiConnector stats reconstruction and operations."""
|
||||
|
||||
def test_build_kv_connector_stats_with_none(self):
|
||||
"""Test that build_kv_connector_stats returns empty stats when given None."""
|
||||
stats = MultiConnector.build_kv_connector_stats(data=None)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 0
|
||||
assert stats.is_empty()
|
||||
|
||||
def test_build_kv_connector_stats_with_empty_dict(self):
|
||||
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
|
||||
stats = MultiConnector.build_kv_connector_stats(data={})
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 0
|
||||
assert stats.is_empty()
|
||||
|
||||
def test_build_kv_connector_stats_reconstructs_nixl_stats(self):
|
||||
"""Test that NixlConnector stats are properly reconstructed with
|
||||
correct data."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5, 2.3],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert "NixlConnector" in stats.data
|
||||
nixl_stats = stats.data["NixlConnector"]
|
||||
assert isinstance(nixl_stats, NixlKVConnectorStats)
|
||||
assert nixl_stats.data["transfer_duration"] == [1.5, 2.3]
|
||||
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||
|
||||
def test_build_kv_connector_stats_with_multiple_connectors(self):
|
||||
"""Test reconstruction with multiple connector types that have custom stats."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
"MockConnector": {"data": {"mock_field": [1, 2, 3]}},
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
# Both connectors should be reconstructed
|
||||
assert len(stats.data) == 2
|
||||
assert "NixlConnector" in stats.data
|
||||
assert "MockConnector" in stats.data
|
||||
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||
# Verify data is preserved
|
||||
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||
|
||||
def test_build_kv_connector_stats_raises_error_for_unknown_connector(self):
|
||||
"""Test that unknown connectors raise an error."""
|
||||
serialized_data = {
|
||||
"UnknownConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Connector 'UnknownConnector' is not registered."
|
||||
):
|
||||
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
def test_build_kv_connector_stats_with_already_instantiated_objects(self):
|
||||
"""Test that already-instantiated stats objects are preserved (same process)."""
|
||||
# This simulates the in-process case where stats are not serialized
|
||||
nixl_stats = NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]})
|
||||
|
||||
data_with_objects = {
|
||||
"NixlConnector": nixl_stats,
|
||||
"MockConnector": mock_stats,
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=data_with_objects)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 2
|
||||
# Verify objects are preserved as-is
|
||||
assert stats.data["NixlConnector"] is nixl_stats
|
||||
assert stats.data["MockConnector"] is mock_stats
|
||||
|
||||
def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self):
|
||||
"""Test handling mixed already-instantiated and serialized stats."""
|
||||
# This can happen during transition or partial serialization
|
||||
nixl_stats = NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
|
||||
mixed_data = {
|
||||
"NixlConnector": nixl_stats, # Already instantiated
|
||||
"MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=mixed_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 2
|
||||
# Instantiated object preserved
|
||||
assert stats.data["NixlConnector"] is nixl_stats
|
||||
# Serialized object reconstructed
|
||||
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||
|
||||
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
|
||||
"""Test that connectors without custom stats (return None) are skipped."""
|
||||
# ExampleConnector doesn't override build_kv_connector_stats,
|
||||
# so it returns None and should be skipped
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
"ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
# Only NixlConnector should be reconstructed
|
||||
assert len(stats.data) == 1
|
||||
assert "NixlConnector" in stats.data
|
||||
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||
# ExampleConnector should be skipped (returns None)
|
||||
assert "ExampleConnector" not in stats.data
|
||||
|
||||
def test_build_kv_connector_stats_handles_malformed_data(self):
|
||||
"""Test that malformed data raises appropriate errors."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}}
|
||||
}
|
||||
|
||||
with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"):
|
||||
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
def test_aggregate_same_connector(self):
|
||||
"""Test aggregating stats from the same connector type."""
|
||||
stats1 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
stats2 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [2.0],
|
||||
"post_duration": [0.2],
|
||||
"bytes_transferred": [2048],
|
||||
"num_descriptors": [20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
result = stats1.aggregate(stats2)
|
||||
|
||||
assert result is stats1 # Should return self
|
||||
assert "NixlConnector" in result.data
|
||||
nixl_stats = result.data["NixlConnector"]
|
||||
assert nixl_stats.data["transfer_duration"] == [1.0, 2.0]
|
||||
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||
|
||||
def test_aggregate_new_connector(self):
|
||||
"""Test aggregating stats when a new connector type appears."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats,
|
||||
)
|
||||
|
||||
stats1 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
stats2 = MultiKVConnectorStats(
|
||||
data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
|
||||
)
|
||||
|
||||
result = stats1.aggregate(stats2)
|
||||
|
||||
assert "NixlConnector" in result.data
|
||||
assert "ExampleConnector" in result.data
|
||||
|
||||
def test_reduce(self):
|
||||
"""Test that reduce() correctly reduces all nested connector stats."""
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0, 2.0],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
reduced = stats.reduce()
|
||||
|
||||
assert "NixlConnector" in reduced
|
||||
assert isinstance(reduced["NixlConnector"], dict)
|
||||
# Check that the stats were reduced (should have aggregated values)
|
||||
assert "Num successful transfers" in reduced["NixlConnector"]
|
||||
assert reduced["NixlConnector"]["Num successful transfers"] == 2
|
||||
|
||||
def test_reset(self):
|
||||
"""Test that reset() resets all nested connector stats."""
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0, 2.0],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert not stats.is_empty()
|
||||
|
||||
stats.reset()
|
||||
|
||||
# After reset, stats should be empty
|
||||
assert stats.is_empty()
|
||||
nixl_stats = stats.data["NixlConnector"]
|
||||
assert len(nixl_stats.data["transfer_duration"]) == 0
|
||||
|
||||
def test_is_empty_with_multiple_connectors(self):
|
||||
"""Test is_empty() returns correct value with multiple connectors."""
|
||||
# All empty
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(data={}),
|
||||
}
|
||||
)
|
||||
# Initialize empty stats
|
||||
stats.data["NixlConnector"].reset()
|
||||
assert stats.is_empty()
|
||||
|
||||
# One non-empty
|
||||
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
|
||||
assert not stats.is_empty()
|
||||
1791
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
1791
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
534
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
534
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector,
|
||||
OffloadingConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
LoadStoreSpec,
|
||||
OffloadingEvent,
|
||||
OffloadingManager,
|
||||
PrepareStoreOutput,
|
||||
)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingHandler,
|
||||
TransferResult,
|
||||
TransferSpec,
|
||||
)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (
|
||||
EOS_TOKEN_ID,
|
||||
create_model_runner_output,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "Mock"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_hashes)
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
self.completed_transfers = []
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
return True
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (
|
||||
MockLoadStoreSpec(block_hashes)
|
||||
)
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _, __
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
return specs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferSummary:
|
||||
gpu_block_indices: list[int]
|
||||
offload_addresses: list[Any]
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
def __init__(
|
||||
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
|
||||
):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=gpu_block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector", # noqa: E501
|
||||
"block_size": offloaded_block_size,
|
||||
},
|
||||
)
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(
|
||||
vllm_config, num_blocks=num_gpu_blocks
|
||||
)
|
||||
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_cross_layers_kv_cache(
|
||||
kv_cache=torch.empty(0),
|
||||
attn_backend=FlashAttentionBackend,
|
||||
)
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, OffloadingConnector)
|
||||
self.scheduler_connector: OffloadingConnector = scheduler_connector
|
||||
|
||||
# extract mocked OffloadingManager of scheduler connector
|
||||
connector_scheduler = scheduler_connector.connector_scheduler
|
||||
assert connector_scheduler is not None
|
||||
manager = connector_scheduler.manager
|
||||
assert isinstance(manager, MagicMock)
|
||||
self.manager: MagicMock = manager
|
||||
|
||||
assert connector_scheduler.gpu_block_size == gpu_block_size
|
||||
assert connector_scheduler.offloaded_block_size == offloaded_block_size
|
||||
|
||||
# extract OffloadingSpec of worker_connector
|
||||
connector_worker = self.worker_connector.connector_worker
|
||||
assert connector_worker is not None
|
||||
offloading_spec = connector_worker.spec
|
||||
assert isinstance(offloading_spec, MockOffloadingSpec)
|
||||
self.offloading_spec: MockOffloadingSpec = offloading_spec
|
||||
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=1000),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in self.offloading_spec.get_completed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = len(offload_addresses) - len(
|
||||
gpu_block_indices
|
||||
)
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks.values():
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
self._update_gpu_block_idx()
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
assert not self.scheduler.running
|
||||
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
yield runner_factory # pass factory to the test
|
||||
|
||||
|
||||
def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
block_hashes = list(block_hashes)
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=list(block_hashes),
|
||||
store_spec=MockLoadStoreSpec(block_hashes),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
)
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(
|
||||
decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
|
||||
# verify hashes are the same, except for the last block
|
||||
assert block_hashes1[:5] == block_hashes2[:5]
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(
|
||||
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
|
||||
)
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(
|
||||
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
|
||||
)
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False
|
||||
)
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True
|
||||
)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
assert len(events) == 2
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == to_hashes([1, 2, 3])
|
||||
assert event.block_size == 16
|
||||
assert event.medium == "A"
|
||||
assert event.token_ids == []
|
||||
assert event.parent_block_hash is None
|
||||
assert event.lora_id is None
|
||||
event = events[1]
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
||||
122
tests/v1/kv_connector/unit/test_output_aggregator.py
Normal file
122
tests/v1/kv_connector/unit/test_output_aggregator.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
expected_finished_count: int = 0,
|
||||
):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DummyModelRunnerOutput("
|
||||
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})"
|
||||
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert not aggregated.invalid_block_ids
|
||||
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, finished_recving={"req2"}
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {1}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {"req1"}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
def test_aggregate_workers_output_with_expected_finished_count():
|
||||
# We create the aggregator expecting to collect from 4 workers
|
||||
aggregator = KVOutputAggregator(expected_finished_count=4)
|
||||
assert aggregator._expected_finished_count == 4
|
||||
# Some request with default expected finished requests
|
||||
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
aggregated = aggregator.aggregate([output1])
|
||||
# still expecting to collect from 4 workers
|
||||
assert aggregator._send_remaining_count["req1"] == 3
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
assert not aggregated.kv_connector_output.finished_recving
|
||||
|
||||
# Workers discover and find that in this setup they only need to
|
||||
# collect from 2
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, expected_finished_count=2
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, expected_finished_count=2
|
||||
)
|
||||
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
|
||||
# Req2 only needs 2 acks
|
||||
aggregated = aggregator.aggregate([output1, output2, output3])
|
||||
assert aggregated.kv_connector_output.expected_finished_count == 2
|
||||
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
|
||||
# Req2 is finished
|
||||
assert "req2" not in aggregator._recv_remaining_count
|
||||
assert aggregated.kv_connector_output.finished_recving == {"req2"}
|
||||
|
||||
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
|
||||
# NOTE: This is to showcase dynamic update. Workers are responsible for
|
||||
# ensuring "req1" termination in this case
|
||||
assert aggregator._send_remaining_count["req1"] == 2
|
||||
262
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
262
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# Ensure the request is finished after 1 token.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and in Persistent Batch ...
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
# STEP (2): Send Finished to PB.
|
||||
# (2a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_id}
|
||||
)
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_short_prompt_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request with short prompt."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Not enough tokens for full block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = BLOCK_SIZE // 2
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
# Even though tokens < block_size, there will be kv xfer for partial block.
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == 1
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
# We need to mark sending finish to clear data for persistent batch.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# Use create_model_runner_output to pass kv_connector_output along
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request], finished_sending={request.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still work with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS
|
||||
)
|
||||
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
# Ensure we send all block ids, including the partial blocks,
|
||||
# even if there is a cache hit.
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_abort_during_kv_transfer():
|
||||
"""Test aborting request does not release blocks for remote decode."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
# Request removed from PB but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Abort the request, and check the blocks are still not freed
|
||||
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Simulate a finished sending notification
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request.request_id]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
577
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
577
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id}
|
||||
)
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_interleaved_lifecycle():
|
||||
"""Test Remote Prefills Work Well With Other Requests."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_local_a = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
request_local_b = create_request(
|
||||
request_id=3,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
|
||||
# STEP 1: Regular request is running.
|
||||
scheduler.add_request(request_local_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request_local_a])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 2: Add a local and remote request.
|
||||
scheduler.add_request(request_local_b)
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request_local_a, request_local_b])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 3: continue running, KVs not arrived yet.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_local_a, request_local_b]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
# STEP 4: KVs arrive.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 6: Hit EOS and free.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b, request_remote],
|
||||
use_eos=True,
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
request_local = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=False,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Schedule the local prefill request. This should
|
||||
# cause blocks to be cached, but separately from
|
||||
scheduler.add_request(request_local)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_local.request_id]
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Local should have cached blocks (but not all due to preallocate).
|
||||
num_hashed_blocks = 0
|
||||
for block in local_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks > 0
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1
|
||||
assert scheduler_output.num_scheduled_tokens[request_id] == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_cannot_schedule_after_recv():
|
||||
"""
|
||||
Test that we can handle no schedule after recv due to not
|
||||
enough remaining KV blocks.
|
||||
"""
|
||||
|
||||
# NOTE: the KVCacheManager will use 1 null block.
|
||||
# So there are 5 total working blocks.
|
||||
TOTAL_NUM_BLOCKS = 6
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
|
||||
|
||||
# Prime the KVCache.
|
||||
NUM_PROMPT_BLOCKS = 2
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 2: 5 blocks are in use (2 new for remote blocks).
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 3: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 4: try to schedule, remote request is put to running list
|
||||
# because the transfer is completed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 5: Remote request will be put back to waiting list
|
||||
# because it needs new block to hold generated token.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 6: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 7: now we can schedule (with 2 blocks computed),
|
||||
# request is retrieved from preempted list.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
assert (
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
||||
== NUM_PROMPT_BLOCKS * BLOCK_SIZE
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 8: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_cannot_recv():
|
||||
"""
|
||||
Test that we can handle no schedule KV block transfer due to not
|
||||
enough remaining KV blocks.
|
||||
"""
|
||||
|
||||
# NOTE: the KVCacheManager will use 1 null block.
|
||||
# So there are 5 total working blocks.
|
||||
TOTAL_NUM_BLOCKS = 6
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
|
||||
|
||||
# Prime the KVCache.
|
||||
NUM_PROMPT_BLOCKS = 2
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 2: 3 blocks are in use,
|
||||
# need 3 new for remote blocks but only 2 are available.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
# Should not have KV transfer in progress.
|
||||
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 3: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 5: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 6: schedule remote request
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 7: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
402
tests/v1/kv_connector/unit/utils.py
Normal file
402
tests/v1/kv_connector/unit/utils.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain, count
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
|
||||
ExampleConnector,
|
||||
)
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].num_cached_block
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 64,
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
kv_connector_extra_config: dict[str, Any] | None = None,
|
||||
dtype: str = "float16",
|
||||
cache_dtype: str = "auto",
|
||||
hf_overrides: dict[str, Any] | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42,
|
||||
hf_overrides=hf_overrides or {},
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype=cache_dtype,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
_request_count = count(1)
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int | None = None,
|
||||
num_tokens: int = 10,
|
||||
common_prefix_len=0,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = sha256,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
assert num_tokens >= common_prefix_len >= 0
|
||||
|
||||
if request_id is None:
|
||||
request_id = next(_request_count)
|
||||
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash_fn)
|
||||
_none_hash_initialized = True
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_request_id=f"prefill-{request_id}",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
|
||||
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
|
||||
prompt_token_ids = common_prefix + suffix
|
||||
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
mm_features=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = (
|
||||
None
|
||||
if (
|
||||
finished_sending is None
|
||||
and finished_recving is None
|
||||
and invalid_block_ids is None
|
||||
)
|
||||
else KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
)
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=None,
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
|
||||
class TestExampleConnector(ExampleConnector):
|
||||
def __init__(self, config: VllmConfig, role, kv_cache_config):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = ExampleConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = (
|
||||
tempfile.gettempdir()
|
||||
+ f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
)
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in (
|
||||
"_connector",
|
||||
"call_record",
|
||||
"name",
|
||||
"_event_file",
|
||||
"__class__",
|
||||
"__dict__",
|
||||
"__getattribute__",
|
||||
"__init__",
|
||||
): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
attr = getattr(self._connector, name)
|
||||
|
||||
# Intercept calls to the connector interface and write an event
|
||||
# for each one to a file, which can be read back in the main test proc.
|
||||
if callable(attr):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_record[name] += 1
|
||||
|
||||
# Include args that we're interested in
|
||||
to_log = [name]
|
||||
for arg in args:
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(" ".join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MockKVConfig:
|
||||
matched_tokens: int = 0
|
||||
is_async: bool = False
|
||||
|
||||
|
||||
class MockKVConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
# Scheduler tests check metadata.requests
|
||||
self.requests: list = []
|
||||
|
||||
|
||||
class MockKVConnector(KVConnectorBase_V1):
|
||||
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config
|
||||
self.config = MockKVConfig(
|
||||
matched_tokens=extra_config["matched_tokens"],
|
||||
is_async=extra_config["is_async"],
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
return (self.config.matched_tokens, self.config.is_async)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: KVCacheBlocks,
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
metadata = MockKVConnectorMetadata()
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for req_id in chain(
|
||||
(req.req_id for req in scheduler_output.scheduled_new_reqs),
|
||||
(
|
||||
req_id
|
||||
for req_id in cached_reqs.req_ids
|
||||
if req_id in cached_reqs.resumed_req_ids
|
||||
),
|
||||
):
|
||||
metadata.requests.append({"req_id": req_id})
|
||||
return metadata
|
||||
|
||||
def start_load_kv(self, kv_caches, finished_req_ids):
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name):
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"TestExampleConnector", __name__, TestExampleConnector.__name__
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MockKVConnector", __name__, MockKVConnector.__name__
|
||||
)
|
||||
Reference in New Issue
Block a user