v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
import types
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
def test_basic_inferface():
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
assert block_id == block.block_id
def test_read_agent_metadata():
rank_table = {
"version":
"1.2",
"server_count":
"2",
"prefill_device_list": [{
"server_id": "192.168.1.1",
"device_id": "0",
"device_ip": "10.30.0.1",
"cluster_id": "0",
}, {
"server_id": "192.168.1.1",
"device_id": "1",
"device_ip": "10.30.0.2",
"cluster_id": "1",
}, {
"server_id": "192.168.1.2",
"device_id": "0",
"device_ip": "10.30.0.3",
"cluster_id": "2",
}, {
"server_id": "192.168.1.2",
"device_id": "1",
"device_ip": "10.30.0.4",
"cluster_id": "3",
}]
}
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
worker = types.SimpleNamespace()
worker.local_ip = worker_local_ip
worker.tp_rank = worker_tp_rank
worker.llm_datadist_role = LLMRole.PROMPT
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
worker, rank_table)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
return agent_metadata.device_ip
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"

View File

@@ -0,0 +1,998 @@
import os
import queue
import socket
import sys
import threading
import time
import types
import unittest
from collections import defaultdict, deque
from unittest.mock import MagicMock, patch
import msgspec
import zmq
from vllm.utils import make_zmq_path
fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send,
group_concurrent_contiguous, string_to_int64_hash, zmq_ctx)
GET_META_MSG = b"get_meta_msg"
DONE_RECVING_MSG = b"done_recving_msg"
class TestKVCacheTaskTrackerInit(unittest.TestCase):
def test_init_basic_properties(self):
tracker = KVCacheTaskTracker()
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
self.assertIsInstance(tracker.finished_requests, set)
self.assertIsInstance(tracker.delayed_free_requests, deque)
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
self.tracker.finished_requests = set()
self.tracker.done_task_lock = threading.Lock()
def test_empty_requests(self):
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, set())
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_single_request(self):
self.tracker.finished_requests = {"req_123"}
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, {"req_123"})
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_multiple_requests(self):
self.tracker.finished_requests = {"req_1", "req_2", "req_3"}
result = self.tracker.get_and_clear_finished_requests()
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
self.assertEqual(len(self.tracker.finished_requests), 0)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_concurrent_access(self, mock_logger):
from concurrent.futures import ThreadPoolExecutor
self.tracker.finished_requests = {"req_1", "req_2"}
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(self.tracker.get_and_clear_finished_requests)
for _ in range(3)
]
results = [f.result() for f in futures]
self.assertEqual(sum(1 for r in results if r), 1)
self.assertEqual(len(self.tracker.finished_requests), 0)
class TestKVCacheSendingThreadInit(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': MagicMock(),
'ready_event': threading.Event()
}
self.threads = []
def tearDown(self):
for thread in self.threads:
if hasattr(thread, 'task_tracker') and hasattr(
thread.task_tracker, 'socket'):
thread.task_tracker.socket.close()
if hasattr(thread, 'is_alive') and thread.is_alive():
thread.join(timeout=0.1)
def test_thread_daemon_property(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertTrue(thread.daemon)
def test_thread_name_format(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertEqual(thread.name, "KVCacheSendingThread")
def test_ready_event_reference(self):
custom_event = threading.Event()
args = self.common_args.copy()
args['ready_event'] = custom_event
thread = KVCacheSendingThread(**args)
self.threads.append(thread)
self.assertIs(thread.ready_event, custom_event)
class TestGetAndClearFinishedRequests(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': {
"test": "metadata"
},
'ready_event': threading.Event()
}
self.thread = KVCacheSendingThread(**self.common_args)
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_and_clear_finished_requests(self, mock_get_clear):
expected_requests = {'req1', 'req2'}
mock_get_clear.return_value = expected_requests
result = self.thread.get_and_clear_finished_requests()
mock_get_clear.assert_called_once()
self.assertEqual(result, expected_requests)
class TestKVCacheSendingThread(unittest.TestCase):
def test_run_handles_get_meta_and_done_recv_msgs(self):
ready_event = threading.Event()
metadata = MooncakeAgentMetadata(
engine_id="engine1",
te_rpc_port=9090,
kv_caches_base_addr=[12345678],
num_blocks=2,
)
host = "127.0.0.1"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
free_port = s.getsockname()[1]
thread = KVCacheSendingThread(
tp_rank=0,
decode_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
)
thread.start()
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
sock.connect(f"tcp://{host}:{free_port}")
encoder = msgspec.msgpack.Encoder()
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
meta = decoder.decode(frames[1])
self.assertEqual(meta.engine_id, "engine1")
self.assertEqual(meta.kv_caches_base_addr, [12345678])
self.assertEqual(meta.num_blocks, 2)
req_id = "request_42"
sock.send_multipart(
[b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
self.assertEqual(frames[1], b"ACK")
self.assertIn(req_id, thread.task_tracker.finished_requests)
sock.close()
context.term()
class TestKVCacheRecvingThreadBasic(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
def test_add_request(self):
test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
}
self.thread.add_request(**test_req)
queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1")
self.assertEqual(queued["remote_host"], "localhost")
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_finished_requests(self, mock_tracker):
mock_tracker.return_value = {"req1", "req2"}
result = self.thread.get_and_clear_finished_requests()
self.assertEqual(result, {"req1", "req2"})
class TestSocketManagement(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
def test_get_remote_socket(self, mock_make_socket, mock_context):
mock_sock = MagicMock()
mock_make_socket.return_value = mock_sock
test_host = "test_host"
test_port = 12345
sock = self.thread._get_remote_socket(test_host, test_port)
self.assertEqual(sock, mock_sock)
mock_make_socket.assert_called_once()
args, kwargs = mock_make_socket.call_args
self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345')
self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore
self.assertFalse(kwargs.get('bind', True))
self.thread.remote_poller.register.assert_called_with(
mock_sock, zmq.POLLIN) # type: ignore
def test_return_socket_to_pool(self):
mock_sock = MagicMock()
test_host = "test_host"
test_port = 12345
test_path = make_zmq_path("tcp", test_host, test_port)
self.thread._return_remote_socket(mock_sock, test_host, test_port)
self.assertEqual(len(self.thread.remote_sockets[test_path]), 1)
self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock)
self.thread.remote_poller.register.assert_not_called()
class TestCoreFunctionality(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.mock_queue = MagicMock()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.task_tracker = MagicMock()
self.engine.batch_transfer_sync_read.return_value = 0
self.thread.remote_te_port = {"remote_engine": {6666: 7777}}
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
def test_handle_request(self, mock_send, mock_transfer):
self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1")
self.mock_queue.task_done.assert_called_once()
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
def test_transfer_kv_cache(self, mock_get_meta):
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
self.thread._transfer_kv_cache(self.test_req)
self.engine.batch_transfer_sync_read.assert_called_once()
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
self.assertEqual(call_args[0], "localhost:7777")
self.assertIsInstance(call_args[1], list)
self.assertIsInstance(call_args[2], list)
self.assertIsInstance(call_args[3], list)
self.assertEqual(len(call_args[1]), len(call_args[2]))
self.assertEqual(len(call_args[1]), len(call_args[3]))
mock_get_meta.assert_not_called()
def test_transfer_kv_cache_failure(self):
self.engine.batch_transfer_sync_read.return_value = -1
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
with self.assertRaises(RuntimeError):
self.thread._transfer_kv_cache(self.test_req)
class TestMetadataHandling(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
kv_caches_base_addr=[0x3000, 0x4000],
num_blocks=2)
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
def test_get_remote_metadata_success(self, mock_recv, mock_send):
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
self.thread._get_remote_metadata("host1", 5555)
mock_get_socket.assert_called_once_with("host1", 5555)
mock_return_socket.assert_called_once_with(mock_socket, "host1",
5555)
mock_send.assert_called_once_with(
mock_socket, self.thread.encoder.encode((GET_META_MSG, "")))
mock_recv.assert_called_once_with(mock_socket,
self.thread.remote_poller)
self.assertEqual(
self.thread.kv_caches_base_addr["remote_engine"][5555],
[0x3000, 0x4000])
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
side_effect=Exception("Network error"))
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
with self.assertRaises(Exception) as context:
self.thread._get_remote_metadata("host1", 5555)
self.assertEqual(str(context.exception), "Network error")
mock_return_socket.assert_called_once()
class TestMainThreadLoop(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
def test_run_loop_normal(self, mock_handle):
test_request = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.request_queue.put(test_request)
self.thread.request_queue.put(None)
self.thread.start()
time.sleep(0.1)
self.thread.join(timeout=1.0)
self.assertTrue(self.thread.ready_event.is_set())
mock_handle.assert_called_once_with(test_request)
self.assertTrue(self.thread.request_queue.empty())
class MockVllmConfig:
def __init__(self):
self.model_config = MagicMock()
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank_local = 0
self.parallel_config.data_parallel_size_local = 1
self.cache_config.block_size = 16
self.kv_transfer_config.kv_port = 5000
self.kv_transfer_config.kv_role = 'kv_producer'
self.kv_transfer_config.get_from_extra_config = MagicMock()
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
}
}.get(k, d)
class MockRequest:
def __init__(self,
request_id,
prompt_token_ids=None,
kv_transfer_params=None,
status=None):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4]
self.kv_transfer_params = kv_transfer_params or {}
self.status = status or "running"
self.output_token_ids = [101, 102]
class TestKVCacheTaskTracker(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
def test_update_done_task_count(self):
self.assertEqual(len(self.tracker.finished_requests), 0)
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time)
result = self.tracker.delayed_free_requests
self.assertEqual(len(result), 1)
self.assertEqual(result[0], ("req_1", current_time))
self.tracker.update_done_task_count("req_1")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
self.assertEqual(result_finished, {"req_1"})
self.assertEqual(len(result_delayed), 0)
def test_retrieve_expired_requests(self):
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time - 600)
self.tracker.add_delayed_request("req_2", current_time)
result = self.tracker._retrieve_expired_requests()
self.assertEqual(result, {
"req_1",
})
result_delay = self.tracker.delayed_free_requests
self.assertEqual(len(result_delay), 1)
self.assertEqual(result_delay[0], ("req_2", current_time))
def test_duplicate_task_update(self):
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
finished = self.tracker.get_and_clear_finished_requests()
self.assertEqual(finished, {"req1"})
class TestMooncakeConnectorMetadata(unittest.TestCase):
def test_add_new_req(self):
meta = MooncakeConnectorMetadata()
self.assertEqual(len(meta.requests), 0)
self.assertEqual(len(meta.requests_to_send), 0)
meta.add_new_req(request_id="req1",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_port": 5000
})
self.assertEqual(len(meta.requests), 1)
req_meta = meta.requests["req1"]
self.assertIsInstance(req_meta, ReqMeta)
self.assertEqual(req_meta.local_block_ids, [1, 2, 3])
self.assertEqual(req_meta.remote_block_ids, [4, 5, 6])
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
self.assertEqual(req_meta.remote_host, "localhost")
self.assertEqual(req_meta.remote_port, 5000)
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self):
config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
request.kv_transfer_params = {"do_remote_prefill": True}
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_build_connector_meta(self):
request = MockRequest("req1")
blocks_mock = MagicMock()
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
}
meta = self.scheduler.build_connector_meta(MagicMock())
self.assertIsInstance(meta, MooncakeConnectorMetadata)
self.assertEqual(len(meta.requests), 1)
self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6])
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_get_finished_count(self):
count = self.scheduler.get_finished_count()
self.assertEqual(count, 2)
class TestHelperFunctions(unittest.TestCase):
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 14, 15]
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(len(src_groups), 2)
self.assertEqual(src_groups[0], [1, 2, 3])
self.assertEqual(src_groups[1], [5, 6])
self.assertEqual(dst_groups[0], [10, 11, 12])
self.assertEqual(dst_groups[1], [14, 15])
def test_group_concurrent_contiguous_empty(self):
src: list[int] = []
dst: list[int] = []
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(src_groups, [])
self.assertEqual(dst_groups, [])
def test_string_to_int64_hash(self):
hash1 = string_to_int64_hash("test_string")
hash2 = string_to_int64_hash("test_string")
self.assertEqual(hash1, hash2)
hash3 = string_to_int64_hash("different_string")
self.assertNotEqual(hash1, hash3)
class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_scheduler_methods(self, mock_method):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
class MockKVCacheBlocks:
def get_unhashed_block_ids(self):
return [4, 5, 6]
class MockSchedulerOutput:
pass
class MockForwardContext:
pass
class TestMooncakeConnector(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
def test_scheduler_initialization(self):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
blocks = MockKVCacheBlocks()
connector.update_state_after_alloc(request, blocks, 3)
mock_method.assert_called_once_with(request, blocks, 3)
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput()
connector.build_connector_meta(scheduler_output)
mock_method.assert_called_once_with(scheduler_output)
@patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.request_finished(request, [1, 2, 3])
mock_method.assert_called_once_with(request, [1, 2, 3])
class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
def test_get_num_new_matched_tokens_no_remote_prefill(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
def test_get_num_new_matched_tokens_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={"do_remote_prefill": True})
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_update_state_after_alloc_no_remote_prefill(self):
request = MockRequest("req1")
blocks = MagicMock()
self.scheduler.update_state_after_alloc(request, blocks, 0)
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_update_state_after_alloc_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={
"do_remote_prefill": True,
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
})
blocks = MockKVCacheBlocks()
self.scheduler.update_state_after_alloc(request, blocks, 3)
self.assertEqual(len(self.scheduler._reqs_need_recv), 1)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6])
def test_request_finished_no_remote_decode(self):
request = MockRequest("req1")
delay_free, params = self.scheduler.request_finished(
request, [1, 2, 3])
self.assertFalse(delay_free)
self.assertIsNone(params)
class TestUtils(unittest.TestCase):
def test_string_to_int64_hash(self):
h1 = string_to_int64_hash("hello")
h2 = string_to_int64_hash("hello")
h3 = string_to_int64_hash("world")
self.assertEqual(h1, h2)
self.assertNotEqual(h1, h3)
self.assertIsInstance(h1, int)
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 20, 21]
src_g, dst_g = group_concurrent_contiguous(src, dst)
self.assertEqual(src_g, [[1, 2, 3], [5, 6]])
self.assertEqual(dst_g, [[10, 11, 12], [20, 21]])
def test_group_empty(self):
src_g, dst_g = group_concurrent_contiguous([], [])
self.assertEqual(src_g, [])
self.assertEqual(dst_g, [])
def test_zmq_ctx_invalid_type(self):
with self.assertRaises(ValueError):
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
pass
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
def test_zmq_ctx_ok(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
self.assertEqual(s, mock_socket)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_success(self, mock_logger):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
mock_socket.send.assert_called_once_with(b"hello")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
"send failed")
with self.assertRaises(RuntimeError):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_success(self, mock_logger):
mock_socket = MagicMock()
mock_socket.recv.return_value = b"response"
mock_poller = MagicMock()
mock_poller.poll.return_value = [
(mock_socket, zmq.POLLIN) # type: ignore
]
data = ensure_zmq_recv(mock_socket, mock_poller)
self.assertEqual(data, b"response")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_poller = MagicMock()
mock_poller.poll.return_value = []
with self.assertRaises(RuntimeError):
ensure_zmq_recv(mock_socket,
mock_poller,
timeout=0.01,
max_retries=2)
class MockMooncakeAgentMetadata:
def __init__(self, **kwargs):
pass
class MockMooncakeConnectorMetadata:
def __init__(self):
self.requests = {}
class MockKVCacheSendingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockKVCacheRecvingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
self.add_request = MagicMock()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockTensor:
def __init__(self, *args, **kwargs):
self.size = MagicMock(return_value=(10, 16, 8, 16))
self.element_size = MagicMock(return_value=4)
self.shape = (10, 16, 8, 16)
self.data_ptr = MagicMock(return_value=0x1000)
mock_envs_ascend = MagicMock()
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
mock_logger = MagicMock()
class MockTransferEngine:
def initialize(self, *args, **kwargs):
return 0
def register_memory(self, *args, **kwargs):
return 1
class MockEnvsAscend:
MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
PHYSICAL_DEVICES = "10,11"
def mock_get_tensor_model_parallel_rank():
return 0
def mock_get_tp_group():
return MagicMock()
def mock_get_ip():
return "127.0.0.1"
def mock_string_to_int64_hash(s):
return hash(s)
class TestMooncakeConnectorWorker(unittest.TestCase):
def setUp(self):
self.envs_ascend_mock = MockEnvsAscend()
self.mock_transfer_engine = MagicMock()
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [
patch('os.getenv', return_value="10,11"),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
patch('math.prod', return_value=128),
patch('random.Random'),
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
mock_get_tensor_model_parallel_rank),
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
mock_get_ip),
patch(
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash),
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.logger',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
patch.dict('sys.modules',
{'vllm_ascend.envs': self.envs_ascend_mock}),
]
for p in self.patches:
p.start() # type: ignore
self.vllm_config = MockVllmConfig()
self.engine_id = "test_engine"
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
def tearDown(self):
for p in self.patches:
p.stop() # type: ignore
def test_register_kv_caches_producer(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertEqual(len(worker.kv_caches), 1)
self.assertIsNotNone(worker.kv_send_thread)
self.assertIsNone(worker.kv_recv_thread)
def test_register_kv_caches_consumer(self):
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertIsNone(worker.kv_send_thread)
self.assertIsNotNone(worker.kv_recv_thread)
def test_register_kv_caches_mla_case(self):
mla_cache1 = MagicMock()
mla_cache1.size.return_value = (10, 16, 1, 16)
mla_cache2 = MagicMock()
mla_cache2.size.return_value = (10, 16, 1, 8)
mla_caches = {"layer1": (mla_cache1, mla_cache2)}
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(mla_caches)
self.assertTrue(worker.use_mla)
self.assertEqual(len(worker.block_len), 2)
def test_device_id_selection_with_physical_devices(self):
# Test with physical devices set
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
# Default tp_rank is 0, so device_id should be 10
self.assertEqual(worker.device_id, 10)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,169 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and blocks should be freed
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_remote_a)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
1))
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,239 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
block_size=BLOCK_SIZE)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
os.environ["VLLM_USE_V1"] = "1"
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
max_num_seqs: int = 16,
max_num_batched_tokens: int = 1024,
block_size: int = 128,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
"fake_weight")
model_config = ModelConfig(
model=fake_weight_path,
skip_tokenizer_init=True,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="LLMDataDistCMgrConnector",
kv_role="kv_both",
kv_connector_module_path=
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
_none_hash_initialized = False
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 128,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
else:
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
extra_args = {}
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
return model_runner_output