forked from EngineX-Ascend/enginex-ascend-910-vllm
init v0.11.0rc0
This commit is contained in:
@@ -7,6 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
@@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
self.assertEqual(len(self.tracker.record_finished_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
self.assertEqual(result["req_1"], current_time)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_2")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1", "req_2"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 1)
|
||||
self.assertEqual(result_record, {"req_2"})
|
||||
|
||||
def test_updtate_add_delayed_request(self) -> None:
|
||||
self.tracker.update_done_task_count("req2")
|
||||
result_start_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_start_record), 1)
|
||||
self.tracker.add_delayed_request("req2", time.time())
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_end_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_end_record), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
@@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
self.assertIn("req_2", result_delay)
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
@@ -961,6 +986,46 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_worker_use_ascend_direct(self):
|
||||
test_case = [True, False]
|
||||
|
||||
for use_ascend_direct in test_case:
|
||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
||||
config = MagicMock()
|
||||
config.kv_transfer_config = MagicMock()
|
||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
||||
lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"use_ascend_direct": use_ascend_direct,
|
||||
}.get(k, d))
|
||||
|
||||
config.parallel_config = MagicMock()
|
||||
config.parallel_config.tensor_parallel_size = 2
|
||||
config.parallel_config.data_parallel_rank_local = 0
|
||||
config.parallel_config.data_parallel_size_local = 1
|
||||
config.kv_transfer_config.kv_port = 8000
|
||||
config.kv_transfer_config.kv_role = 'worker'
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
|
||||
return_value=0):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
|
||||
return_value=None):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_ip",
|
||||
return_value="127.0.0.1"):
|
||||
worker = MooncakeConnectorWorker(
|
||||
config, self.engine_id)
|
||||
self.assertIsNotNone(worker)
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
|
||||
Reference in New Issue
Block a user