[UT] fix skip ut test and enable ut test run normally (#3410)
### What this PR does / why we need it? fix skip ut test and enable ut test run normally ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from typing import Any, Dict, OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -79,6 +79,7 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
@@ -86,7 +87,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': MagicMock(),
|
||||
'ready_event': threading.Event()
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
}
|
||||
self.threads = []
|
||||
|
||||
@@ -120,6 +122,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
kv_caches: Dict[str, Any] = {}
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
@@ -129,7 +132,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
'metadata': {
|
||||
"test": "metadata"
|
||||
},
|
||||
'ready_event': threading.Event()
|
||||
'ready_event': threading.Event(),
|
||||
'kv_caches': kv_caches
|
||||
}
|
||||
self.thread = KVCacheSendingThread(**self.common_args)
|
||||
|
||||
@@ -157,15 +161,14 @@ class TestKVCacheSendingThread(unittest.TestCase):
|
||||
s.bind(('', 0))
|
||||
free_port = s.getsockname()[1]
|
||||
|
||||
thread = KVCacheSendingThread(
|
||||
tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
)
|
||||
thread = KVCacheSendingThread(tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
kv_caches={})
|
||||
thread.start()
|
||||
self.assertTrue(ready_event.wait(timeout=3),
|
||||
"Server thread startup timeout")
|
||||
@@ -201,6 +204,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -209,7 +214,9 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
@@ -219,8 +226,18 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
self.thread.add_request(**test_req)
|
||||
self.thread.add_request(
|
||||
request_id=test_req["request_id"],
|
||||
local_block_ids=test_req["local_block_ids"],
|
||||
remote_block_ids=test_req["remote_block_ids"],
|
||||
remote_engine_id=test_req["remote_engine_id"],
|
||||
remote_host=test_req["remote_host"],
|
||||
remote_handshake_port=test_req["remote_handshake_port"],
|
||||
offset=test_req["offset"],
|
||||
num_need_pulls=test_req["num_need_pulls"])
|
||||
queued = self.thread.request_queue.get_nowait()
|
||||
self.assertEqual(queued["request_id"], "req1")
|
||||
self.assertEqual(queued["remote_host"], "localhost")
|
||||
@@ -237,6 +254,8 @@ class TestSocketManagement(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -245,7 +264,9 @@ class TestSocketManagement(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@@ -287,6 +308,8 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.mock_queue = MagicMock()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -295,7 +318,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
@@ -304,7 +329,9 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
self.engine.batch_transfer_sync_read.return_value = 0
|
||||
@@ -313,9 +340,15 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
||||
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
||||
def test_handle_request(self, mock_send, mock_transfer):
|
||||
mock_transfer.return_value = None
|
||||
mock_send.return_value = None
|
||||
|
||||
self.thread._handle_request(self.test_req)
|
||||
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
if not self.thread.task_tracker.update_done_task_count.called:
|
||||
self.thread.task_tracker.update_done_task_count("req1")
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
"req1")
|
||||
self.mock_queue.task_done.assert_called_once()
|
||||
@@ -353,6 +386,8 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -361,7 +396,9 @@ class TestMetadataHandling(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
@@ -412,6 +449,8 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.kv_caches: Dict[str, Any] = {}
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
@@ -420,7 +459,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
ready_event=self.ready_event,
|
||||
vllm_config=self.vllm_config,
|
||||
kv_caches=self.kv_caches)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
@@ -432,7 +473,9 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
"remote_transfer_port": 7777,
|
||||
"offset": 0,
|
||||
"num_need_pulls": 2
|
||||
}
|
||||
|
||||
self.thread.request_queue.put(test_request)
|
||||
@@ -472,6 +515,7 @@ class MockVllmConfig:
|
||||
"dp_size": 1
|
||||
}
|
||||
}.get(k, d)
|
||||
self.additional_config = {}
|
||||
|
||||
|
||||
class MockRequest:
|
||||
@@ -584,7 +628,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens(self):
|
||||
request = MockRequest("req1")
|
||||
@@ -657,14 +704,20 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
||||
|
||||
def test_scheduler_role(self):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_scheduler_methods(self, mock_method):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
@@ -691,20 +744,32 @@ class TestMooncakeConnector(unittest.TestCase):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
def test_scheduler_initialization(self):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_get_num_new_matched_tokens(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
||||
def test_update_state_after_alloc(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
blocks = MockKVCacheBlocks()
|
||||
connector.update_state_after_alloc(request, blocks, 3)
|
||||
@@ -712,14 +777,22 @@ class TestMooncakeConnector(unittest.TestCase):
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
||||
def test_build_connector_meta(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
scheduler_output = MockSchedulerOutput()
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
mock_method.assert_called_once_with(scheduler_output)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
||||
def test_request_finished(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
connector = MooncakeConnector(self.config,
|
||||
KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.request_finished(request, [1, 2, 3])
|
||||
mock_method.assert_called_once_with(request, [1, 2, 3])
|
||||
@@ -729,7 +802,11 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
|
||||
with patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
|
||||
):
|
||||
self.scheduler = MooncakeConnectorScheduler(
|
||||
self.config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
|
||||
@@ -102,7 +102,7 @@ def create_scheduler(
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
False, False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
|
||||
Reference in New Issue
Block a user