[P/D]mooncake_connector adapted to 0.10.1 (#2664)

### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(https://github.com/vllm-project/vllm/pull/19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(https://github.com/vllm-project/vllm/pull/23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
fa4311d85f

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
baxingpiaochong
2025-09-04 08:22:10 +08:00
committed by GitHub
parent 07d44ade19
commit df88a2ecc8
3 changed files with 130 additions and 319 deletions

View File

@@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch
import msgspec
import zmq
from vllm.utils import make_zmq_path
from zmq import Context # type: ignore
fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
@@ -32,193 +31,17 @@ DONE_RECVING_MSG = b"done_recving_msg"
class TestKVCacheTaskTrackerInit(unittest.TestCase):
def test_init_basic_properties(self):
tracker = KVCacheTaskTracker(tp_rank=1,
local_engine_id="engine1",
target_count=10)
self.assertEqual(tracker.tp_rank, 1)
self.assertEqual(tracker.local_engine_id, "engine1")
self.assertEqual(tracker.target_count, 10)
tracker = KVCacheTaskTracker()
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
self.assertIsInstance(tracker.done_task_counts, defaultdict)
self.assertIsInstance(tracker.finished_requests, set)
def test_socket_path_generation(self):
tracker = KVCacheTaskTracker(tp_rank=1,
local_engine_id="engine42",
target_count=1)
self.assertEqual(tracker.socket_path,
"ipc:///tmp/vllm_mooncake_connector_engine42.ipc")
@patch("vllm_ascend.distributed.mooncake_connector.threading.Thread")
def test_tp_rank_zero_initialization(self, mock_thread):
tracker = KVCacheTaskTracker(tp_rank=0,
local_engine_id="test",
target_count=1)
mock_thread.assert_called_once_with(
target=tracker._listen_for_completion_signals,
daemon=True,
name="KVCacheTaskTrackerListenerThread")
mock_thread.return_value.start.assert_called_once()
self.assertIsNone(tracker.socket)
self.assertTrue(tracker.listener.daemon)
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_tp_rank_non_zero_initialization(self, mock_logger,
mock_make_zmq_socket):
mock_socket = MagicMock()
mock_make_zmq_socket.return_value = mock_socket
tracker = KVCacheTaskTracker(tp_rank=1,
local_engine_id="test",
target_count=1)
mock_make_zmq_socket.assert_called_once_with(
ctx=unittest.mock.ANY,
path="ipc:///tmp/vllm_mooncake_connector_test.ipc",
socket_type=zmq.PUSH, # type: ignore
bind=False)
mock_logger.info.assert_called_once_with(
"Connecting to transfer socket at %s",
"ipc:///tmp/vllm_mooncake_connector_test.ipc")
self.assertIsNone(tracker.listener)
self.assertEqual(tracker.socket, mock_socket)
class TestKVCacheTaskTrackerListenMethod(unittest.TestCase):
def setUp(self):
self.tp_rank = 0
self.local_engine_id = "test_engine_ut"
self.target_count = 3
self.tracker = KVCacheTaskTracker(self.tp_rank, self.local_engine_id,
self.target_count)
self.original_listen = self.tracker._listen_for_completion_signals
def tearDown(self):
self.tracker._listen_for_completion_signals = self.original_listen
Context.instance().term()
time.sleep(0.1)
def test_normal_message_processing(self):
listener_thread = threading.Thread(
target=self.tracker._listen_for_completion_signals, daemon=True)
listener_thread.start()
time.sleep(0.2)
test_messages = [("request_001", 1), ("request_001", 2),
("request_002", 0), ("request_003", 1)]
ctx = Context()
sender_socket = ctx.socket(zmq.PUSH) # type: ignore
sender_socket.connect(self.tracker.socket_path)
for msg in test_messages:
sender_socket.send_pyobj(msg)
time.sleep(0.05)
sender_socket.close()
time.sleep(0.2)
with self.tracker.done_task_lock:
self.assertEqual(len(self.tracker.done_task_counts["request_001"]),
2)
self.assertIn(1, self.tracker.done_task_counts["request_001"])
self.assertIn(2, self.tracker.done_task_counts["request_001"])
self.assertEqual(len(self.tracker.done_task_counts["request_002"]),
1)
self.assertIn(0, self.tracker.done_task_counts["request_002"])
self.assertEqual(len(self.tracker.done_task_counts["request_003"]),
1)
self.assertIn(1, self.tracker.done_task_counts["request_003"])
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
autospec=True)
def test_listen_with_timeout(self, mock_make_socket):
mock_socket = MagicMock()
def mock_recv():
start = time.time()
while time.time() - start < 0.5:
time.sleep(0.01)
return ("req1", 0)
mock_socket.recv_pyobj = mock_recv
mock_make_socket.return_value = mock_socket
test_thread = threading.Thread(
target=self.tracker._listen_for_completion_signals, daemon=True)
test_thread.start()
test_thread.join(timeout=1.0)
mock_make_socket.assert_called_once()
class TestKVCacheTaskTrackerTP(unittest.TestCase):
def setUp(self):
self.local_engine_id = "test_engine"
self.target_count = 3
def test_update_done_task_count_tp_rank_0(self):
tracker = KVCacheTaskTracker(tp_rank=0,
local_engine_id=self.local_engine_id,
target_count=self.target_count)
test_request_id = "test_req_001"
test_tp_rank = 1
tracker.update_done_task_count(test_request_id, test_tp_rank)
with tracker.done_task_lock:
self.assertEqual(len(tracker.done_task_counts[test_request_id]), 1)
self.assertIn(test_tp_rank,
tracker.done_task_counts[test_request_id])
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
autospec=True)
def test_update_done_task_count_non_zero_tp(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
tracker = KVCacheTaskTracker(tp_rank=1,
local_engine_id=self.local_engine_id,
target_count=self.target_count)
test_request_id = "test_req_002"
test_tp_rank = 1
tracker.update_done_task_count(test_request_id, test_tp_rank)
mock_socket.send_pyobj.assert_called_once_with(
(test_request_id, test_tp_rank))
with tracker.done_task_lock:
self.assertNotIn(test_request_id, tracker.done_task_counts)
@patch("vllm_ascend.distributed.mooncake_connector.logger", autospec=True)
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
autospec=True)
def test_update_done_task_count_logging(self, mock_make_socket,
mock_logger):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
tracker = KVCacheTaskTracker(tp_rank=2,
local_engine_id=self.local_engine_id,
target_count=self.target_count)
test_request_id = "test_req_003"
tracker.update_done_task_count(test_request_id, 2)
mock_logger.debug.assert_called_once_with(
"Sent done signal for request %s to tp 0", test_request_id)
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
autospec=True)
def test_update_multiple_calls(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
tracker = KVCacheTaskTracker(tp_rank=1,
local_engine_id=self.local_engine_id,
target_count=self.target_count)
test_data = [("req1", 1), ("req1", 1), ("req2", 1)]
for req_id, rank in test_data:
tracker.update_done_task_count(req_id, rank)
self.assertEqual(mock_socket.send_pyobj.call_count, 3)
mock_socket.send_pyobj.assert_called_with(("req2", 1))
self.assertIsInstance(tracker.delayed_free_requests, deque)
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker(tp_rank=0,
local_engine_id="test",
target_count=3)
self.tracker = KVCacheTaskTracker()
self.tracker.finished_requests = set()
self.tracker.done_task_counts = defaultdict(set)
self.tracker.done_task_lock = threading.Lock()
def test_empty_requests(self):
@@ -251,14 +74,6 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
self.assertEqual(sum(1 for r in results if r), 1)
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_after_increment(self):
self.tracker._increment_task_count("req_123", 0)
self.tracker._increment_task_count("req_123", 1)
self.tracker._increment_task_count("req_123", 2)
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, {"req_123"})
self.assertEqual(self.tracker.get_and_clear_finished_requests(), set())
class TestKVCacheSendingThreadInit(unittest.TestCase):
@@ -282,47 +97,6 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
if hasattr(thread, 'is_alive') and thread.is_alive():
thread.join(timeout=0.1)
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
def test_initialization_basic(self, mock_tracker):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertEqual(thread.tp_rank, 1)
self.assertEqual(thread.decode_tp_size, 4)
self.assertEqual(thread.local_engine_id, 'engine_1')
mock_tracker.assert_called_once()
args = mock_tracker.call_args[0]
kwargs = mock_tracker.call_args[1]
if args:
self.assertEqual(args[0], 1)
self.assertEqual(args[1], 'engine_1')
self.assertEqual(args[2], 4)
else:
self.assertEqual(kwargs['tp_rank'], 1)
self.assertEqual(kwargs['local_engine_id'], 'engine_1')
self.assertEqual(kwargs['target_count'], 4)
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
def test_task_tracker_initialization(self, mock_tracker):
args = self.common_args.copy()
args.update({
'tp_rank': 2,
'decode_tp_size': 8,
'local_engine_id': 'engine_2'
})
thread = KVCacheSendingThread(**args)
self.threads.append(thread)
mock_tracker.assert_called_once()
call_args = mock_tracker.call_args[0]
call_kwargs = mock_tracker.call_args[1]
if call_args:
self.assertEqual(call_args[0], 2)
self.assertEqual(call_args[1], 'engine_2')
self.assertEqual(call_args[2], 8)
else:
self.assertEqual(call_kwargs['tp_rank'], 2)
self.assertEqual(call_kwargs['local_engine_id'], 'engine_2')
self.assertEqual(call_kwargs['target_count'], 8)
def test_thread_daemon_property(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
@@ -542,7 +316,7 @@ class TestCoreFunctionality(unittest.TestCase):
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1", self.thread.tp_rank)
"req1")
self.mock_queue.task_done.assert_called_once()
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
@@ -675,9 +449,11 @@ class TestMainThreadLoop(unittest.TestCase):
class MockVllmConfig:
def __init__(self):
self.model_config = MagicMock()
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank_local = 0
self.parallel_config.data_parallel_size_local = 1
@@ -714,28 +490,40 @@ class MockRequest:
class TestKVCacheTaskTracker(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker(tp_rank=0,
local_engine_id="test_engine",
target_count=2)
self.tracker = KVCacheTaskTracker()
def test_update_task_count(self):
self.assertEqual(len(self.tracker.done_task_counts), 0)
def test_update_done_task_count(self):
self.assertEqual(len(self.tracker.finished_requests), 0)
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
self.tracker.update_done_task_count("req1", 0)
self.tracker.update_done_task_count("req1", 1)
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time)
result = self.tracker.delayed_free_requests
self.assertEqual(len(result), 1)
self.assertEqual(result[0], ("req_1", current_time))
self.assertEqual(len(self.tracker.finished_requests), 1)
self.assertTrue("req1" in self.tracker.finished_requests)
self.tracker.update_done_task_count("req_1")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
self.assertEqual(result_finished, {"req_1"})
self.assertEqual(len(result_delayed), 0)
finished = self.tracker.get_and_clear_finished_requests()
self.assertEqual(finished, {"req1"})
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_retrieve_expired_requests(self):
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time - 600)
self.tracker.add_delayed_request("req_2", current_time)
result = self.tracker._retrieve_expired_requests()
self.assertEqual(result, {
"req_1",
})
result_delay = self.tracker.delayed_free_requests
self.assertEqual(len(result_delay), 1)
self.assertEqual(result_delay[0], ("req_2", current_time))
def test_duplicate_task_update(self):
self.tracker.update_done_task_count("req1", 0)
self.tracker.update_done_task_count("req1", 0)
self.tracker.update_done_task_count("req1", 1)
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
finished = self.tracker.get_and_clear_finished_requests()
self.assertEqual(finished, {"req1"})
@@ -745,6 +533,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
def test_add_new_req(self):
meta = MooncakeConnectorMetadata()
self.assertEqual(len(meta.requests), 0)
self.assertEqual(len(meta.requests_to_send), 0)
meta.add_new_req(request_id="req1",
local_block_ids=[1, 2, 3],
kv_transfer_params={
@@ -802,6 +593,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_get_finished_count(self):
count = self.scheduler.get_finished_count()
self.assertEqual(count, 2)
class TestHelperFunctions(unittest.TestCase):