From 03ca2b26ca9ab6b9a12f021b0595a726ee35e223 Mon Sep 17 00:00:00 2001 From: Chao Lei Date: Mon, 18 Aug 2025 14:30:07 +0800 Subject: [PATCH] [P/D] Mooncake Connector for v1 distributed (#1568) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR adopt Mooncake TransferEngine for kv cache register and pull_blocks style disaggregate prefill implementation. ### Does this PR introduce any user-facing change? No ### Dependencies 1. Cann Dependencies Using Mooncake TransferEngine with Ascend Transport requires CANN version 8.2.RC1 or higher.(see detail Mooncake[#502](https://github.com/kvcache-ai/Mooncake/pull/502)) 2. vllm-ascend This PR depends on changes introduced by #950 (modifications to `model_runner_v1`) and #1361 (updates to `schedule`), both of which have been merged into the `v0.9.1-dev` branch and are expected to land in `main` shortly. ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/1c859a1387286cf650c3bc24fdeac706b97999e8 --------- Signed-off-by: leichao.lc Co-authored-by: jianzs Co-authored-by: zzy-ContiLearn <1831242919@qq.com> Co-authored-by: fems14 <1804143737@qq.com> Co-authored-by: Dreamerleader <2270923832@qq.com> Co-authored-by: chris668899 <15105191595@126.com> Co-authored-by: Pz1116 --- .../mooncake_connector_deployment_guide.md | 163 +++ .../kv_connector/test_mooncake_connector.py | 1198 +++++++++++++++++ vllm_ascend/distributed/__init__.py | 4 + vllm_ascend/distributed/mooncake_connector.py | 1050 +++++++++++++++ 4 files changed, 2415 insertions(+) create mode 100644 examples/disaggregated_prefill_v1/mooncake_connector_deployment_guide.md create mode 100644 tests/ut/kv_connector/test_mooncake_connector.py create mode 100644 vllm_ascend/distributed/mooncake_connector.py diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_deployment_guide.md new file mode 100644 index 0000000..614eca5 --- /dev/null +++ b/examples/disaggregated_prefill_v1/mooncake_connector_deployment_guide.md @@ -0,0 +1,163 @@ +# Mooncake connector deployment Guide + +## Environmental Dependencies + + * Software: + * Python >= 3.9, < 3.12 + * CANN >= 8.2.rc1 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 + * vLLM (same version as vllm-ascend) + * mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md + +The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is + + * vllm: v0.10.1 + * vllm-ascend: v0.10.1rc1 + +## run + +### 1.Run `prefill` Node + +``` +bash run_prefill.sh +``` + +Content of the run_prefill.sh script + +``` +export HCCL_EXEC_TIMEOUT=204 +export HCCL_CONNECT_TIMEOUT=120 +export HCCL_IF_IP=localhost +export GLOO_SOCKET_IFNAME="xxxxxx" +export TP_SOCKET_IFNAME="xxxxxx" +export HCCL_SOCKET_IFNAME="xxxxxx" +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 + +vllm serve "/xxxxx/DeepSeek-V2-Lite-Chat" \ + --host localhost \ + --port 8100 \ + --tensor-parallel-size 2\ + --seed 1024 \ + --max-model-len 2000 \ + --max-num-batched-tokens 2000 \ + --trust-remote-code \ + --enforce-eager \ + --data-parallel-size 2 \ + --data-parallel-address localhost \ + --data-parallel-rpc-port 9100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector": "MooncakeConnectorV1", + "kv_buffer_device": "npu", + "kv_role": "kv_producer", + "kv_parallel_size": 1, + "kv_port": "20001", + "engine_id": "0", + "kv_rank": 0, + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 2, + "tp_size": 2 + }, + "decode": { + "dp_size": 2, + "tp_size": 2 + } + } + }' +``` + +`HCCL_EXEC_TIMEOUT`, `HCCL_CONNECT_TIMEOUT`, and `HCCL_IF_IP` are hccl-related configurations.
+Set `GLOO_SOCKET_IFNAME`, `TP_SOCKET_IFNAME`, and `HCCL_SOCKET_IFNAME` to the corresponding NIC.
+`ASCEND_RT_VISIBLE_DEVICES` specifies the cards on which the node run resides. The total number of cards equals `dp_size*tp_size`.
+`/xxxxx/DeepSeek-V2-Lite-Chat` is configured as a model that requires run.
+`--host`: indicates the IP address of the node to be started.
+`--port`: indicates the port to be started, which corresponds to the port in step 4.
+`--seed`, --max-model-len, and --max-num-batched-tokens model basic configuration. Set this parameter based on the site requirements.
+`--tensor-parallel-size`: specifies the TP size.
+`--data-parallel-size`: indicates the DP size.
+`--data-parallel-address`: indicates the IP address of the DP. Set this parameter to the IP address of the node.--data-parallel-rpc-port: indicates the RPC port for communication in the DP group.
+`--trust-remote-code` can load the local model.
+`--enforce-eager` Turn off the map mode
+`--gpu-memory-utilization`: Percentage of video memory occupied by the card
+`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.
+ + +### 2. Run `decode` Node + +``` +bash run_decode.sh +``` + +Content of the run_decode.sh script + +``` +export HCCL_EXEC_TIMEOUT=204 +export HCCL_CONNECT_TIMEOUT=120 +export HCCL_IF_IP=localhost +export GLOO_SOCKET_IFNAME="xxxxxx" +export TP_SOCKET_IFNAME="xxxxxx" +export HCCL_SOCKET_IFNAME="xxxxxx" +export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 + +vllm serve "/xxxxx/DeepSeek-V2-Lite-Chat" \ + --host localhost \ + --port 8200 \ + --tensor-parallel-size 2\ + --seed 1024 \ + --max-model-len 2000 \ + --max-num-batched-tokens 2000 \ + --trust-remote-code \ + --enforce-eager \ + --data-parallel-size 2 \ + --data-parallel-address localhost \ + --data-parallel-rpc-port 9100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector": "MooncakeConnectorV1", + "kv_buffer_device": "npu", + "kv_role": "kv_consumer", + "kv_parallel_size": 1, + "kv_port": "20002", + "engine_id": "1", + "kv_rank": 1, + "kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 2, + "tp_size": 2 + }, + "decode": { + "dp_size": 2, + "tp_size": 2 + } + } + }' +``` + +### 3. Start proxy_server. ### + +``` +cd /vllm-ascend/examples/disaggregate_prefill_v1/ +python load_balance_proxy_server_example.py --host localhost --prefiller-hosts host1 host2 --prefiller-ports 8100 8101 --decoder-hosts host3 host4 --decoder-ports 8200 8201 +``` + +`--host`: indicates the active node. The value of localhost in the curl command delivered in step 5 must be the same as the host. The default port number for starting the service proxy is 8000.
+`--prefiller-hosts`: Set this parameter to the IP addresses of all p nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.
+`--prefiller-ports`: Set this parameter to the port number of all p nodes, which is the configuration of the port number for the vllm to start the service in step 3. Write the port number after the configuration in sequence and leave a blank space between the port number and the port number. The sequence must be one-to-one mapping to the IP address of --prefiller-hosts.
+`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.
+`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.
+ + +### 4. Run Inference + +Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script. + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ +"model": "model_path", +"prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", +"max_tokens": 256 +}' +``` \ No newline at end of file diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py new file mode 100644 index 0000000..9bca0dc --- /dev/null +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -0,0 +1,1198 @@ +import os +import queue +import socket +import sys +import threading +import time +import types +import unittest +from collections import defaultdict, deque +from unittest.mock import MagicMock, patch + +import msgspec +import zmq +from vllm.utils import make_zmq_path +from zmq import Context # type: ignore + +fake_engine = types.ModuleType("mooncake.engine") +fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] +sys.modules["mooncake.engine"] = fake_engine + +from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402 + KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker, + KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector, + MooncakeConnectorMetadata, MooncakeConnectorScheduler, + MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send, + group_concurrent_contiguous, string_to_int64_hash, zmq_ctx) + +GET_META_MSG = b"get_meta_msg" +DONE_RECVING_MSG = b"done_recving_msg" + + +class TestKVCacheTaskTrackerInit(unittest.TestCase): + + def test_init_basic_properties(self): + tracker = KVCacheTaskTracker(tp_rank=1, + local_engine_id="engine1", + target_count=10) + self.assertEqual(tracker.tp_rank, 1) + self.assertEqual(tracker.local_engine_id, "engine1") + self.assertEqual(tracker.target_count, 10) + self.assertIsInstance(tracker.done_task_lock, type(threading.Lock())) + self.assertIsInstance(tracker.done_task_counts, defaultdict) + self.assertIsInstance(tracker.finished_requests, set) + + def test_socket_path_generation(self): + tracker = KVCacheTaskTracker(tp_rank=1, + local_engine_id="engine42", + target_count=1) + self.assertEqual(tracker.socket_path, + "ipc:///tmp/vllm_mooncake_connector_engine42.ipc") + + @patch("vllm_ascend.distributed.mooncake_connector.threading.Thread") + def test_tp_rank_zero_initialization(self, mock_thread): + tracker = KVCacheTaskTracker(tp_rank=0, + local_engine_id="test", + target_count=1) + mock_thread.assert_called_once_with( + target=tracker._listen_for_completion_signals, + daemon=True, + name="KVCacheTaskTrackerListenerThread") + mock_thread.return_value.start.assert_called_once() + self.assertIsNone(tracker.socket) + self.assertTrue(tracker.listener.daemon) + + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket") + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_tp_rank_non_zero_initialization(self, mock_logger, + mock_make_zmq_socket): + mock_socket = MagicMock() + mock_make_zmq_socket.return_value = mock_socket + tracker = KVCacheTaskTracker(tp_rank=1, + local_engine_id="test", + target_count=1) + mock_make_zmq_socket.assert_called_once_with( + ctx=unittest.mock.ANY, + path="ipc:///tmp/vllm_mooncake_connector_test.ipc", + socket_type=zmq.PUSH, # type: ignore + bind=False) + mock_logger.info.assert_called_once_with( + "Connecting to transfer socket at %s", + "ipc:///tmp/vllm_mooncake_connector_test.ipc") + self.assertIsNone(tracker.listener) + self.assertEqual(tracker.socket, mock_socket) + + +class TestKVCacheTaskTrackerListenMethod(unittest.TestCase): + + def setUp(self): + self.tp_rank = 0 + self.local_engine_id = "test_engine_ut" + self.target_count = 3 + self.tracker = KVCacheTaskTracker(self.tp_rank, self.local_engine_id, + self.target_count) + self.original_listen = self.tracker._listen_for_completion_signals + + def tearDown(self): + self.tracker._listen_for_completion_signals = self.original_listen + Context.instance().term() + time.sleep(0.1) + + def test_normal_message_processing(self): + listener_thread = threading.Thread( + target=self.tracker._listen_for_completion_signals, daemon=True) + listener_thread.start() + time.sleep(0.2) + test_messages = [("request_001", 1), ("request_001", 2), + ("request_002", 0), ("request_003", 1)] + ctx = Context() + sender_socket = ctx.socket(zmq.PUSH) # type: ignore + sender_socket.connect(self.tracker.socket_path) + for msg in test_messages: + sender_socket.send_pyobj(msg) + time.sleep(0.05) + sender_socket.close() + time.sleep(0.2) + + with self.tracker.done_task_lock: + self.assertEqual(len(self.tracker.done_task_counts["request_001"]), + 2) + self.assertIn(1, self.tracker.done_task_counts["request_001"]) + self.assertIn(2, self.tracker.done_task_counts["request_001"]) + self.assertEqual(len(self.tracker.done_task_counts["request_002"]), + 1) + self.assertIn(0, self.tracker.done_task_counts["request_002"]) + self.assertEqual(len(self.tracker.done_task_counts["request_003"]), + 1) + self.assertIn(1, self.tracker.done_task_counts["request_003"]) + + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", + autospec=True) + def test_listen_with_timeout(self, mock_make_socket): + mock_socket = MagicMock() + + def mock_recv(): + start = time.time() + while time.time() - start < 0.5: + time.sleep(0.01) + return ("req1", 0) + + mock_socket.recv_pyobj = mock_recv + mock_make_socket.return_value = mock_socket + + test_thread = threading.Thread( + target=self.tracker._listen_for_completion_signals, daemon=True) + test_thread.start() + test_thread.join(timeout=1.0) + mock_make_socket.assert_called_once() + + +class TestKVCacheTaskTrackerTP(unittest.TestCase): + + def setUp(self): + self.local_engine_id = "test_engine" + self.target_count = 3 + + def test_update_done_task_count_tp_rank_0(self): + tracker = KVCacheTaskTracker(tp_rank=0, + local_engine_id=self.local_engine_id, + target_count=self.target_count) + test_request_id = "test_req_001" + test_tp_rank = 1 + tracker.update_done_task_count(test_request_id, test_tp_rank) + with tracker.done_task_lock: + self.assertEqual(len(tracker.done_task_counts[test_request_id]), 1) + self.assertIn(test_tp_rank, + tracker.done_task_counts[test_request_id]) + + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", + autospec=True) + def test_update_done_task_count_non_zero_tp(self, mock_make_socket): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + tracker = KVCacheTaskTracker(tp_rank=1, + local_engine_id=self.local_engine_id, + target_count=self.target_count) + test_request_id = "test_req_002" + test_tp_rank = 1 + tracker.update_done_task_count(test_request_id, test_tp_rank) + mock_socket.send_pyobj.assert_called_once_with( + (test_request_id, test_tp_rank)) + with tracker.done_task_lock: + self.assertNotIn(test_request_id, tracker.done_task_counts) + + @patch("vllm_ascend.distributed.mooncake_connector.logger", autospec=True) + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", + autospec=True) + def test_update_done_task_count_logging(self, mock_make_socket, + mock_logger): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + tracker = KVCacheTaskTracker(tp_rank=2, + local_engine_id=self.local_engine_id, + target_count=self.target_count) + test_request_id = "test_req_003" + tracker.update_done_task_count(test_request_id, 2) + mock_logger.debug.assert_called_once_with( + "Sent done signal for request %s to tp 0", test_request_id) + + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket", + autospec=True) + def test_update_multiple_calls(self, mock_make_socket): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + tracker = KVCacheTaskTracker(tp_rank=1, + local_engine_id=self.local_engine_id, + target_count=self.target_count) + test_data = [("req1", 1), ("req1", 1), ("req2", 1)] + for req_id, rank in test_data: + tracker.update_done_task_count(req_id, rank) + self.assertEqual(mock_socket.send_pyobj.call_count, 3) + mock_socket.send_pyobj.assert_called_with(("req2", 1)) + + +class TestGetAndClearFinishedSingleRequests(unittest.TestCase): + + def setUp(self): + self.tracker = KVCacheTaskTracker(tp_rank=0, + local_engine_id="test", + target_count=3) + self.tracker.finished_requests = set() + self.tracker.done_task_counts = defaultdict(set) + self.tracker.done_task_lock = threading.Lock() + + def test_empty_requests(self): + result = self.tracker.get_and_clear_finished_requests() + self.assertEqual(result, set()) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_single_request(self): + self.tracker.finished_requests = {"req_123"} + result = self.tracker.get_and_clear_finished_requests() + self.assertEqual(result, {"req_123"}) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_multiple_requests(self): + self.tracker.finished_requests = {"req_1", "req_2", "req_3"} + result = self.tracker.get_and_clear_finished_requests() + self.assertSetEqual(result, {"req_1", "req_2", "req_3"}) + self.assertEqual(len(self.tracker.finished_requests), 0) + + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_concurrent_access(self, mock_logger): + from concurrent.futures import ThreadPoolExecutor + self.tracker.finished_requests = {"req_1", "req_2"} + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [ + executor.submit(self.tracker.get_and_clear_finished_requests) + for _ in range(3) + ] + results = [f.result() for f in futures] + self.assertEqual(sum(1 for r in results if r), 1) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_after_increment(self): + self.tracker._increment_task_count("req_123", 0) + self.tracker._increment_task_count("req_123", 1) + self.tracker._increment_task_count("req_123", 2) + result = self.tracker.get_and_clear_finished_requests() + self.assertEqual(result, {"req_123"}) + self.assertEqual(self.tracker.get_and_clear_finished_requests(), set()) + + +class TestKVCacheSendingThreadInit(unittest.TestCase): + + def setUp(self): + self.common_args = { + 'tp_rank': 1, + 'decode_tp_size': 4, + 'local_engine_id': 'engine_1', + 'side_channel_host': 'localhost', + 'side_channel_port': 5555, + 'metadata': MagicMock(), + 'ready_event': threading.Event() + } + self.threads = [] + + def tearDown(self): + for thread in self.threads: + if hasattr(thread, 'task_tracker') and hasattr( + thread.task_tracker, 'socket'): + thread.task_tracker.socket.close() + if hasattr(thread, 'is_alive') and thread.is_alive(): + thread.join(timeout=0.1) + + @patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker') + def test_initialization_basic(self, mock_tracker): + thread = KVCacheSendingThread(**self.common_args) + self.threads.append(thread) + self.assertEqual(thread.tp_rank, 1) + self.assertEqual(thread.decode_tp_size, 4) + self.assertEqual(thread.local_engine_id, 'engine_1') + mock_tracker.assert_called_once() + args = mock_tracker.call_args[0] + kwargs = mock_tracker.call_args[1] + if args: + self.assertEqual(args[0], 1) + self.assertEqual(args[1], 'engine_1') + self.assertEqual(args[2], 4) + else: + self.assertEqual(kwargs['tp_rank'], 1) + self.assertEqual(kwargs['local_engine_id'], 'engine_1') + self.assertEqual(kwargs['target_count'], 4) + + @patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker') + def test_task_tracker_initialization(self, mock_tracker): + args = self.common_args.copy() + args.update({ + 'tp_rank': 2, + 'decode_tp_size': 8, + 'local_engine_id': 'engine_2' + }) + thread = KVCacheSendingThread(**args) + self.threads.append(thread) + mock_tracker.assert_called_once() + call_args = mock_tracker.call_args[0] + call_kwargs = mock_tracker.call_args[1] + if call_args: + self.assertEqual(call_args[0], 2) + self.assertEqual(call_args[1], 'engine_2') + self.assertEqual(call_args[2], 8) + else: + self.assertEqual(call_kwargs['tp_rank'], 2) + self.assertEqual(call_kwargs['local_engine_id'], 'engine_2') + self.assertEqual(call_kwargs['target_count'], 8) + + def test_thread_daemon_property(self): + thread = KVCacheSendingThread(**self.common_args) + self.threads.append(thread) + self.assertTrue(thread.daemon) + + def test_thread_name_format(self): + thread = KVCacheSendingThread(**self.common_args) + self.threads.append(thread) + self.assertEqual(thread.name, "KVCacheSendingThread") + + def test_ready_event_reference(self): + custom_event = threading.Event() + args = self.common_args.copy() + args['ready_event'] = custom_event + thread = KVCacheSendingThread(**args) + self.threads.append(thread) + self.assertIs(thread.ready_event, custom_event) + + +class TestGetAndClearFinishedRequests(unittest.TestCase): + + def setUp(self): + self.common_args = { + 'tp_rank': 1, + 'decode_tp_size': 4, + 'local_engine_id': 'engine_1', + 'side_channel_host': 'localhost', + 'side_channel_port': 5555, + 'metadata': { + "test": "metadata" + }, + 'ready_event': threading.Event() + } + self.thread = KVCacheSendingThread(**self.common_args) + + @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') + def test_get_and_clear_finished_requests(self, mock_get_clear): + expected_requests = {'req1', 'req2'} + mock_get_clear.return_value = expected_requests + result = self.thread.get_and_clear_finished_requests() + mock_get_clear.assert_called_once() + self.assertEqual(result, expected_requests) + + +class TestKVCacheSendingThread(unittest.TestCase): + + def test_run_handles_get_meta_and_done_recv_msgs(self): + ready_event = threading.Event() + metadata = MooncakeAgentMetadata( + engine_id="engine1", + te_rpc_port=9090, + kv_caches_base_addr=[12345678], + num_blocks=2, + ) + host = "127.0.0.1" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + free_port = s.getsockname()[1] + + thread = KVCacheSendingThread( + tp_rank=0, + decode_tp_size=1, + local_engine_id="engine1", + side_channel_host=host, + side_channel_port=free_port, + metadata=metadata, + ready_event=ready_event, + ) + thread.start() + self.assertTrue(ready_event.wait(timeout=3), + "Server thread startup timeout") + + context = zmq.Context() # type: ignore + sock = context.socket(zmq.DEALER) # type: ignore + sock.connect(f"tcp://{host}:{free_port}") + encoder = msgspec.msgpack.Encoder() + decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata) + + sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))]) + frames = sock.recv_multipart() + self.assertEqual(frames[0], b"") + meta = decoder.decode(frames[1]) + self.assertEqual(meta.engine_id, "engine1") + self.assertEqual(meta.kv_caches_base_addr, [12345678]) + self.assertEqual(meta.num_blocks, 2) + + req_id = "request_42" + sock.send_multipart( + [b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))]) + frames = sock.recv_multipart() + self.assertEqual(frames[0], b"") + self.assertEqual(frames[1], b"ACK") + self.assertIn(req_id, thread.task_tracker.finished_requests) + + sock.close() + context.term() + + +class TestKVCacheRecvingThreadBasic(unittest.TestCase): + + def setUp(self): + self.engine = MagicMock() + self.ready_event = threading.Event() + self.thread = KVCacheRecvingThread( + tp_rank=0, + tp_size=4, + engine=self.engine, + local_engine_id="local_engine", + local_handshake_port=5555, + local_kv_caches_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + ready_event=self.ready_event) + + def test_add_request(self): + test_req = { + "request_id": "req1", + "local_block_ids": [1, 2], + "remote_block_ids": [3, 4], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_handshake_port": 6666, + } + self.thread.add_request(**test_req) + queued = self.thread.request_queue.get_nowait() + self.assertEqual(queued["request_id"], "req1") + self.assertEqual(queued["remote_host"], "localhost") + + @patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests') + def test_get_finished_requests(self, mock_tracker): + mock_tracker.return_value = {"req1", "req2"} + result = self.thread.get_and_clear_finished_requests() + self.assertEqual(result, {"req1", "req2"}) + + +class TestSocketManagement(unittest.TestCase): + + def setUp(self): + self.engine = MagicMock() + self.ready_event = threading.Event() + self.thread = KVCacheRecvingThread( + tp_rank=0, + tp_size=4, + engine=self.engine, + local_engine_id="local_engine", + local_handshake_port=5555, + local_kv_caches_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + ready_event=self.ready_event) + self.thread.remote_sockets = defaultdict(deque) + self.thread.remote_poller = MagicMock() + + @patch('vllm_ascend.distributed.mooncake_connector.zmq.Context') + @patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket') + def test_get_remote_socket(self, mock_make_socket, mock_context): + mock_sock = MagicMock() + mock_make_socket.return_value = mock_sock + test_host = "test_host" + test_port = 12345 + + sock = self.thread._get_remote_socket(test_host, test_port) + + self.assertEqual(sock, mock_sock) + mock_make_socket.assert_called_once() + args, kwargs = mock_make_socket.call_args + self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345') + self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore + self.assertFalse(kwargs.get('bind', True)) + self.thread.remote_poller.register.assert_called_with( + mock_sock, zmq.POLLIN) # type: ignore + + def test_return_socket_to_pool(self): + mock_sock = MagicMock() + test_host = "test_host" + test_port = 12345 + test_path = make_zmq_path("tcp", test_host, test_port) + + self.thread._return_remote_socket(mock_sock, test_host, test_port) + + self.assertEqual(len(self.thread.remote_sockets[test_path]), 1) + self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock) + self.thread.remote_poller.register.assert_not_called() + + +class TestCoreFunctionality(unittest.TestCase): + + def setUp(self): + self.engine = MagicMock() + self.ready_event = threading.Event() + self.mock_queue = MagicMock() + self.thread = KVCacheRecvingThread( + tp_rank=0, + tp_size=4, + engine=self.engine, + local_engine_id="local_engine", + local_handshake_port=5555, + local_kv_caches_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + ready_event=self.ready_event) + self.thread.request_queue = self.mock_queue + self.test_req = { + "request_id": "req1", + "local_block_ids": [1, 2], + "remote_block_ids": [3, 4], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_handshake_port": 6666, + "remote_transfer_port": 7777 + } + self.thread.task_tracker = MagicMock() + self.engine.batch_transfer_sync_read.return_value = 0 + self.thread.remote_te_port = {"remote_engine": {6666: 7777}} + + @patch.object(KVCacheRecvingThread, '_transfer_kv_cache') + @patch.object(KVCacheRecvingThread, '_send_done_recv_signal') + def test_handle_request(self, mock_send, mock_transfer): + self.thread._handle_request(self.test_req) + mock_transfer.assert_called_once_with(self.test_req) + mock_send.assert_called_once_with("req1", "localhost", 6666) + self.thread.task_tracker.update_done_task_count.assert_called_once_with( + "req1", self.thread.tp_rank) + self.mock_queue.task_done.assert_called_once() + + @patch.object(KVCacheRecvingThread, '_get_remote_metadata') + def test_transfer_kv_cache(self, mock_get_meta): + self.thread.kv_caches_base_addr["remote_engine"] = { + 6666: [0x3000, 0x4000] + } + + self.thread._transfer_kv_cache(self.test_req) + + self.engine.batch_transfer_sync_read.assert_called_once() + call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args + self.assertEqual(call_args[0], "localhost:7777") + self.assertIsInstance(call_args[1], list) + self.assertIsInstance(call_args[2], list) + self.assertIsInstance(call_args[3], list) + self.assertEqual(len(call_args[1]), len(call_args[2])) + self.assertEqual(len(call_args[1]), len(call_args[3])) + mock_get_meta.assert_not_called() + + def test_transfer_kv_cache_failure(self): + self.engine.batch_transfer_sync_read.return_value = -1 + self.thread.kv_caches_base_addr["remote_engine"] = { + 6666: [0x3000, 0x4000] + } + + with self.assertRaises(RuntimeError): + self.thread._transfer_kv_cache(self.test_req) + + +class TestMetadataHandling(unittest.TestCase): + + def setUp(self): + self.engine = MagicMock() + self.ready_event = threading.Event() + self.thread = KVCacheRecvingThread( + tp_rank=0, + tp_size=4, + engine=self.engine, + local_engine_id="local_engine", + local_handshake_port=5555, + local_kv_caches_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + ready_event=self.ready_event) + self.test_metadata = MooncakeAgentMetadata( + engine_id="remote_engine", + te_rpc_port=9090, + kv_caches_base_addr=[0x3000, 0x4000], + num_blocks=2) + + @patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send') + @patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv') + def test_get_remote_metadata_success(self, mock_recv, mock_send): + mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata) + + with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \ + patch.object(self.thread, '_return_remote_socket') as mock_return_socket: + mock_socket = MagicMock() + mock_get_socket.return_value = mock_socket + + self.thread._get_remote_metadata("host1", 5555) + + mock_get_socket.assert_called_once_with("host1", 5555) + mock_return_socket.assert_called_once_with(mock_socket, "host1", + 5555) + mock_send.assert_called_once_with( + mock_socket, self.thread.encoder.encode((GET_META_MSG, ""))) + mock_recv.assert_called_once_with(mock_socket, + self.thread.remote_poller) + self.assertEqual( + self.thread.kv_caches_base_addr["remote_engine"][5555], + [0x3000, 0x4000]) + + @patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send') + @patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv', + side_effect=Exception("Network error")) + def test_get_remote_metadata_failure(self, mock_recv, mock_send): + with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \ + patch.object(self.thread, '_return_remote_socket') as mock_return_socket: + mock_socket = MagicMock() + mock_get_socket.return_value = mock_socket + + with self.assertRaises(Exception) as context: + self.thread._get_remote_metadata("host1", 5555) + + self.assertEqual(str(context.exception), "Network error") + mock_return_socket.assert_called_once() + + +class TestMainThreadLoop(unittest.TestCase): + + def setUp(self): + self.engine = MagicMock() + self.ready_event = threading.Event() + self.thread = KVCacheRecvingThread( + tp_rank=0, + tp_size=4, + engine=self.engine, + local_engine_id="local_engine", + local_handshake_port=5555, + local_kv_caches_base_addr=[0x1000, 0x2000], + block_len=[1024, 2048], + ready_event=self.ready_event) + self.thread.request_queue = queue.Queue() + + @patch.object(KVCacheRecvingThread, '_handle_request') + def test_run_loop_normal(self, mock_handle): + test_request = { + "request_id": "req1", + "local_block_ids": [1, 2], + "remote_block_ids": [3, 4], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_handshake_port": 6666, + "remote_transfer_port": 7777 + } + + self.thread.request_queue.put(test_request) + self.thread.request_queue.put(None) + + self.thread.start() + time.sleep(0.1) + self.thread.join(timeout=1.0) + + self.assertTrue(self.thread.ready_event.is_set()) + mock_handle.assert_called_once_with(test_request) + self.assertTrue(self.thread.request_queue.empty()) + + +class MockVllmConfig: + + def __init__(self): + self.parallel_config = MagicMock() + self.cache_config = MagicMock() + self.kv_transfer_config = MagicMock() + self.parallel_config.tensor_parallel_size = 2 + self.parallel_config.data_parallel_rank_local = 0 + self.parallel_config.data_parallel_size_local = 1 + self.cache_config.block_size = 16 + self.kv_transfer_config.kv_port = 5000 + self.kv_transfer_config.kv_role = 'kv_producer' + self.kv_transfer_config.get_from_extra_config = MagicMock() + self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { + "prefill": { + "tp_size": 2, + "dp_size": 1 + }, + "decode": { + "tp_size": 2, + "dp_size": 1 + } + }.get(k, d) + + +class MockRequest: + + def __init__(self, + request_id, + prompt_token_ids=None, + kv_transfer_params=None, + status=None): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4] + self.kv_transfer_params = kv_transfer_params or {} + self.status = status or "running" + self.output_token_ids = [101, 102] + + +class TestKVCacheTaskTracker(unittest.TestCase): + + def setUp(self): + self.tracker = KVCacheTaskTracker(tp_rank=0, + local_engine_id="test_engine", + target_count=2) + + def test_update_task_count(self): + self.assertEqual(len(self.tracker.done_task_counts), 0) + self.assertEqual(len(self.tracker.finished_requests), 0) + + self.tracker.update_done_task_count("req1", 0) + self.tracker.update_done_task_count("req1", 1) + + self.assertEqual(len(self.tracker.finished_requests), 1) + self.assertTrue("req1" in self.tracker.finished_requests) + + finished = self.tracker.get_and_clear_finished_requests() + self.assertEqual(finished, {"req1"}) + self.assertEqual(len(self.tracker.finished_requests), 0) + + def test_duplicate_task_update(self): + self.tracker.update_done_task_count("req1", 0) + self.tracker.update_done_task_count("req1", 0) + self.tracker.update_done_task_count("req1", 1) + + finished = self.tracker.get_and_clear_finished_requests() + self.assertEqual(finished, {"req1"}) + + +class TestMooncakeConnectorMetadata(unittest.TestCase): + + def test_add_new_req(self): + meta = MooncakeConnectorMetadata() + meta.add_new_req(request_id="req1", + local_block_ids=[1, 2, 3], + kv_transfer_params={ + "remote_block_ids": [4, 5, 6], + "remote_engine_id": "remote_engine", + "remote_host": "localhost", + "remote_port": 5000 + }) + + self.assertEqual(len(meta.requests), 1) + req_meta = meta.requests["req1"] + self.assertIsInstance(req_meta, ReqMeta) + self.assertEqual(req_meta.local_block_ids, [1, 2, 3]) + self.assertEqual(req_meta.remote_block_ids, [4, 5, 6]) + self.assertEqual(req_meta.remote_engine_id, "remote_engine") + self.assertEqual(req_meta.remote_host, "localhost") + self.assertEqual(req_meta.remote_port, 5000) + + +class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase): + + def setUp(self): + config = MockVllmConfig() + self.scheduler = MooncakeConnectorScheduler(config, "test_engine") + + def test_get_num_new_matched_tokens(self): + request = MockRequest("req1") + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 0) + self.assertFalse(async_flag) + + request.kv_transfer_params = {"do_remote_prefill": True} + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 3) + self.assertTrue(async_flag) + + def test_build_connector_meta(self): + request = MockRequest("req1") + blocks_mock = MagicMock() + blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6] + self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6]) + request.kv_transfer_params = { + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "remote", + "remote_host": "localhost", + "remote_port": 5000 + } + + meta = self.scheduler.build_connector_meta(MagicMock()) + self.assertIsInstance(meta, MooncakeConnectorMetadata) + self.assertEqual(len(meta.requests), 1) + self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6]) + self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3]) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + + +class TestHelperFunctions(unittest.TestCase): + + def test_group_concurrent_contiguous(self): + src: list[int] = [1, 2, 3, 5, 6] + dst: list[int] = [10, 11, 12, 14, 15] + + src_groups, dst_groups = group_concurrent_contiguous(src, dst) + + self.assertEqual(len(src_groups), 2) + self.assertEqual(src_groups[0], [1, 2, 3]) + self.assertEqual(src_groups[1], [5, 6]) + self.assertEqual(dst_groups[0], [10, 11, 12]) + self.assertEqual(dst_groups[1], [14, 15]) + + def test_group_concurrent_contiguous_empty(self): + src: list[int] = [] + dst: list[int] = [] + src_groups, dst_groups = group_concurrent_contiguous(src, dst) + self.assertEqual(src_groups, []) + self.assertEqual(dst_groups, []) + + def test_string_to_int64_hash(self): + hash1 = string_to_int64_hash("test_string") + hash2 = string_to_int64_hash("test_string") + self.assertEqual(hash1, hash2) + + hash3 = string_to_int64_hash("different_string") + self.assertNotEqual(hash1, hash3) + + +class TestMooncakeConnectorForScheduler(unittest.TestCase): + + def test_scheduler_role(self): + config = MockVllmConfig() + connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) + self.assertIsNotNone(connector.connector_scheduler) + self.assertIsNone(connector.connector_worker) + + @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") + def test_scheduler_methods(self, mock_method): + config = MockVllmConfig() + connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.get_num_new_matched_tokens(request, 0) + mock_method.assert_called_once_with(request, 0) + + +class MockKVCacheBlocks: + + def get_unhashed_block_ids(self): + return [4, 5, 6] + + +class MockSchedulerOutput: + pass + + +class MockForwardContext: + pass + + +class TestMooncakeConnector(unittest.TestCase): + + def setUp(self): + self.config = MockVllmConfig() + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" + + def test_scheduler_initialization(self): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) + self.assertIsNotNone(connector.connector_scheduler) + self.assertIsNone(connector.connector_worker) + + @patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens") + def test_get_num_new_matched_tokens(self, mock_method): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.get_num_new_matched_tokens(request, 0) + mock_method.assert_called_once_with(request, 0) + + @patch.object(MooncakeConnectorScheduler, "update_state_after_alloc") + def test_update_state_after_alloc(self, mock_method): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + blocks = MockKVCacheBlocks() + connector.update_state_after_alloc(request, blocks, 3) + mock_method.assert_called_once_with(request, blocks, 3) + + @patch.object(MooncakeConnectorScheduler, "build_connector_meta") + def test_build_connector_meta(self, mock_method): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) + scheduler_output = MockSchedulerOutput() + connector.build_connector_meta(scheduler_output) + mock_method.assert_called_once_with(scheduler_output) + + @patch.object(MooncakeConnectorScheduler, "request_finished") + def test_request_finished(self, mock_method): + connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER) + request = MockRequest("req1") + connector.request_finished(request, [1, 2, 3]) + mock_method.assert_called_once_with(request, [1, 2, 3]) + + +class TestMooncakeConnectorScheduler(unittest.TestCase): + + def setUp(self): + self.config = MockVllmConfig() + self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine") + + def test_get_num_new_matched_tokens_no_remote_prefill(self): + request = MockRequest("req1") + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 0) + self.assertFalse(async_flag) + + def test_get_num_new_matched_tokens_with_remote_prefill(self): + request = MockRequest("req1", + kv_transfer_params={"do_remote_prefill": True}) + tokens, async_flag = self.scheduler.get_num_new_matched_tokens( + request, 0) + self.assertEqual(tokens, 3) + self.assertTrue(async_flag) + + def test_update_state_after_alloc_no_remote_prefill(self): + request = MockRequest("req1") + blocks = MagicMock() + self.scheduler.update_state_after_alloc(request, blocks, 0) + self.assertEqual(len(self.scheduler._reqs_need_recv), 0) + + def test_update_state_after_alloc_with_remote_prefill(self): + request = MockRequest("req1", + kv_transfer_params={ + "do_remote_prefill": True, + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "remote", + "remote_host": "localhost", + "remote_port": 5000 + }) + blocks = MockKVCacheBlocks() + self.scheduler.update_state_after_alloc(request, blocks, 3) + self.assertEqual(len(self.scheduler._reqs_need_recv), 1) + self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request) + self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6]) + + def test_request_finished_no_remote_decode(self): + request = MockRequest("req1") + delay_free, params = self.scheduler.request_finished( + request, [1, 2, 3]) + self.assertFalse(delay_free) + self.assertIsNone(params) + + +class TestUtils(unittest.TestCase): + + def test_string_to_int64_hash(self): + h1 = string_to_int64_hash("hello") + h2 = string_to_int64_hash("hello") + h3 = string_to_int64_hash("world") + self.assertEqual(h1, h2) + self.assertNotEqual(h1, h3) + self.assertIsInstance(h1, int) + + def test_group_concurrent_contiguous(self): + src: list[int] = [1, 2, 3, 5, 6] + dst: list[int] = [10, 11, 12, 20, 21] + src_g, dst_g = group_concurrent_contiguous(src, dst) + self.assertEqual(src_g, [[1, 2, 3], [5, 6]]) + self.assertEqual(dst_g, [[10, 11, 12], [20, 21]]) + + def test_group_empty(self): + src_g, dst_g = group_concurrent_contiguous([], []) + self.assertEqual(src_g, []) + self.assertEqual(dst_g, []) + + def test_zmq_ctx_invalid_type(self): + with self.assertRaises(ValueError): + with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"): + pass + + @patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket") + def test_zmq_ctx_ok(self, mock_make_socket): + mock_socket = MagicMock() + mock_make_socket.return_value = mock_socket + with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore + self.assertEqual(s, mock_socket) + + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_ensure_zmq_send_success(self, mock_logger): + mock_socket = MagicMock() + ensure_zmq_send(mock_socket, b"hello") + mock_socket.send.assert_called_once_with(b"hello") + + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_ensure_zmq_send_retry_and_fail(self, mock_logger): + mock_socket = MagicMock() + mock_socket.send.side_effect = zmq.ZMQError( # type: ignore + "send failed") + with self.assertRaises(RuntimeError): + ensure_zmq_send(mock_socket, b"hello", max_retries=2) + self.assertEqual(mock_socket.send.call_count, 2) + + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_ensure_zmq_recv_success(self, mock_logger): + mock_socket = MagicMock() + mock_socket.recv.return_value = b"response" + mock_poller = MagicMock() + mock_poller.poll.return_value = [ + (mock_socket, zmq.POLLIN) # type: ignore + ] + data = ensure_zmq_recv(mock_socket, mock_poller) + self.assertEqual(data, b"response") + + @patch("vllm_ascend.distributed.mooncake_connector.logger") + def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger): + mock_socket = MagicMock() + mock_poller = MagicMock() + mock_poller.poll.return_value = [] + with self.assertRaises(RuntimeError): + ensure_zmq_recv(mock_socket, + mock_poller, + timeout=0.01, + max_retries=2) + + +class MockMooncakeAgentMetadata: + + def __init__(self, **kwargs): + pass + + +class MockMooncakeConnectorMetadata: + + def __init__(self): + self.requests = {} + + +class MockKVCacheSendingThread(threading.Thread): + + def __init__(self, *args, **kwargs): + super().__init__() + self.daemon = True + self._finished_requests = set() + + def get_and_clear_finished_requests(self): + return self._finished_requests + + def start(self): + pass + + +class MockKVCacheRecvingThread(threading.Thread): + + def __init__(self, *args, **kwargs): + super().__init__() + self.daemon = True + self._finished_requests = set() + self.add_request = MagicMock() + + def get_and_clear_finished_requests(self): + return self._finished_requests + + def start(self): + pass + + +class MockTensor: + + def __init__(self, *args, **kwargs): + self.size = MagicMock(return_value=(10, 16, 8, 16)) + self.element_size = MagicMock(return_value=4) + self.shape = (10, 16, 8, 16) + self.data_ptr = MagicMock(return_value=0x1000) + + +mock_envs_ascend = MagicMock() +mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" + +mock_logger = MagicMock() + + +class MockTransferEngine: + + def initialize(self, *args, **kwargs): + return 0 + + def register_memory(self, *args, **kwargs): + return 1 + + +class MockEnvsAscend: + MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol" + + +def mock_get_tensor_model_parallel_rank(): + return 0 + + +def mock_get_tp_group(): + return MagicMock() + + +def mock_get_ip(): + return "127.0.0.1" + + +def mock_string_to_int64_hash(s): + return hash(s) + + +class TestMooncakeConnectorWorker(unittest.TestCase): + + def setUp(self): + self.envs_ascend_mock = MockEnvsAscend() + self.mock_transfer_engine = MagicMock() + self.mock_transfer_engine.get_rpc_port.return_value = 9090 + self.mock_transfer_engine.initialize.return_value = 0 + self.mock_transfer_engine.register_memory.return_value = 0 + + self.patches = [ + patch('os.getenv', return_value="0,1"), + patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), + patch('torch.Tensor.element_size', return_value=4), + patch('torch.Tensor.data_ptr', return_value=0x1000), + patch('math.prod', return_value=128), + patch('random.Random'), + patch( + 'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank', + mock_get_tensor_model_parallel_rank), + patch('vllm_ascend.distributed.mooncake_connector.get_tp_group', + mock_get_tp_group), + patch('vllm_ascend.distributed.mooncake_connector.get_ip', + mock_get_ip), + patch( + 'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash', + mock_string_to_int64_hash), + patch('vllm_ascend.distributed.mooncake_connector.TransferEngine', + return_value=self.mock_transfer_engine), + patch( + 'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread', + MagicMock()), + patch( + 'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread', + MagicMock()), + patch('vllm_ascend.distributed.mooncake_connector.logger', + MagicMock()), + patch('vllm_ascend.distributed.mooncake_connector.threading.Event', + MagicMock()), + patch.dict('sys.modules', + {'vllm_ascend.envs': self.envs_ascend_mock}), + patch('vllm_ascend.distributed.mooncake_connector.envs_ascend', + self.envs_ascend_mock), + ] + + for p in self.patches: + p.start() # type: ignore + + self.vllm_config = MockVllmConfig() + self.engine_id = "test_engine" + self.kv_caches = {"layer1": (MagicMock(), MagicMock())} + + def tearDown(self): + for p in self.patches: + p.stop() # type: ignore + + def test_register_kv_caches_producer(self): + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) + worker.register_kv_caches(self.kv_caches) + self.assertEqual(len(worker.kv_caches), 1) + self.assertIsNotNone(worker.kv_send_thread) + self.assertIsNone(worker.kv_recv_thread) + + def test_register_kv_caches_consumer(self): + self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer' + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) + worker.register_kv_caches(self.kv_caches) + self.assertIsNone(worker.kv_send_thread) + self.assertIsNotNone(worker.kv_recv_thread) + + def test_register_kv_caches_mla_case(self): + mla_cache1 = MagicMock() + mla_cache1.size.return_value = (10, 16, 1, 16) + mla_cache2 = MagicMock() + mla_cache2.size.return_value = (10, 16, 1, 8) + mla_caches = {"layer1": (mla_cache1, mla_cache2)} + + worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) + worker.register_kv_caches(mla_caches) + self.assertTrue(worker.use_mla) + self.assertEqual(len(worker.block_len), 2) + + +if __name__ == '__main__': + unittest.main() diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index ebe8694..458b814 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -22,3 +22,7 @@ KVConnectorFactory.register_connector( "LLMDataDistCMgrConnector", "vllm_ascend.distributed.llmdatadist_c_mgr_connector", "LLMDataDistCMgrConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector", + "MooncakeConnector") diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py new file mode 100644 index 0000000..e223877 --- /dev/null +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -0,0 +1,1050 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import hashlib +import math +import os +import queue +import random +import struct +import threading +import time +from collections import defaultdict, deque +from collections.abc import Iterator +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import msgspec +import numpy as np +import numpy.typing as npt +import torch +import zmq +from mooncake.engine import TransferEngine # type: ignore +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, + get_tp_group) +from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +GET_META_MSG = b"get_meta_msg" +DONE_RECVING_MSG = b"done_recving_msg" + + +class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True): + engine_id: str + te_rpc_port: int + kv_caches_base_addr: list[int] + num_blocks: int + + +@dataclass +class ReqMeta: + local_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str + + +class KVCacheTaskTracker: + + def __init__(self, tp_rank: int, local_engine_id: str, target_count: int): + super().__init__() + self.tp_rank = tp_rank + self.local_engine_id = local_engine_id + self.target_count = target_count + + self.done_task_lock = threading.Lock() + self.done_task_counts: defaultdict[str, set[int]] = defaultdict(set) + self.finished_requests: set[str] = set() + + self.socket_path = \ + f"ipc:///tmp/vllm_mooncake_connector_{self.local_engine_id}.ipc" + if tp_rank == 0: + self.listener = threading.Thread( + target=self._listen_for_completion_signals, + daemon=True, + name="KVCacheTaskTrackerListenerThread") + self.listener.start() + self.socket = None + else: + self.listener = None # type: ignore + self.socket = make_zmq_socket( + ctx=zmq.Context(), # type: ignore + path=self.socket_path, + socket_type=zmq.PUSH, # type: ignore + bind=False) + logger.info("Connecting to transfer socket at %s", + self.socket_path) + + def _listen_for_completion_signals(self): + socket = make_zmq_socket( + ctx=zmq.Context(), # type: ignore + path=self.socket_path, + socket_type=zmq.PULL, # type: ignore + bind=True) + logger.info("Listening for completion signals on %s", self.socket_path) + + while True: + try: + done_request_id, tp_rank = socket.recv_pyobj() + logger.debug("Received completion notification for request: " + f"{done_request_id} from tp rank {tp_rank}") + self._increment_task_count(done_request_id, tp_rank) + except Exception as e: + logger.error(f"Error in run_busy_loop: {e}") + + def update_done_task_count(self, request_id: str, tp_rank: int): + if self.tp_rank == 0: + self._increment_task_count(request_id, tp_rank) + else: + self.socket.send_pyobj((request_id, tp_rank)) # type: ignore + logger.debug("Sent done signal for request %s to tp 0", request_id) + + def _increment_task_count(self, request_id: str, tp_rank: int): + with self.done_task_lock: + if tp_rank in self.done_task_counts[request_id]: + logger.warning( + f"Received duplicate done signal for request {request_id} " + f"from tp rank {tp_rank}. Ignoring.") + return + + self.done_task_counts[request_id].add(tp_rank) + if len(self.done_task_counts[request_id]) == self.target_count: + self.finished_requests.add(request_id) + self.done_task_counts.pop(request_id) + logger.info("All transfers completed for request: " + f"{request_id}. Total ranks: " + f"{self.target_count}.") + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + +class KVCacheSendingThread(threading.Thread): + + def __init__(self, tp_rank: int, decode_tp_size: int, local_engine_id: str, + side_channel_host: str, side_channel_port: int, + metadata: MooncakeAgentMetadata, + ready_event: threading.Event): + super().__init__(daemon=True, name="KVCacheSendingThread") + self.tp_rank = tp_rank + self.decode_tp_size = decode_tp_size + self.local_engine_id = local_engine_id + self.side_channel_host = side_channel_host + self.side_channel_port = side_channel_port + self.metadata = metadata + self.ready_event = ready_event + + self.task_tracker = KVCacheTaskTracker(self.tp_rank, + self.local_engine_id, + self.decode_tp_size) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + return self.task_tracker.get_and_clear_finished_requests() + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + + encoder = msgspec.msgpack.Encoder() + encoded_data = encoder.encode(self.metadata) + size_in_bytes = len(encoded_data) + logger.debug("Size of encoded MooncakeAgentMetadata: %s bytes", + str(size_in_bytes)) + + # Listen for new requests for metadata. + # NOTE(rob): we need each rank to have a unique port. This hack to keeps + # us moving. We will switch when moving to etcd or where we have a + # single ZMQ socket in the scheduler. + handshake_port = self.side_channel_port + self.tp_rank + path = make_zmq_path("tcp", self.side_channel_host, handshake_port) + logger.info("Starting listening on path: %s", path) + with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore + self.ready_event.set() + decoder = msgspec.msgpack.Decoder(type=tuple) + while True: + try: + frames = sock.recv_multipart() + if len(frames) < 2: + logger.error("Invalid message format: %s", frames) + continue + + identity = frames[0] + payload = [f for f in frames[1:] if f != b""] + if len(payload) != 1: + logger.error("Invalid message format: %s", frames) + continue + + msg = decoder.decode(payload[0]) + if msg[0] == GET_META_MSG: + sock.send_multipart((identity, b"", encoded_data)) + elif msg[0] == DONE_RECVING_MSG: + logger.debug("Got DONE_RECVING_MSG for request %s", + msg[1]) + request_id, decode_tp_rank = msg[1], msg[2] + self.task_tracker.update_done_task_count( + request_id, decode_tp_rank) + # Acknowledge the request completion. + while True: + try: + # Send ACK to the sender. + sock.send_multipart( + (identity, b"", b"ACK"), + flags=zmq.NOBLOCK) # type: ignore + break + except zmq.Again: # type: ignore + # If the socket is not ready, retry sending. + logger.debug( + "Socket not ready, retrying to send ACK for " + "request %s", msg[1]) + time.sleep(0.01) + else: + logger.error( + "Connection listener got unexpected message %s", + msg) + except Exception as e: + logger.error("Connection listener got exception %s: %s", + type(e), e) + + +class KVCacheRecvingThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, + local_engine_id: str, local_handshake_port: int, + local_kv_caches_base_addr: list[int], block_len: list[int], + ready_event: threading.Event): + super().__init__(daemon=True, name="KVCacheRecvingThread") + self.tp_rank = tp_rank + self.tp_size = tp_size + + self.local_engine_id = local_engine_id + self.local_handshake_port = local_handshake_port + self.engine = engine + self.ready_event = ready_event + + self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ + defaultdict(dict) + self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ + local_kv_caches_base_addr + self.remote_te_port: dict[str, dict[int, int]] = \ + defaultdict(dict) + self.block_len = block_len + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + + self.task_tracker = KVCacheTaskTracker(self.tp_rank, + self.local_engine_id, + self.tp_size) + + self.encoder = msgspec.msgpack.Encoder() + self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + self.remote_sockets_lock = threading.Lock() + self.remote_sockets: dict[ # type: ignore + str, deque[zmq.Socket]] = defaultdict( # type: ignore + deque) + self.remote_poller = zmq.Poller() # type: ignore + self.timeout = 1.0 # seconds + + def add_request(self, request_id: str, local_block_ids: list[int], + remote_block_ids: list[int], remote_engine_id: str, + remote_host: str, remote_handshake_port: int): + """Add a new request to the queue for processing.""" + logger.debug(f"Adding request {request_id} to the queue.") + self.request_queue.put({ + "request_id": request_id, + "local_block_ids": local_block_ids, + "remote_block_ids": remote_block_ids, + "remote_engine_id": remote_engine_id, + "remote_host": remote_host, + "remote_handshake_port": remote_handshake_port, + }) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + return self.task_tracker.get_and_clear_finished_requests() + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + request_id = req_meta["request_id"] + remote_host = req_meta["remote_host"] + remote_handshake_port = req_meta["remote_handshake_port"] + + try: + logger.debug( + f"Starting to transfer KV cache for request {request_id}.") + self._transfer_kv_cache(req_meta) + logger.debug( + f"Finished transferring KV cache for request {request_id}.") + except Exception as e: + logger.error("Failed to transfer KV cache for request " + f"{request_id}: {e}") + finally: + self.task_tracker.update_done_task_count(request_id, self.tp_rank) + # Always send the done signal to the remote host to ensure proper + # resource cleanup. Failing to do so may cause a memory leak on the + # remote host. + self._send_done_recv_signal(request_id, remote_host, + remote_handshake_port) + self.request_queue.task_done() + + def _transfer_kv_cache(self, req_meta: dict[str, Any]): + """Handle a KV cache transfer request.""" + request_id = req_meta["request_id"] + remote_block_ids = req_meta["remote_block_ids"] + local_block_ids = req_meta["local_block_ids"] + remote_engine_id = req_meta["remote_engine_id"] + remote_host = req_meta["remote_host"] + remote_handshake_port = req_meta["remote_handshake_port"] + + # Full prefix cache hit: do not need to read remote blocks, just notify + # P worker that we have the blocks we need. + if len(local_block_ids) == 0: + return + + # Check if we have the remote metadata cached. + if remote_engine_id not in self.kv_caches_base_addr or \ + remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: + self._get_remote_metadata(remote_host, remote_handshake_port) + + grouped_remote_block_ids, grouped_local_block_ids = \ + group_concurrent_contiguous(remote_block_ids, local_block_ids) + remote_kv_caches_base_addrs = \ + self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] + local_kv_caches_base_addrs = \ + self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port] + + req_start_time = time.perf_counter() + num_transfer_groups = len(grouped_remote_block_ids) + num_blocks = len(local_block_ids) + + remote_transfer_port = self.remote_te_port[remote_engine_id][ + remote_handshake_port] + session_id = f"{remote_host}:{remote_transfer_port}" + src_list, dst_list, length_list = [], [], [] + for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate( + zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): + block_len = (self.block_len[k % 2] + if self.use_mla else self.block_len[0]) + for i, remote_block_id in enumerate(grouped_remote_block_ids): + local_block_ids = grouped_local_block_ids[i] + src = src_layer_base_addr + local_block_ids[0] * block_len + dst = dst_layer_base_addr + remote_block_id[0] * block_len + length = len(local_block_ids) * block_len + src_list.append(src) + dst_list.append(dst) + length_list.append(length) + ret = self.engine.batch_transfer_sync_read(session_id, src_list, + dst_list, length_list) + if ret < 0: + logger.error("Mooncake transfer failed for request %s", + req_meta["request_id"]) + raise RuntimeError(f"Mooncake transfer failed, ret: {ret}") + + req_end_time = time.perf_counter() + req_transfer_elapsed = (req_end_time - req_start_time) * 1000 + logger.info( + "KV cache transfer for request %s took %.2f ms (%d groups," + " %d blocks).", request_id, req_transfer_elapsed, + num_transfer_groups, num_blocks) + + def _get_remote_metadata(self, remote_host: str, + remote_handshake_port: int) -> None: + """Get the metadata from the remote host.""" + sock: Optional[zmq.Socket] = None # type: ignore + try: + sock = self._get_remote_socket(remote_host, remote_handshake_port) + ensure_zmq_send(sock, self.encoder.encode((GET_META_MSG, ""))) + metadata_bytes = ensure_zmq_recv(sock, self.remote_poller) + agent_meta = self.decoder.decode(metadata_bytes) + engine_id = agent_meta.engine_id + assert engine_id != self.local_engine_id, ( + f"Conflict engine id {engine_id} with local engine id " + f"{self.local_engine_id}.") + self.kv_caches_base_addr[engine_id][remote_handshake_port] = \ + agent_meta.kv_caches_base_addr + self.remote_te_port[engine_id][remote_handshake_port] = \ + agent_meta.te_rpc_port + finally: + if sock is not None: + self._return_remote_socket(sock, remote_host, + remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, + remote_handshake_port) + + def _send_done_recv_signal(self, request_id: str, remote_host: str, + remote_handshake_port: int): + logger.debug("Sending done recving signal for request %s to %s:%d", + request_id, remote_host, remote_handshake_port) + sock: Optional[zmq.Socket] = None # type: ignore + try: + sock = self._get_remote_socket(remote_host, remote_handshake_port) + data_bytes = self.encoder.encode( + (DONE_RECVING_MSG, request_id, self.tp_rank)) + ensure_zmq_send(sock, data_bytes) + resp = ensure_zmq_recv(sock, + self.remote_poller, + timeout=self.timeout) + logger.debug( + f"Received response for request {request_id}: {resp.decode('utf-8')}" + ) + if resp != b"ACK": + logger.error("Failed to receive ACK for request %s from %s:%d", + request_id, remote_host, remote_handshake_port) + raise RuntimeError( + f"Failed to receive ACK, resp: {resp.decode('utf-8')}") + finally: + if sock is not None: + self._return_remote_socket(sock, remote_host, + remote_handshake_port) + logger.debug("Returned socket to pool for %s:%d", remote_host, + remote_handshake_port) + + def _get_remote_socket( + self, remote_host: str, + remote_handshake_port: int) -> zmq.Socket: # type: ignore + """Get a socket to the remote host.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + if self.remote_sockets[remote_path]: + return self.remote_sockets[remote_path].popleft() + + ctx = zmq.Context() # type: ignore + sock = make_zmq_socket( + ctx=ctx, + path=remote_path, + socket_type=zmq.REQ, # type: ignore + bind=False) + sock.setsockopt( + zmq.SNDTIMEO, # type: ignore + int(self.timeout * 1000)) + self.remote_poller.register(sock, zmq.POLLIN) # type: ignore + return sock + + def _return_remote_socket( + self, + sock: zmq.Socket, # type: ignore + remote_host: str, + remote_handshake_port: int) -> None: + """Return the remote socket to the pool.""" + remote_path = make_zmq_path("tcp", remote_host, remote_handshake_port) + with self.remote_sockets_lock: + self.remote_sockets[remote_path].append(sock) + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + + def __init__(self): + self.requests: dict[str, ReqMeta] = {} + + def add_new_req( + self, + request_id: str, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + ): + self.requests[request_id] = ReqMeta( + local_block_ids=local_block_ids, + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + + +class MooncakeConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + assert vllm_config.kv_transfer_config is not None + self.engine_id = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: Optional[MooncakeConnectorScheduler] = \ + MooncakeConnectorScheduler(vllm_config, str(self.engine_id)) + self.connector_worker: Optional[MooncakeConnectorWorker] = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeConnectorWorker( + vllm_config, str(self.engine_id)) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeConnector does not do layerwise saving.""" + pass + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeConnector does not save explicitly.""" + pass + + def wait_for_save(self): + """MooncakeConnector does not save explicitly.""" + pass + + +class MooncakeConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.engine_id = engine_id + logger.info("Initializing Mooncake Scheduler %s", engine_id) + + self.side_channel_host = get_ip() + self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ + vllm_config.parallel_config.data_parallel_size + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + + # Requests that need to start recv. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + assert num_computed_tokens == 0, "Currently only support " \ + "prefill with num_computed_tokens == 0." + # Assume that the request's KV cache is already fully prefilled and + # can be fetched entirely from the prefill node. + count = max(len(request.prompt_token_ids) - 1, 0) + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, params) + + if params is not None and params.get("do_remote_prefill"): + if params.get("remote_block_ids"): + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + local_block_ids = (blocks.get_unhashed_block_ids() + if num_external_tokens > 0 else []) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = ( + request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", params) + else: + assert num_external_tokens == 0 + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeConnectorMetadata() + + # Loop through scheduled reqs and convert to ReqMeta. + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + # For the case where there are no remote blocks to pull + # (block_ids is empty), we don't need to schedule + # an async read on the worker side. + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + + # Clear the list once workers start the transfers + self._reqs_need_recv.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", request.status, params) + + if (params is None or not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): + return False, None + + computed_block_ids = block_ids + delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(computed_block_ids), request.request_id) + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_block_ids=computed_block_ids, + remote_engine_id=self.engine_id, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + + +class MooncakeConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self._get_prefill_decode_size(vllm_config) + if self._prefill_tp_size < self._decode_tp_size: + raise ValueError( + f"prefill_tp_size: {self._prefill_tp_size} must be greater than" + f" or equal to the decode_tp_size: {self._decode_tp_size}") + + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + logger.info("Initializing Mooncake work %s", engine_id) + self.engine = TransferEngine() + + # Metadata. + self.engine_id = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = vllm_config.parallel_config.tensor_parallel_size + self.tp_group = get_tp_group() + self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.dp_size = vllm_config.parallel_config.data_parallel_size_local + self.kv_caches: dict[str, torch.Tensor] = {} + self.side_channel_host = get_ip() + self.max_device_id = self.tp_size * self.dp_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + + # Handshake base port + self.side_channel_port = ( + vllm_config.kv_transfer_config.kv_port + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + self.handshake_port = self.side_channel_port + self.tp_rank + self.sockets: dict = {} + + # get tp device id + # TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940 + # introducing some changes + device_ids_str = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) + if device_ids_str is None: + device_ids = list( + range(self.dp_rank * self.tp_size, + (self.dp_rank + 1) * self.tp_size)) + else: + device_ids = list(map(int, device_ids_str.split(','))) + assert len(device_ids) > self.tp_rank # type: ignore + self.device_id = device_ids[self.tp_rank] # type: ignore + + self._initialize( + hostname=self.side_channel_host + ':' + '0' + ':' + 'npu_' \ + + str(self.device_id), + device_name=None) + self.te_rpc_port = self.engine.get_rpc_port() + + # Background thread for sending or receiving KV caches. + self.kv_send_thread: Optional[KVCacheSendingThread] = None + self.kv_recv_thread: Optional[KVCacheRecvingThread] = None + + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + def _get_prefill_decode_size(self, vllm_config: VllmConfig): + # get prefill tp and dp size from extra config + prefill_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "prefill", {}) + + assert "tp_size" in prefill_parallel_config.keys() + self._prefill_tp_size = prefill_parallel_config["tp_size"] + + assert "dp_size" in prefill_parallel_config.keys() + self._prefill_dp_size = prefill_parallel_config["dp_size"] + + # get decode tp and dp size from extra config + decode_parallel_config: dict[ + str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( + "decode", {}) + assert "tp_size" in decode_parallel_config.keys() + self._decode_tp_size = decode_parallel_config["tp_size"] + assert "dp_size" in decode_parallel_config.keys() + self._decode_dp_size = decode_parallel_config["dp_size"] + + def _initialize( + self, + hostname: str, + device_name: Optional[str], + ) -> None: + """Initialize the mooncake instance.""" + device_name = device_name if device_name is not None else "" + ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend", + device_name) + if ret_value != 0: + raise RuntimeError( + f"Mooncake initialization failed with ret_value: {ret_value}") + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data.""" + + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + self.use_mla = first_kv_cache_tuple[0].size( + -1) != first_kv_cache_tuple[1].size(-1) + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[i % 2] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + region_len = self.num_blocks * self.block_len[0] + kv_caches_base_addr.append(base_addr) + self._register(base_addr, region_len) + + # After KV Caches registered, start the sending or receiving thread. + metadata = MooncakeAgentMetadata( + engine_id=self.engine_id, + te_rpc_port=self.te_rpc_port, + kv_caches_base_addr=kv_caches_base_addr, + num_blocks=self.num_blocks, + ) + + ready_event = threading.Event() + if self.kv_role == 'kv_producer': + self.kv_send_thread = KVCacheSendingThread(self.tp_rank, + self._decode_tp_size, + self.engine_id, + self.side_channel_host, + self.side_channel_port, + metadata, ready_event) + self.kv_send_thread.start() + else: + self.kv_recv_thread = KVCacheRecvingThread( + self.tp_rank, self.tp_size, self.engine, self.engine_id, + self.handshake_port, kv_caches_base_addr, self.block_len, + ready_event) + self.kv_recv_thread.start() + ready_event.wait() + + def _register(self, ptr, length): + logger.info( + "Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, " + "block_lens=%s", ptr, length, self.num_blocks, self.block_len) + ret_value = self.engine.register_memory(ptr, length) + if ret_value != 0: + raise RuntimeError("Mooncake memory registration failed.") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_producer' else set()) + done_recving = ( + self.kv_recv_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_consumer' else set()) + if self.tp_rank == 0: + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d", len(done_sending), len(done_recving)) + return done_sending, done_recving + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + """Start loading KV blocks from remote engine.""" + for req_id, meta in metadata.requests.items(): + logger.debug( + "start_load_kv for request %s from remote engine %s. " + "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, + meta.remote_engine_id, len(meta.local_block_ids), + len(meta.remote_block_ids)) + + remote_handshake_port = meta.remote_port + \ + self._get_remote_tp_rank(req_id) + self.kv_recv_thread.add_request( # type: ignore[union-attr] + request_id=req_id, + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port, + ) + + def _get_remote_tp_rank(self, req_id: str) -> int: + return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] + + def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]: + if self._prefill_tp_size == self._decode_tp_size: + return list(range(self._prefill_tp_size)) + + seed = string_to_int64_hash(req_id) + rand = random.Random(seed) + sampled_nums = rand.sample(range(self._prefill_tp_size), + self._decode_tp_size) + return sampled_nums + + +@contextlib.contextmanager +def zmq_ctx(socket_type: Any, + addr: str) -> Iterator[zmq.Socket]: # type: ignore + """Context manager for a ZMQ socket""" + + if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore + raise ValueError(f"Unexpected socket type: {socket_type}") + + ctx: Optional[zmq.Context] = None # type: ignore + try: + ctx = zmq.Context() # type: ignore + yield make_zmq_socket(ctx=ctx, + path=addr, + socket_type=socket_type, + bind=socket_type == zmq.ROUTER) # type: ignore + finally: + if ctx is not None: + ctx.destroy(linger=0) + + +def group_concurrent_contiguous( + src: List[int], dst: List[int] +) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: + """Vectorised NumPy implementation.""" + src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64) + dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64) + + if src_indices.size == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) + | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def string_to_int64_hash(input_str): + """ + Hash the string using SHA-256 and convert it into an int64 integer. + """ + hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest() + trunked_bytes = hashed_bytes[:8] + uint64_value = struct.unpack(" 0: + logger.warning( + f"Send failed: {e}, retrying... ({retries_left} " + "attempts left)") + time.sleep(0.1) + else: + logger.error(f"Send failed after all retries: {e}") + raise RuntimeError(f"Failed to send data after {max_retries} " + f"retries: {e}") + + +def ensure_zmq_recv( + socket: zmq.Socket, # type: ignore + poller: zmq.Poller, # type: ignore + timeout: float = 1.0, + max_retries: int = 3) -> bytes: + retries_left = max_retries + while True: + try: + if dict(poller.poll(int(timeout * 1000))): # milliseconds + data = socket.recv() + return data + else: + raise zmq.ZMQError("Receive timeout") # type: ignore + except zmq.ZMQError as e: # type: ignore + retries_left -= 1 + if retries_left > 0: + logger.warning(f"Receive failed: {e}, retrying... " + f"({retries_left} attempts left)") + time.sleep(0.1) + else: + logger.error(f"Receive failed after all retries: {e}") + raise RuntimeError( + f"Failed to receive data after {max_retries} " + f"retries: {e}")