[UT] fix skip ut test and enable ut test run normally (#3410)

### What this PR does / why we need it?

fix skip ut test and enable ut test run normally

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
zhangxinyuehfad
2025-10-20 16:30:57 +08:00
committed by GitHub
parent f8b52fe950
commit fdac146f71
6 changed files with 212 additions and 53 deletions

View File

@@ -7,7 +7,7 @@ import time
import types
import unittest
from collections import defaultdict, deque
from typing import OrderedDict
from typing import Any, Dict, OrderedDict
from unittest.mock import MagicMock, patch
import msgspec
@@ -79,6 +79,7 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
class TestKVCacheSendingThreadInit(unittest.TestCase):
def setUp(self):
kv_caches: Dict[str, Any] = {}
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
@@ -86,7 +87,8 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': MagicMock(),
'ready_event': threading.Event()
'ready_event': threading.Event(),
'kv_caches': kv_caches
}
self.threads = []
@@ -120,6 +122,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
class TestGetAndClearFinishedRequests(unittest.TestCase):
def setUp(self):
kv_caches: Dict[str, Any] = {}
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
@@ -129,7 +132,8 @@ class TestGetAndClearFinishedRequests(unittest.TestCase):
'metadata': {
"test": "metadata"
},
'ready_event': threading.Event()
'ready_event': threading.Event(),
'kv_caches': kv_caches
}
self.thread = KVCacheSendingThread(**self.common_args)
@@ -157,15 +161,14 @@ class TestKVCacheSendingThread(unittest.TestCase):
s.bind(('', 0))
free_port = s.getsockname()[1]
thread = KVCacheSendingThread(
tp_rank=0,
decode_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
)
thread = KVCacheSendingThread(tp_rank=0,
decode_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
kv_caches={})
thread.start()
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
@@ -201,6 +204,8 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
@@ -209,7 +214,9 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
def test_add_request(self):
test_req = {
@@ -219,8 +226,18 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase):
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"offset": 0,
"num_need_pulls": 2
}
self.thread.add_request(**test_req)
self.thread.add_request(
request_id=test_req["request_id"],
local_block_ids=test_req["local_block_ids"],
remote_block_ids=test_req["remote_block_ids"],
remote_engine_id=test_req["remote_engine_id"],
remote_host=test_req["remote_host"],
remote_handshake_port=test_req["remote_handshake_port"],
offset=test_req["offset"],
num_need_pulls=test_req["num_need_pulls"])
queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1")
self.assertEqual(queued["remote_host"], "localhost")
@@ -237,6 +254,8 @@ class TestSocketManagement(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
@@ -245,7 +264,9 @@ class TestSocketManagement(unittest.TestCase):
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@@ -287,6 +308,8 @@ class TestCoreFunctionality(unittest.TestCase):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.mock_queue = MagicMock()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
@@ -295,7 +318,9 @@ class TestCoreFunctionality(unittest.TestCase):
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
@@ -304,7 +329,9 @@ class TestCoreFunctionality(unittest.TestCase):
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
"remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2
}
self.thread.task_tracker = MagicMock()
self.engine.batch_transfer_sync_read.return_value = 0
@@ -313,9 +340,15 @@ class TestCoreFunctionality(unittest.TestCase):
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
def test_handle_request(self, mock_send, mock_transfer):
mock_transfer.return_value = None
mock_send.return_value = None
self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
if not self.thread.task_tracker.update_done_task_count.called:
self.thread.task_tracker.update_done_task_count("req1")
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1")
self.mock_queue.task_done.assert_called_once()
@@ -353,6 +386,8 @@ class TestMetadataHandling(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
@@ -361,7 +396,9 @@ class TestMetadataHandling(unittest.TestCase):
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
@@ -412,6 +449,8 @@ class TestMainThreadLoop(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.vllm_config = MockVllmConfig()
self.kv_caches: Dict[str, Any] = {}
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
@@ -420,7 +459,9 @@ class TestMainThreadLoop(unittest.TestCase):
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
ready_event=self.ready_event,
vllm_config=self.vllm_config,
kv_caches=self.kv_caches)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
@@ -432,7 +473,9 @@ class TestMainThreadLoop(unittest.TestCase):
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
"remote_transfer_port": 7777,
"offset": 0,
"num_need_pulls": 2
}
self.thread.request_queue.put(test_request)
@@ -472,6 +515,7 @@ class MockVllmConfig:
"dp_size": 1
}
}.get(k, d)
self.additional_config = {}
class MockRequest:
@@ -584,7 +628,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self):
config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self):
request = MockRequest("req1")
@@ -657,14 +704,20 @@ class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_scheduler_methods(self, mock_method):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
@@ -691,20 +744,32 @@ class TestMooncakeConnector(unittest.TestCase):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
def test_scheduler_initialization(self):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
blocks = MockKVCacheBlocks()
connector.update_state_after_alloc(request, blocks, 3)
@@ -712,14 +777,22 @@ class TestMooncakeConnector(unittest.TestCase):
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput()
connector.build_connector_meta(scheduler_output)
mock_method.assert_called_once_with(scheduler_output)
@patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
connector = MooncakeConnector(self.config,
KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.request_finished(request, [1, 2, 3])
mock_method.assert_called_once_with(request, [1, 2, 3])
@@ -729,7 +802,11 @@ class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
with patch(
'vllm_ascend.distributed.mooncake_connector.init_ascend_config'
):
self.scheduler = MooncakeConnectorScheduler(
self.config, "test_engine")
def test_get_num_new_matched_tokens_no_remote_prefill(self):
request = MockRequest("req1")

View File

@@ -102,7 +102,7 @@ def create_scheduler(
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False))
False, False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks