[Refactor]Refactor of vllm_ascend/distributed module (#5719)

### What this PR does / why we need it?
Based on the RFC:https://github.com/vllm-project/vllm-ascend/issues/5604

This PR is a refactoring of vllm_ascend/distributed, moving all
kv_transfer realtaed codes into a dedicated folder, which has already
been done in vLLM

### Does this PR introduce _any_ user-facing change?
NA

### How was this patch tested?


- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: lty <linhebiwen@gmail.com>
This commit is contained in:
lty
2026-01-15 08:57:40 +08:00
committed by GitHub
parent f34b3b8ee9
commit 295018ec0f
56 changed files with 300 additions and 293 deletions

View File

@@ -23,21 +23,24 @@ _mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
_mock_pcp_group = MagicMock(rank_in_group=0, world_size=1)
_mock_dcp_group = MagicMock(rank_in_group=0, world_size=1)
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
return_value=_mock_pp_group).start()
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
return_value=_mock_tp_group).start()
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_world_size',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group',
return_value=_mock_pp_group).start()
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group',
return_value=_mock_tp_group).start()
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_world_size',
return_value=4).start()
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank',
return_value=0).start()
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
return_value=_mock_pcp_group).start()
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pcp_group',
return_value=_mock_pcp_group).start()
patch('vllm.distributed.parallel_state._DCP', _mock_dcp_group).start()
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import ( # noqa: E402
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
@@ -81,7 +84,8 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
self.assertEqual(len(self.tracker.finished_requests), 0)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger")
def test_concurrent_access(self, mock_logger):
from concurrent.futures import ThreadPoolExecutor
self.tracker.finished_requests = {"req_1", "req_2"}
@@ -307,8 +311,12 @@ class TestSocketManagement(unittest.TestCase):
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.zmq.Context'
)
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket'
)
def test_get_remote_socket(self, mock_make_socket, mock_context):
mock_sock = MagicMock()
mock_make_socket.return_value = mock_sock
@@ -402,7 +410,7 @@ class TestCoreFunctionality(unittest.TestCase):
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
def test_transfer_kv_cache(self, mock_get_meta):
with patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config'
) as mock_config:
mock_config.return_value.enable_kv_nz = False
self.thread.kv_caches_base_addr["remote_engine"] = {
@@ -456,8 +464,12 @@ class TestMetadataHandling(unittest.TestCase):
kv_caches_base_addr=[0x3000, 0x4000],
num_blocks=2)
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send'
)
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv'
)
def test_get_remote_metadata_success(self, mock_recv, mock_send):
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
@@ -479,9 +491,12 @@ class TestMetadataHandling(unittest.TestCase):
self.thread.kv_caches_base_addr["remote_engine"][5555],
[0x3000, 0x4000])
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
side_effect=Exception("Network error"))
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_send'
)
@patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.ensure_zmq_recv',
side_effect=Exception("Network error"))
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
@@ -694,10 +709,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self):
config = MockVllmConfig()
self.p1 = patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config',
new=MagicMock())
self.p2 = patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
new=MagicMock(return_value=MagicMock()))
self.p1.start()
self.p2.start()
@@ -775,9 +790,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self):
config = MockVllmConfig()
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
@@ -787,9 +802,9 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_methods(self, mock_method):
config = MockVllmConfig()
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
@@ -819,9 +834,9 @@ class TestMooncakeConnector(unittest.TestCase):
def test_scheduler_initialization(self):
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
@@ -831,9 +846,9 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method):
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
@@ -844,9 +859,9 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method):
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
@@ -858,9 +873,9 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method):
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
@@ -871,9 +886,9 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method):
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
@@ -887,9 +902,9 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.init_ascend_config'
), patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()):
self.scheduler = MooncakeConnectorScheduler(
self.config, "test_engine")
@@ -965,20 +980,24 @@ class TestUtils(unittest.TestCase):
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
pass
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.make_zmq_socket"
)
def test_zmq_ctx_ok(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
self.assertEqual(s, mock_socket)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger")
def test_ensure_zmq_send_success(self, mock_logger):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
mock_socket.send.assert_called_once_with(b"hello")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger")
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
@@ -987,7 +1006,8 @@ class TestUtils(unittest.TestCase):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger")
def test_ensure_zmq_recv_success(self, mock_logger):
mock_socket = MagicMock()
mock_socket.recv.return_value = b"response"
@@ -998,7 +1018,8 @@ class TestUtils(unittest.TestCase):
data = ensure_zmq_recv(mock_socket, mock_poller)
self.assertEqual(data, b"response")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger")
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_poller = MagicMock()
@@ -1106,35 +1127,40 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
patch('torch.Tensor.data_ptr', return_value=0x1000),
patch('math.prod', return_value=128),
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tensor_model_parallel_rank',
mock_get_tensor_model_parallel_rank),
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
return_value=_mock_pp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
mock_get_ip),
patch(
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_pp_group',
return_value=_mock_pp_group),
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ip',
mock_get_ip),
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash),
patch(
'vllm_ascend.distributed.mooncake_connector.global_te.get_transfer_engine',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.get_transfer_engine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.global_te.register_buffer',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.global_te.register_buffer',
return_value=None),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheSendingThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.KVCacheRecvingThread',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.logger',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.logger',
MagicMock()),
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.threading.Event',
MagicMock()),
patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()),
]
@@ -1186,7 +1212,8 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int,
decode_tp_size: int, num_kv_heads: int,
tp_num_need_pulls: int, is_deepseek_mla: bool):
with patch('vllm_ascend.distributed.mooncake_connector.get_ascend_config',
with patch(
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector.get_ascend_config',
return_value=MagicMock()), \
patch.object(self.vllm_config.kv_transfer_config, 'get_from_extra_config',
side_effect=lambda k, d=None: {

View File

@@ -15,7 +15,7 @@ fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
from vllm_ascend.distributed.mooncake_layerwise_connector import ( # noqa: E402
from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector import ( # noqa: E402
KVCacheRecvingLayerThread, KVCacheSendingLayerThread, KVConnectorRole,
MooncakeAgentMetadata, MooncakeLayerwiseConnector,
MooncakeLayerwiseConnectorMetadata, MooncakeLayerwiseConnectorScheduler,
@@ -81,19 +81,20 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
chunk_finish=False)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.npu_stream_switch",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.npu_stream_switch",
side_effect=lambda *_args, **_kwargs: contextlib.nullcontext())
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.Tensor.data_ptr",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.Tensor.data_ptr",
autospec=True,
return_value=0x200000)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.align_memory",
side_effect=lambda x, _align: x)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.align_memory",
side_effect=lambda x, _align: x)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize"
)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous"
)
def test_transfer_pd_gt1_uses_buffers_and_calls_engine(
self, mock_group, _mock_sync, _mock_align, _mock_dataptr,
@@ -171,10 +172,10 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
self.engine.batch_transfer_sync_write.assert_not_called()
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous",
side_effect=group_concurrent_contiguous)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.torch.npu.synchronize"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.torch.npu.synchronize"
)
def test_callback_invoked_on_final_layer(self, _mock_sync, _mock_group):
@@ -250,21 +251,27 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
self.assertNotIn("reqX", th.task_tracker)
self.assertIn("reqX", th.done_requests)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket")
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_path",
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_path",
side_effect=lambda proto, host, port: f"{proto}://{host}:{port}")
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx"
)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.zmq_ctx")
def test_run_loop_handles_meta_done_invalid_unexpected_and_ack(
self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_make_path,
_mock_make_sock, _mock_get_ip, mock_logger):
@@ -330,16 +337,21 @@ class TestKVCacheRecvingLayerThread(unittest.TestCase):
finished = th.get_and_clear_finished_requests()
self.assertIn("reqA", finished)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip",
return_value="127.0.0.1")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Decoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.msgspec.msgpack.Encoder"
)
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.zmq_ctx"
)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.zmq_ctx")
def test_run_loop_pd_head_ratio_gt1_requires_multiple_done(
self, mock_zmq_ctx, mock_Encoder, mock_Decoder, _mock_get_ip,
_mock_logger):
@@ -623,7 +635,7 @@ class TestMooncakeLayerwiseConnectorScheduler_More(unittest.TestCase):
self.assertEqual(len(meta.requests), 0)
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous"
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.group_concurrent_contiguous"
)
def test_build_connector_meta_emits_when_tokens_reach_total(
self, mock_group_concurrent_contiguous):
@@ -707,20 +719,25 @@ class TestHelperFunctions(unittest.TestCase):
pass
@patch(
"vllm_ascend.distributed.mooncake_layerwise_connector.make_zmq_socket")
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.make_zmq_socket"
)
def test_zmq_ctx_ok(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
self.assertEqual(s, mock_socket)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
def test_ensure_zmq_send_success(self, _):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
mock_socket.send.assert_called_once_with(b"hello")
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
def test_ensure_zmq_send_retry_and_fail(self, _):
mock_socket = MagicMock()
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
@@ -729,7 +746,9 @@ class TestHelperFunctions(unittest.TestCase):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
def test_ensure_zmq_recv_success(self, _):
mock_socket = MagicMock()
mock_socket.recv.return_value = b"response"
@@ -740,7 +759,9 @@ class TestHelperFunctions(unittest.TestCase):
data = ensure_zmq_recv(mock_socket, mock_poller)
self.assertEqual(data, b"response")
@patch("vllm_ascend.distributed.mooncake_layerwise_connector.logger")
@patch(
"vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger"
)
def test_ensure_zmq_recv_timeout_and_fail(self, _):
mock_socket = MagicMock()
mock_poller = MagicMock()
@@ -849,37 +870,37 @@ class TestMooncakeLayerwiseConnectorWorker(unittest.TestCase):
patch('math.prod', return_value=128),
patch('random.Random'),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_tensor_model_parallel_rank',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tensor_model_parallel_rank',
return_value=0),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_tp_group',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_tp_group',
return_value=None),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ip',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ip',
return_value="127.0.0.1"),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.string_to_int64_hash',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.string_to_int64_hash',
side_effect=lambda s: hash(s)),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.get_transfer_engine',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.get_transfer_engine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.global_te.register_buffer',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.global_te.register_buffer',
return_value=None),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheSendingLayerThread',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheSendingLayerThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.KVCacheRecvingLayerThread',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.KVCacheRecvingLayerThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.logger',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.logger',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.threading.Event',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.threading.Event',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_layerwise_connector.get_ascend_config',
'vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_layerwise_connector.get_ascend_config',
return_value=SimpleNamespace(pd_tp_ratio=1,
num_head_replica=1,
pd_head_ratio=1)),

View File

@@ -79,8 +79,7 @@ def create_vllm_config(
)
kv_transfer_config = KVTransferConfig(
kv_connector="MooncakeConnectorV1",
kv_role="kv_both",
kv_connector_module_path="vllm_ascend.distributed.mooncake_connector")
kv_role="kv_both")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,