v0.10.1rc1
This commit is contained in:
96
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
96
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
import types
|
||||
|
||||
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
|
||||
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
|
||||
|
||||
|
||||
def test_basic_inferface():
|
||||
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.requests) == 1
|
||||
assert request_id in kv_connector_metadata.requests
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_read_agent_metadata():
|
||||
rank_table = {
|
||||
"version":
|
||||
"1.2",
|
||||
"server_count":
|
||||
"2",
|
||||
"prefill_device_list": [{
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.1",
|
||||
"cluster_id": "0",
|
||||
}, {
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.2",
|
||||
"cluster_id": "1",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.3",
|
||||
"cluster_id": "2",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.4",
|
||||
"cluster_id": "3",
|
||||
}]
|
||||
}
|
||||
|
||||
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
|
||||
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
worker = types.SimpleNamespace()
|
||||
worker.local_ip = worker_local_ip
|
||||
worker.tp_rank = worker_tp_rank
|
||||
worker.llm_datadist_role = LLMRole.PROMPT
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
|
||||
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
|
||||
worker, rank_table)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
|
||||
return agent_metadata.device_ip
|
||||
|
||||
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
|
||||
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
|
||||
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"
|
||||
998
tests/ut/kv_connector/test_mooncake_connector.py
Normal file
998
tests/ut/kv_connector/test_mooncake_connector.py
Normal file
@@ -0,0 +1,998 @@
|
||||
import os
|
||||
import queue
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.utils import make_zmq_path
|
||||
|
||||
fake_engine = types.ModuleType("mooncake.engine")
|
||||
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
|
||||
sys.modules["mooncake.engine"] = fake_engine
|
||||
|
||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
|
||||
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
|
||||
MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send,
|
||||
group_concurrent_contiguous, string_to_int64_hash, zmq_ctx)
|
||||
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
DONE_RECVING_MSG = b"done_recving_msg"
|
||||
|
||||
|
||||
class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
|
||||
def test_init_basic_properties(self):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
self.tracker.finished_requests = set()
|
||||
self.tracker.done_task_lock = threading.Lock()
|
||||
|
||||
def test_empty_requests(self):
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, set())
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
def test_single_request(self):
|
||||
self.tracker.finished_requests = {"req_123"}
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, {"req_123"})
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
def test_multiple_requests(self):
|
||||
self.tracker.finished_requests = {"req_1", "req_2", "req_3"}
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_concurrent_access(self, mock_logger):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
self.tracker.finished_requests = {"req_1", "req_2"}
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [
|
||||
executor.submit(self.tracker.get_and_clear_finished_requests)
|
||||
for _ in range(3)
|
||||
]
|
||||
results = [f.result() for f in futures]
|
||||
self.assertEqual(sum(1 for r in results if r), 1)
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
|
||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': MagicMock(),
|
||||
'ready_event': threading.Event()
|
||||
}
|
||||
self.threads = []
|
||||
|
||||
def tearDown(self):
|
||||
for thread in self.threads:
|
||||
if hasattr(thread, 'task_tracker') and hasattr(
|
||||
thread.task_tracker, 'socket'):
|
||||
thread.task_tracker.socket.close()
|
||||
if hasattr(thread, 'is_alive') and thread.is_alive():
|
||||
thread.join(timeout=0.1)
|
||||
|
||||
def test_thread_daemon_property(self):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
self.assertTrue(thread.daemon)
|
||||
|
||||
def test_thread_name_format(self):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
self.assertEqual(thread.name, "KVCacheSendingThread")
|
||||
|
||||
def test_ready_event_reference(self):
|
||||
custom_event = threading.Event()
|
||||
args = self.common_args.copy()
|
||||
args['ready_event'] = custom_event
|
||||
thread = KVCacheSendingThread(**args)
|
||||
self.threads.append(thread)
|
||||
self.assertIs(thread.ready_event, custom_event)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': {
|
||||
"test": "metadata"
|
||||
},
|
||||
'ready_event': threading.Event()
|
||||
}
|
||||
self.thread = KVCacheSendingThread(**self.common_args)
|
||||
|
||||
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
|
||||
def test_get_and_clear_finished_requests(self, mock_get_clear):
|
||||
expected_requests = {'req1', 'req2'}
|
||||
mock_get_clear.return_value = expected_requests
|
||||
result = self.thread.get_and_clear_finished_requests()
|
||||
mock_get_clear.assert_called_once()
|
||||
self.assertEqual(result, expected_requests)
|
||||
|
||||
|
||||
class TestKVCacheSendingThread(unittest.TestCase):
|
||||
|
||||
def test_run_handles_get_meta_and_done_recv_msgs(self):
|
||||
ready_event = threading.Event()
|
||||
metadata = MooncakeAgentMetadata(
|
||||
engine_id="engine1",
|
||||
te_rpc_port=9090,
|
||||
kv_caches_base_addr=[12345678],
|
||||
num_blocks=2,
|
||||
)
|
||||
host = "127.0.0.1"
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
free_port = s.getsockname()[1]
|
||||
|
||||
thread = KVCacheSendingThread(
|
||||
tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
)
|
||||
thread.start()
|
||||
self.assertTrue(ready_event.wait(timeout=3),
|
||||
"Server thread startup timeout")
|
||||
|
||||
context = zmq.Context() # type: ignore
|
||||
sock = context.socket(zmq.DEALER) # type: ignore
|
||||
sock.connect(f"tcp://{host}:{free_port}")
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
|
||||
|
||||
sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))])
|
||||
frames = sock.recv_multipart()
|
||||
self.assertEqual(frames[0], b"")
|
||||
meta = decoder.decode(frames[1])
|
||||
self.assertEqual(meta.engine_id, "engine1")
|
||||
self.assertEqual(meta.kv_caches_base_addr, [12345678])
|
||||
self.assertEqual(meta.num_blocks, 2)
|
||||
|
||||
req_id = "request_42"
|
||||
sock.send_multipart(
|
||||
[b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))])
|
||||
frames = sock.recv_multipart()
|
||||
self.assertEqual(frames[0], b"")
|
||||
self.assertEqual(frames[1], b"ACK")
|
||||
self.assertIn(req_id, thread.task_tracker.finished_requests)
|
||||
|
||||
sock.close()
|
||||
context.term()
|
||||
|
||||
|
||||
class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
}
|
||||
self.thread.add_request(**test_req)
|
||||
queued = self.thread.request_queue.get_nowait()
|
||||
self.assertEqual(queued["request_id"], "req1")
|
||||
self.assertEqual(queued["remote_host"], "localhost")
|
||||
|
||||
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
|
||||
def test_get_finished_requests(self, mock_tracker):
|
||||
mock_tracker.return_value = {"req1", "req2"}
|
||||
result = self.thread.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, {"req1", "req2"})
|
||||
|
||||
|
||||
class TestSocketManagement(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
|
||||
def test_get_remote_socket(self, mock_make_socket, mock_context):
|
||||
mock_sock = MagicMock()
|
||||
mock_make_socket.return_value = mock_sock
|
||||
test_host = "test_host"
|
||||
test_port = 12345
|
||||
|
||||
sock = self.thread._get_remote_socket(test_host, test_port)
|
||||
|
||||
self.assertEqual(sock, mock_sock)
|
||||
mock_make_socket.assert_called_once()
|
||||
args, kwargs = mock_make_socket.call_args
|
||||
self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345')
|
||||
self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore
|
||||
self.assertFalse(kwargs.get('bind', True))
|
||||
self.thread.remote_poller.register.assert_called_with(
|
||||
mock_sock, zmq.POLLIN) # type: ignore
|
||||
|
||||
def test_return_socket_to_pool(self):
|
||||
mock_sock = MagicMock()
|
||||
test_host = "test_host"
|
||||
test_port = 12345
|
||||
test_path = make_zmq_path("tcp", test_host, test_port)
|
||||
|
||||
self.thread._return_remote_socket(mock_sock, test_host, test_port)
|
||||
|
||||
self.assertEqual(len(self.thread.remote_sockets[test_path]), 1)
|
||||
self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock)
|
||||
self.thread.remote_poller.register.assert_not_called()
|
||||
|
||||
|
||||
class TestCoreFunctionality(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.mock_queue = MagicMock()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
self.engine.batch_transfer_sync_read.return_value = 0
|
||||
self.thread.remote_te_port = {"remote_engine": {6666: 7777}}
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
||||
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
||||
def test_handle_request(self, mock_send, mock_transfer):
|
||||
self.thread._handle_request(self.test_req)
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
"req1")
|
||||
self.mock_queue.task_done.assert_called_once()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
|
||||
def test_transfer_kv_cache(self, mock_get_meta):
|
||||
self.thread.kv_caches_base_addr["remote_engine"] = {
|
||||
6666: [0x3000, 0x4000]
|
||||
}
|
||||
|
||||
self.thread._transfer_kv_cache(self.test_req)
|
||||
|
||||
self.engine.batch_transfer_sync_read.assert_called_once()
|
||||
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
|
||||
self.assertEqual(call_args[0], "localhost:7777")
|
||||
self.assertIsInstance(call_args[1], list)
|
||||
self.assertIsInstance(call_args[2], list)
|
||||
self.assertIsInstance(call_args[3], list)
|
||||
self.assertEqual(len(call_args[1]), len(call_args[2]))
|
||||
self.assertEqual(len(call_args[1]), len(call_args[3]))
|
||||
mock_get_meta.assert_not_called()
|
||||
|
||||
def test_transfer_kv_cache_failure(self):
|
||||
self.engine.batch_transfer_sync_read.return_value = -1
|
||||
self.thread.kv_caches_base_addr["remote_engine"] = {
|
||||
6666: [0x3000, 0x4000]
|
||||
}
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.thread._transfer_kv_cache(self.test_req)
|
||||
|
||||
|
||||
class TestMetadataHandling(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
kv_caches_base_addr=[0x3000, 0x4000],
|
||||
num_blocks=2)
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
|
||||
def test_get_remote_metadata_success(self, mock_recv, mock_send):
|
||||
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
|
||||
|
||||
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
|
||||
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
|
||||
mock_socket = MagicMock()
|
||||
mock_get_socket.return_value = mock_socket
|
||||
|
||||
self.thread._get_remote_metadata("host1", 5555)
|
||||
|
||||
mock_get_socket.assert_called_once_with("host1", 5555)
|
||||
mock_return_socket.assert_called_once_with(mock_socket, "host1",
|
||||
5555)
|
||||
mock_send.assert_called_once_with(
|
||||
mock_socket, self.thread.encoder.encode((GET_META_MSG, "")))
|
||||
mock_recv.assert_called_once_with(mock_socket,
|
||||
self.thread.remote_poller)
|
||||
self.assertEqual(
|
||||
self.thread.kv_caches_base_addr["remote_engine"][5555],
|
||||
[0x3000, 0x4000])
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
|
||||
side_effect=Exception("Network error"))
|
||||
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
|
||||
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
|
||||
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
|
||||
mock_socket = MagicMock()
|
||||
mock_get_socket.return_value = mock_socket
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.thread._get_remote_metadata("host1", 5555)
|
||||
|
||||
self.assertEqual(str(context.exception), "Network error")
|
||||
mock_return_socket.assert_called_once()
|
||||
|
||||
|
||||
class TestMainThreadLoop(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
def test_run_loop_normal(self, mock_handle):
|
||||
test_request = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
}
|
||||
|
||||
self.thread.request_queue.put(test_request)
|
||||
self.thread.request_queue.put(None)
|
||||
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
self.thread.join(timeout=1.0)
|
||||
|
||||
self.assertTrue(self.thread.ready_event.is_set())
|
||||
mock_handle.assert_called_once_with(test_request)
|
||||
self.assertTrue(self.thread.request_queue.empty())
|
||||
|
||||
|
||||
class MockVllmConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = MagicMock()
|
||||
self.parallel_config = MagicMock()
|
||||
self.cache_config = MagicMock()
|
||||
self.kv_transfer_config = MagicMock()
|
||||
self.model_config.use_mla = True
|
||||
self.parallel_config.tensor_parallel_size = 2
|
||||
self.parallel_config.data_parallel_rank_local = 0
|
||||
self.parallel_config.data_parallel_size_local = 1
|
||||
self.cache_config.block_size = 16
|
||||
self.kv_transfer_config.kv_port = 5000
|
||||
self.kv_transfer_config.kv_role = 'kv_producer'
|
||||
self.kv_transfer_config.get_from_extra_config = MagicMock()
|
||||
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
}
|
||||
}.get(k, d)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
|
||||
def __init__(self,
|
||||
request_id,
|
||||
prompt_token_ids=None,
|
||||
kv_transfer_params=None,
|
||||
status=None):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4]
|
||||
self.kv_transfer_params = kv_transfer_params or {}
|
||||
self.status = status or "running"
|
||||
self.output_token_ids = [101, 102]
|
||||
|
||||
|
||||
class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time - 600)
|
||||
self.tracker.add_delayed_request("req_2", current_time)
|
||||
result = self.tracker._retrieve_expired_requests()
|
||||
self.assertEqual(result, {
|
||||
"req_1",
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
|
||||
finished = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(finished, {"req1"})
|
||||
|
||||
|
||||
class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
|
||||
def test_add_new_req(self):
|
||||
meta = MooncakeConnectorMetadata()
|
||||
self.assertEqual(len(meta.requests), 0)
|
||||
self.assertEqual(len(meta.requests_to_send), 0)
|
||||
|
||||
meta.add_new_req(request_id="req1",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
})
|
||||
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
req_meta = meta.requests["req1"]
|
||||
self.assertIsInstance(req_meta, ReqMeta)
|
||||
self.assertEqual(req_meta.local_block_ids, [1, 2, 3])
|
||||
self.assertEqual(req_meta.remote_block_ids, [4, 5, 6])
|
||||
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
|
||||
self.assertEqual(req_meta.remote_host, "localhost")
|
||||
self.assertEqual(req_meta.remote_port, 5000)
|
||||
|
||||
|
||||
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens(self):
|
||||
request = MockRequest("req1")
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertFalse(async_flag)
|
||||
|
||||
request.kv_transfer_params = {"do_remote_prefill": True}
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 3)
|
||||
self.assertTrue(async_flag)
|
||||
|
||||
def test_build_connector_meta(self):
|
||||
request = MockRequest("req1")
|
||||
blocks_mock = MagicMock()
|
||||
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
|
||||
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
|
||||
request.kv_transfer_params = {
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
}
|
||||
|
||||
meta = self.scheduler.build_connector_meta(MagicMock())
|
||||
self.assertIsInstance(meta, MooncakeConnectorMetadata)
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6])
|
||||
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
|
||||
def test_get_finished_count(self):
|
||||
count = self.scheduler.get_finished_count()
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
|
||||
class TestHelperFunctions(unittest.TestCase):
|
||||
|
||||
def test_group_concurrent_contiguous(self):
|
||||
src: list[int] = [1, 2, 3, 5, 6]
|
||||
dst: list[int] = [10, 11, 12, 14, 15]
|
||||
|
||||
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
|
||||
|
||||
self.assertEqual(len(src_groups), 2)
|
||||
self.assertEqual(src_groups[0], [1, 2, 3])
|
||||
self.assertEqual(src_groups[1], [5, 6])
|
||||
self.assertEqual(dst_groups[0], [10, 11, 12])
|
||||
self.assertEqual(dst_groups[1], [14, 15])
|
||||
|
||||
def test_group_concurrent_contiguous_empty(self):
|
||||
src: list[int] = []
|
||||
dst: list[int] = []
|
||||
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
|
||||
self.assertEqual(src_groups, [])
|
||||
self.assertEqual(dst_groups, [])
|
||||
|
||||
def test_string_to_int64_hash(self):
|
||||
hash1 = string_to_int64_hash("test_string")
|
||||
hash2 = string_to_int64_hash("test_string")
|
||||
self.assertEqual(hash1, hash2)
|
||||
|
||||
hash3 = string_to_int64_hash("different_string")
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
|
||||
class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
||||
|
||||
def test_scheduler_role(self):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_scheduler_methods(self, mock_method):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
|
||||
class MockKVCacheBlocks:
|
||||
|
||||
def get_unhashed_block_ids(self):
|
||||
return [4, 5, 6]
|
||||
|
||||
|
||||
class MockSchedulerOutput:
|
||||
pass
|
||||
|
||||
|
||||
class MockForwardContext:
|
||||
pass
|
||||
|
||||
|
||||
class TestMooncakeConnector(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
def test_scheduler_initialization(self):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_get_num_new_matched_tokens(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
||||
def test_update_state_after_alloc(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
blocks = MockKVCacheBlocks()
|
||||
connector.update_state_after_alloc(request, blocks, 3)
|
||||
mock_method.assert_called_once_with(request, blocks, 3)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
||||
def test_build_connector_meta(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
scheduler_output = MockSchedulerOutput()
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
mock_method.assert_called_once_with(scheduler_output)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
||||
def test_request_finished(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.request_finished(request, [1, 2, 3])
|
||||
mock_method.assert_called_once_with(request, [1, 2, 3])
|
||||
|
||||
|
||||
class TestMooncakeConnectorScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertFalse(async_flag)
|
||||
|
||||
def test_get_num_new_matched_tokens_with_remote_prefill(self):
|
||||
request = MockRequest("req1",
|
||||
kv_transfer_params={"do_remote_prefill": True})
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 3)
|
||||
self.assertTrue(async_flag)
|
||||
|
||||
def test_update_state_after_alloc_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
blocks = MagicMock()
|
||||
self.scheduler.update_state_after_alloc(request, blocks, 0)
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
|
||||
def test_update_state_after_alloc_with_remote_prefill(self):
|
||||
request = MockRequest("req1",
|
||||
kv_transfer_params={
|
||||
"do_remote_prefill": True,
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
})
|
||||
blocks = MockKVCacheBlocks()
|
||||
self.scheduler.update_state_after_alloc(request, blocks, 3)
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 1)
|
||||
self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request)
|
||||
self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6])
|
||||
|
||||
def test_request_finished_no_remote_decode(self):
|
||||
request = MockRequest("req1")
|
||||
delay_free, params = self.scheduler.request_finished(
|
||||
request, [1, 2, 3])
|
||||
self.assertFalse(delay_free)
|
||||
self.assertIsNone(params)
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_string_to_int64_hash(self):
|
||||
h1 = string_to_int64_hash("hello")
|
||||
h2 = string_to_int64_hash("hello")
|
||||
h3 = string_to_int64_hash("world")
|
||||
self.assertEqual(h1, h2)
|
||||
self.assertNotEqual(h1, h3)
|
||||
self.assertIsInstance(h1, int)
|
||||
|
||||
def test_group_concurrent_contiguous(self):
|
||||
src: list[int] = [1, 2, 3, 5, 6]
|
||||
dst: list[int] = [10, 11, 12, 20, 21]
|
||||
src_g, dst_g = group_concurrent_contiguous(src, dst)
|
||||
self.assertEqual(src_g, [[1, 2, 3], [5, 6]])
|
||||
self.assertEqual(dst_g, [[10, 11, 12], [20, 21]])
|
||||
|
||||
def test_group_empty(self):
|
||||
src_g, dst_g = group_concurrent_contiguous([], [])
|
||||
self.assertEqual(src_g, [])
|
||||
self.assertEqual(dst_g, [])
|
||||
|
||||
def test_zmq_ctx_invalid_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
|
||||
pass
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
|
||||
def test_zmq_ctx_ok(self, mock_make_socket):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_socket.return_value = mock_socket
|
||||
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
|
||||
self.assertEqual(s, mock_socket)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_send_success(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
ensure_zmq_send(mock_socket, b"hello")
|
||||
mock_socket.send.assert_called_once_with(b"hello")
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
|
||||
"send failed")
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
|
||||
self.assertEqual(mock_socket.send.call_count, 2)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_recv_success(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv.return_value = b"response"
|
||||
mock_poller = MagicMock()
|
||||
mock_poller.poll.return_value = [
|
||||
(mock_socket, zmq.POLLIN) # type: ignore
|
||||
]
|
||||
data = ensure_zmq_recv(mock_socket, mock_poller)
|
||||
self.assertEqual(data, b"response")
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_poller = MagicMock()
|
||||
mock_poller.poll.return_value = []
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_recv(mock_socket,
|
||||
mock_poller,
|
||||
timeout=0.01,
|
||||
max_retries=2)
|
||||
|
||||
|
||||
class MockMooncakeAgentMetadata:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockMooncakeConnectorMetadata:
|
||||
|
||||
def __init__(self):
|
||||
self.requests = {}
|
||||
|
||||
|
||||
class MockKVCacheSendingThread(threading.Thread):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self._finished_requests = set()
|
||||
|
||||
def get_and_clear_finished_requests(self):
|
||||
return self._finished_requests
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockKVCacheRecvingThread(threading.Thread):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self._finished_requests = set()
|
||||
self.add_request = MagicMock()
|
||||
|
||||
def get_and_clear_finished_requests(self):
|
||||
return self._finished_requests
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockTensor:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.size = MagicMock(return_value=(10, 16, 8, 16))
|
||||
self.element_size = MagicMock(return_value=4)
|
||||
self.shape = (10, 16, 8, 16)
|
||||
self.data_ptr = MagicMock(return_value=0x1000)
|
||||
|
||||
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
|
||||
|
||||
mock_logger = MagicMock()
|
||||
|
||||
|
||||
class MockTransferEngine:
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
def register_memory(self, *args, **kwargs):
|
||||
return 1
|
||||
|
||||
|
||||
class MockEnvsAscend:
|
||||
MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
|
||||
PHYSICAL_DEVICES = "10,11"
|
||||
|
||||
|
||||
def mock_get_tensor_model_parallel_rank():
|
||||
return 0
|
||||
|
||||
|
||||
def mock_get_tp_group():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def mock_get_ip():
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
def mock_string_to_int64_hash(s):
|
||||
return hash(s)
|
||||
|
||||
|
||||
class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.envs_ascend_mock = MockEnvsAscend()
|
||||
self.mock_transfer_engine = MagicMock()
|
||||
self.mock_transfer_engine.get_rpc_port.return_value = 9090
|
||||
self.mock_transfer_engine.initialize.return_value = 0
|
||||
self.mock_transfer_engine.register_memory.return_value = 0
|
||||
|
||||
self.patches = [
|
||||
patch('os.getenv', return_value="10,11"),
|
||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||
patch('torch.Tensor.element_size', return_value=4),
|
||||
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
||||
patch('math.prod', return_value=128),
|
||||
patch('random.Random'),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
|
||||
mock_get_tensor_model_parallel_rank),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||
mock_get_tp_group),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
|
||||
mock_get_ip),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
||||
mock_string_to_int64_hash),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
|
||||
return_value=self.mock_transfer_engine),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
||||
MagicMock()),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.logger',
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
|
||||
MagicMock()),
|
||||
patch.dict('sys.modules',
|
||||
{'vllm_ascend.envs': self.envs_ascend_mock}),
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start() # type: ignore
|
||||
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.engine_id = "test_engine"
|
||||
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
self.assertEqual(len(worker.kv_caches), 1)
|
||||
self.assertIsNotNone(worker.kv_send_thread)
|
||||
self.assertIsNone(worker.kv_recv_thread)
|
||||
|
||||
def test_register_kv_caches_consumer(self):
|
||||
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
self.assertIsNone(worker.kv_send_thread)
|
||||
self.assertIsNotNone(worker.kv_recv_thread)
|
||||
|
||||
def test_register_kv_caches_mla_case(self):
|
||||
mla_cache1 = MagicMock()
|
||||
mla_cache1.size.return_value = (10, 16, 1, 16)
|
||||
mla_cache2 = MagicMock()
|
||||
mla_cache2.size.return_value = (10, 16, 1, 8)
|
||||
mla_caches = {"layer1": (mla_cache1, mla_cache2)}
|
||||
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(mla_caches)
|
||||
self.assertTrue(worker.use_mla)
|
||||
self.assertEqual(len(worker.block_len), 2)
|
||||
|
||||
def test_device_id_selection_with_physical_devices(self):
|
||||
# Test with physical devices set
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
# Default tp_rank is 0, so device_id should be 10
|
||||
self.assertEqual(worker.device_id, 10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
169
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
169
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and blocks should be freed
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_id])
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still works with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||
|
||||
scheduler.add_request(request_remote_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
|
||||
use_eos=True)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||
1))
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_remote.request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
239
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
239
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert (request.num_computed_tokens == 0)
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
scheduler.schedule()
|
||||
|
||||
assert_scheduler_empty(scheduler)
|
||||
233
tests/ut/kv_connector/utils.py
Normal file
233
tests/ut/kv_connector/utils.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 1024,
|
||||
block_size: int = 128,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
|
||||
"fake_weight")
|
||||
model_config = ModelConfig(
|
||||
model=fake_weight_path,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LLMDataDistCMgrConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_module_path=
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 128,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
if use_all_1s_for_prompt_tokens:
|
||||
prompt_token_ids = [1] * num_tokens
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
else:
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
# Make output data structure.
|
||||
extra_args = {}
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
extra_args = {"kv_connector_output": kv_connector_output}
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return model_runner_output
|
||||
Reference in New Issue
Block a user