[P/D]mooncake_connector adapted to 0.10.1 (#2664)
### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(https://github.com/vllm-project/vllm/pull/19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.
This PR is currently linked to a PR in vllm
(https://github.com/vllm-project/vllm/pull/23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.
The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.
- vLLM version: v0.10.1.1
- vLLM main:
fa4311d85f
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.utils import make_zmq_path
|
||||
from zmq import Context # type: ignore
|
||||
|
||||
fake_engine = types.ModuleType("mooncake.engine")
|
||||
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
|
||||
@@ -32,193 +31,17 @@ DONE_RECVING_MSG = b"done_recving_msg"
|
||||
class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
|
||||
def test_init_basic_properties(self):
|
||||
tracker = KVCacheTaskTracker(tp_rank=1,
|
||||
local_engine_id="engine1",
|
||||
target_count=10)
|
||||
self.assertEqual(tracker.tp_rank, 1)
|
||||
self.assertEqual(tracker.local_engine_id, "engine1")
|
||||
self.assertEqual(tracker.target_count, 10)
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.done_task_counts, defaultdict)
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
|
||||
def test_socket_path_generation(self):
|
||||
tracker = KVCacheTaskTracker(tp_rank=1,
|
||||
local_engine_id="engine42",
|
||||
target_count=1)
|
||||
self.assertEqual(tracker.socket_path,
|
||||
"ipc:///tmp/vllm_mooncake_connector_engine42.ipc")
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.threading.Thread")
|
||||
def test_tp_rank_zero_initialization(self, mock_thread):
|
||||
tracker = KVCacheTaskTracker(tp_rank=0,
|
||||
local_engine_id="test",
|
||||
target_count=1)
|
||||
mock_thread.assert_called_once_with(
|
||||
target=tracker._listen_for_completion_signals,
|
||||
daemon=True,
|
||||
name="KVCacheTaskTrackerListenerThread")
|
||||
mock_thread.return_value.start.assert_called_once()
|
||||
self.assertIsNone(tracker.socket)
|
||||
self.assertTrue(tracker.listener.daemon)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_tp_rank_non_zero_initialization(self, mock_logger,
|
||||
mock_make_zmq_socket):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_zmq_socket.return_value = mock_socket
|
||||
tracker = KVCacheTaskTracker(tp_rank=1,
|
||||
local_engine_id="test",
|
||||
target_count=1)
|
||||
mock_make_zmq_socket.assert_called_once_with(
|
||||
ctx=unittest.mock.ANY,
|
||||
path="ipc:///tmp/vllm_mooncake_connector_test.ipc",
|
||||
socket_type=zmq.PUSH, # type: ignore
|
||||
bind=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Connecting to transfer socket at %s",
|
||||
"ipc:///tmp/vllm_mooncake_connector_test.ipc")
|
||||
self.assertIsNone(tracker.listener)
|
||||
self.assertEqual(tracker.socket, mock_socket)
|
||||
|
||||
|
||||
class TestKVCacheTaskTrackerListenMethod(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tp_rank = 0
|
||||
self.local_engine_id = "test_engine_ut"
|
||||
self.target_count = 3
|
||||
self.tracker = KVCacheTaskTracker(self.tp_rank, self.local_engine_id,
|
||||
self.target_count)
|
||||
self.original_listen = self.tracker._listen_for_completion_signals
|
||||
|
||||
def tearDown(self):
|
||||
self.tracker._listen_for_completion_signals = self.original_listen
|
||||
Context.instance().term()
|
||||
time.sleep(0.1)
|
||||
|
||||
def test_normal_message_processing(self):
|
||||
listener_thread = threading.Thread(
|
||||
target=self.tracker._listen_for_completion_signals, daemon=True)
|
||||
listener_thread.start()
|
||||
time.sleep(0.2)
|
||||
test_messages = [("request_001", 1), ("request_001", 2),
|
||||
("request_002", 0), ("request_003", 1)]
|
||||
ctx = Context()
|
||||
sender_socket = ctx.socket(zmq.PUSH) # type: ignore
|
||||
sender_socket.connect(self.tracker.socket_path)
|
||||
for msg in test_messages:
|
||||
sender_socket.send_pyobj(msg)
|
||||
time.sleep(0.05)
|
||||
sender_socket.close()
|
||||
time.sleep(0.2)
|
||||
|
||||
with self.tracker.done_task_lock:
|
||||
self.assertEqual(len(self.tracker.done_task_counts["request_001"]),
|
||||
2)
|
||||
self.assertIn(1, self.tracker.done_task_counts["request_001"])
|
||||
self.assertIn(2, self.tracker.done_task_counts["request_001"])
|
||||
self.assertEqual(len(self.tracker.done_task_counts["request_002"]),
|
||||
1)
|
||||
self.assertIn(0, self.tracker.done_task_counts["request_002"])
|
||||
self.assertEqual(len(self.tracker.done_task_counts["request_003"]),
|
||||
1)
|
||||
self.assertIn(1, self.tracker.done_task_counts["request_003"])
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
|
||||
autospec=True)
|
||||
def test_listen_with_timeout(self, mock_make_socket):
|
||||
mock_socket = MagicMock()
|
||||
|
||||
def mock_recv():
|
||||
start = time.time()
|
||||
while time.time() - start < 0.5:
|
||||
time.sleep(0.01)
|
||||
return ("req1", 0)
|
||||
|
||||
mock_socket.recv_pyobj = mock_recv
|
||||
mock_make_socket.return_value = mock_socket
|
||||
|
||||
test_thread = threading.Thread(
|
||||
target=self.tracker._listen_for_completion_signals, daemon=True)
|
||||
test_thread.start()
|
||||
test_thread.join(timeout=1.0)
|
||||
mock_make_socket.assert_called_once()
|
||||
|
||||
|
||||
class TestKVCacheTaskTrackerTP(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.local_engine_id = "test_engine"
|
||||
self.target_count = 3
|
||||
|
||||
def test_update_done_task_count_tp_rank_0(self):
|
||||
tracker = KVCacheTaskTracker(tp_rank=0,
|
||||
local_engine_id=self.local_engine_id,
|
||||
target_count=self.target_count)
|
||||
test_request_id = "test_req_001"
|
||||
test_tp_rank = 1
|
||||
tracker.update_done_task_count(test_request_id, test_tp_rank)
|
||||
with tracker.done_task_lock:
|
||||
self.assertEqual(len(tracker.done_task_counts[test_request_id]), 1)
|
||||
self.assertIn(test_tp_rank,
|
||||
tracker.done_task_counts[test_request_id])
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
|
||||
autospec=True)
|
||||
def test_update_done_task_count_non_zero_tp(self, mock_make_socket):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_socket.return_value = mock_socket
|
||||
tracker = KVCacheTaskTracker(tp_rank=1,
|
||||
local_engine_id=self.local_engine_id,
|
||||
target_count=self.target_count)
|
||||
test_request_id = "test_req_002"
|
||||
test_tp_rank = 1
|
||||
tracker.update_done_task_count(test_request_id, test_tp_rank)
|
||||
mock_socket.send_pyobj.assert_called_once_with(
|
||||
(test_request_id, test_tp_rank))
|
||||
with tracker.done_task_lock:
|
||||
self.assertNotIn(test_request_id, tracker.done_task_counts)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger", autospec=True)
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
|
||||
autospec=True)
|
||||
def test_update_done_task_count_logging(self, mock_make_socket,
|
||||
mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_socket.return_value = mock_socket
|
||||
tracker = KVCacheTaskTracker(tp_rank=2,
|
||||
local_engine_id=self.local_engine_id,
|
||||
target_count=self.target_count)
|
||||
test_request_id = "test_req_003"
|
||||
tracker.update_done_task_count(test_request_id, 2)
|
||||
mock_logger.debug.assert_called_once_with(
|
||||
"Sent done signal for request %s to tp 0", test_request_id)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket",
|
||||
autospec=True)
|
||||
def test_update_multiple_calls(self, mock_make_socket):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_socket.return_value = mock_socket
|
||||
tracker = KVCacheTaskTracker(tp_rank=1,
|
||||
local_engine_id=self.local_engine_id,
|
||||
target_count=self.target_count)
|
||||
test_data = [("req1", 1), ("req1", 1), ("req2", 1)]
|
||||
for req_id, rank in test_data:
|
||||
tracker.update_done_task_count(req_id, rank)
|
||||
self.assertEqual(mock_socket.send_pyobj.call_count, 3)
|
||||
mock_socket.send_pyobj.assert_called_with(("req2", 1))
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker(tp_rank=0,
|
||||
local_engine_id="test",
|
||||
target_count=3)
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
self.tracker.finished_requests = set()
|
||||
self.tracker.done_task_counts = defaultdict(set)
|
||||
self.tracker.done_task_lock = threading.Lock()
|
||||
|
||||
def test_empty_requests(self):
|
||||
@@ -251,14 +74,6 @@ class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
self.assertEqual(sum(1 for r in results if r), 1)
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
def test_after_increment(self):
|
||||
self.tracker._increment_task_count("req_123", 0)
|
||||
self.tracker._increment_task_count("req_123", 1)
|
||||
self.tracker._increment_task_count("req_123", 2)
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, {"req_123"})
|
||||
self.assertEqual(self.tracker.get_and_clear_finished_requests(), set())
|
||||
|
||||
|
||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
|
||||
@@ -282,47 +97,6 @@ class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
if hasattr(thread, 'is_alive') and thread.is_alive():
|
||||
thread.join(timeout=0.1)
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
|
||||
def test_initialization_basic(self, mock_tracker):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
self.assertEqual(thread.tp_rank, 1)
|
||||
self.assertEqual(thread.decode_tp_size, 4)
|
||||
self.assertEqual(thread.local_engine_id, 'engine_1')
|
||||
mock_tracker.assert_called_once()
|
||||
args = mock_tracker.call_args[0]
|
||||
kwargs = mock_tracker.call_args[1]
|
||||
if args:
|
||||
self.assertEqual(args[0], 1)
|
||||
self.assertEqual(args[1], 'engine_1')
|
||||
self.assertEqual(args[2], 4)
|
||||
else:
|
||||
self.assertEqual(kwargs['tp_rank'], 1)
|
||||
self.assertEqual(kwargs['local_engine_id'], 'engine_1')
|
||||
self.assertEqual(kwargs['target_count'], 4)
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.KVCacheTaskTracker')
|
||||
def test_task_tracker_initialization(self, mock_tracker):
|
||||
args = self.common_args.copy()
|
||||
args.update({
|
||||
'tp_rank': 2,
|
||||
'decode_tp_size': 8,
|
||||
'local_engine_id': 'engine_2'
|
||||
})
|
||||
thread = KVCacheSendingThread(**args)
|
||||
self.threads.append(thread)
|
||||
mock_tracker.assert_called_once()
|
||||
call_args = mock_tracker.call_args[0]
|
||||
call_kwargs = mock_tracker.call_args[1]
|
||||
if call_args:
|
||||
self.assertEqual(call_args[0], 2)
|
||||
self.assertEqual(call_args[1], 'engine_2')
|
||||
self.assertEqual(call_args[2], 8)
|
||||
else:
|
||||
self.assertEqual(call_kwargs['tp_rank'], 2)
|
||||
self.assertEqual(call_kwargs['local_engine_id'], 'engine_2')
|
||||
self.assertEqual(call_kwargs['target_count'], 8)
|
||||
|
||||
def test_thread_daemon_property(self):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
@@ -542,7 +316,7 @@ class TestCoreFunctionality(unittest.TestCase):
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
"req1", self.thread.tp_rank)
|
||||
"req1")
|
||||
self.mock_queue.task_done.assert_called_once()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
|
||||
@@ -675,9 +449,11 @@ class TestMainThreadLoop(unittest.TestCase):
|
||||
class MockVllmConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = MagicMock()
|
||||
self.parallel_config = MagicMock()
|
||||
self.cache_config = MagicMock()
|
||||
self.kv_transfer_config = MagicMock()
|
||||
self.model_config.use_mla = True
|
||||
self.parallel_config.tensor_parallel_size = 2
|
||||
self.parallel_config.data_parallel_rank_local = 0
|
||||
self.parallel_config.data_parallel_size_local = 1
|
||||
@@ -714,28 +490,40 @@ class MockRequest:
|
||||
class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker(tp_rank=0,
|
||||
local_engine_id="test_engine",
|
||||
target_count=2)
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
|
||||
def test_update_task_count(self):
|
||||
self.assertEqual(len(self.tracker.done_task_counts), 0)
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req1", 0)
|
||||
self.tracker.update_done_task_count("req1", 1)
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
|
||||
self.assertEqual(len(self.tracker.finished_requests), 1)
|
||||
self.assertTrue("req1" in self.tracker.finished_requests)
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
|
||||
finished = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(finished, {"req1"})
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time - 600)
|
||||
self.tracker.add_delayed_request("req_2", current_time)
|
||||
result = self.tracker._retrieve_expired_requests()
|
||||
self.assertEqual(result, {
|
||||
"req_1",
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1", 0)
|
||||
self.tracker.update_done_task_count("req1", 0)
|
||||
self.tracker.update_done_task_count("req1", 1)
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
|
||||
finished = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(finished, {"req1"})
|
||||
@@ -745,6 +533,9 @@ class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
|
||||
def test_add_new_req(self):
|
||||
meta = MooncakeConnectorMetadata()
|
||||
self.assertEqual(len(meta.requests), 0)
|
||||
self.assertEqual(len(meta.requests_to_send), 0)
|
||||
|
||||
meta.add_new_req(request_id="req1",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
@@ -802,6 +593,10 @@ class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
|
||||
def test_get_finished_count(self):
|
||||
count = self.scheduler.get_finished_count()
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
|
||||
class TestHelperFunctions(unittest.TestCase):
|
||||
|
||||
|
||||
@@ -58,74 +58,21 @@ class ReqMeta:
|
||||
|
||||
class KVCacheTaskTracker:
|
||||
|
||||
def __init__(self, tp_rank: int, local_engine_id: str, target_count: int):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tp_rank = tp_rank
|
||||
self.local_engine_id = local_engine_id
|
||||
self.target_count = target_count
|
||||
|
||||
self.done_task_lock = threading.Lock()
|
||||
self.done_task_counts: defaultdict[str, set[int]] = defaultdict(set)
|
||||
self.finished_requests: set[str] = set()
|
||||
# Only used in prefill node. Tracks requests whose kv blocks freeing is
|
||||
# intentionally delayed. Each entry is a tuple of (request_id,
|
||||
# timestamp). If a request remains in this queue for too long, it will
|
||||
# be force-freed.
|
||||
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
|
||||
|
||||
self.socket_path = \
|
||||
f"ipc:///tmp/vllm_mooncake_connector_{self.local_engine_id}.ipc"
|
||||
if tp_rank == 0:
|
||||
self.listener = threading.Thread(
|
||||
target=self._listen_for_completion_signals,
|
||||
daemon=True,
|
||||
name="KVCacheTaskTrackerListenerThread")
|
||||
self.listener.start()
|
||||
self.socket = None
|
||||
else:
|
||||
self.listener = None # type: ignore
|
||||
self.socket = make_zmq_socket(
|
||||
ctx=zmq.Context(), # type: ignore
|
||||
path=self.socket_path,
|
||||
socket_type=zmq.PUSH, # type: ignore
|
||||
bind=False)
|
||||
logger.info("Connecting to transfer socket at %s",
|
||||
self.socket_path)
|
||||
|
||||
def _listen_for_completion_signals(self):
|
||||
socket = make_zmq_socket(
|
||||
ctx=zmq.Context(), # type: ignore
|
||||
path=self.socket_path,
|
||||
socket_type=zmq.PULL, # type: ignore
|
||||
bind=True)
|
||||
logger.info("Listening for completion signals on %s", self.socket_path)
|
||||
|
||||
while True:
|
||||
try:
|
||||
done_request_id, tp_rank = socket.recv_pyobj()
|
||||
logger.debug("Received completion notification for request: "
|
||||
f"{done_request_id} from tp rank {tp_rank}")
|
||||
self._increment_task_count(done_request_id, tp_rank)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_busy_loop: {e}")
|
||||
|
||||
def update_done_task_count(self, request_id: str, tp_rank: int):
|
||||
if self.tp_rank == 0:
|
||||
self._increment_task_count(request_id, tp_rank)
|
||||
else:
|
||||
self.socket.send_pyobj((request_id, tp_rank)) # type: ignore
|
||||
logger.debug("Sent done signal for request %s to tp 0", request_id)
|
||||
|
||||
def _increment_task_count(self, request_id: str, tp_rank: int):
|
||||
def update_done_task_count(self, request_id: str):
|
||||
with self.done_task_lock:
|
||||
if tp_rank in self.done_task_counts[request_id]:
|
||||
logger.warning(
|
||||
f"Received duplicate done signal for request {request_id} "
|
||||
f"from tp rank {tp_rank}. Ignoring.")
|
||||
return
|
||||
|
||||
self.done_task_counts[request_id].add(tp_rank)
|
||||
if len(self.done_task_counts[request_id]) == self.target_count:
|
||||
self.finished_requests.add(request_id)
|
||||
self.done_task_counts.pop(request_id)
|
||||
logger.info("All transfers completed for request: "
|
||||
f"{request_id}. Total ranks: "
|
||||
f"{self.target_count}.")
|
||||
self.finished_requests.add(request_id)
|
||||
self._remove_delayed_requests(request_id)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
@@ -135,9 +82,37 @@ class KVCacheTaskTracker:
|
||||
"""
|
||||
with self.done_task_lock:
|
||||
finished_requests = self.finished_requests.copy()
|
||||
expired_requests = self._retrieve_expired_requests()
|
||||
finished_requests.update(expired_requests)
|
||||
self.finished_requests.clear()
|
||||
return finished_requests
|
||||
|
||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||
"""Add a delayed free request."""
|
||||
with self.done_task_lock:
|
||||
self.delayed_free_requests.append((request_id, delay_start_time))
|
||||
|
||||
def _retrieve_expired_requests(self):
|
||||
"""Retrieve all expired delayed requests."""
|
||||
expired_requests: set[str] = set()
|
||||
# Free delayed requests if they exceed the timeout
|
||||
current_time = time.time()
|
||||
while self.delayed_free_requests:
|
||||
request_id, delay_start_time = self.delayed_free_requests[0]
|
||||
if (current_time - delay_start_time
|
||||
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
|
||||
self.delayed_free_requests.popleft()
|
||||
expired_requests.add(request_id)
|
||||
logger.info("Force freed request: %s", request_id)
|
||||
else:
|
||||
break
|
||||
return expired_requests
|
||||
|
||||
def _remove_delayed_requests(self, request_id: str):
|
||||
"""Remove all delayed free requests matching the given request_id."""
|
||||
self.delayed_free_requests = deque(
|
||||
(r, t) for r, t in self.delayed_free_requests if r != request_id)
|
||||
|
||||
|
||||
class KVCacheSendingThread(threading.Thread):
|
||||
|
||||
@@ -154,9 +129,7 @@ class KVCacheSendingThread(threading.Thread):
|
||||
self.metadata = metadata
|
||||
self.ready_event = ready_event
|
||||
|
||||
self.task_tracker = KVCacheTaskTracker(self.tp_rank,
|
||||
self.local_engine_id,
|
||||
self.decode_tp_size)
|
||||
self.task_tracker = KVCacheTaskTracker()
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
@@ -166,6 +139,10 @@ class KVCacheSendingThread(threading.Thread):
|
||||
"""
|
||||
return self.task_tracker.get_and_clear_finished_requests()
|
||||
|
||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||
return self.task_tracker.add_delayed_request(request_id,
|
||||
delay_start_time)
|
||||
|
||||
def run(self):
|
||||
"""Run the thread to handle KV cache transfer requests."""
|
||||
|
||||
@@ -204,9 +181,8 @@ class KVCacheSendingThread(threading.Thread):
|
||||
elif msg[0] == DONE_RECVING_MSG:
|
||||
logger.debug("Got DONE_RECVING_MSG for request %s",
|
||||
msg[1])
|
||||
request_id, decode_tp_rank = msg[1], msg[2]
|
||||
self.task_tracker.update_done_task_count(
|
||||
request_id, decode_tp_rank)
|
||||
request_id = msg[1]
|
||||
self.task_tracker.update_done_task_count(request_id)
|
||||
# Acknowledge the request completion.
|
||||
while True:
|
||||
try:
|
||||
@@ -259,9 +235,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
# TODO(jianzs): make this configurable
|
||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||
|
||||
self.task_tracker = KVCacheTaskTracker(self.tp_rank,
|
||||
self.local_engine_id,
|
||||
self.tp_size)
|
||||
self.task_tracker = KVCacheTaskTracker()
|
||||
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
@@ -323,7 +297,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
logger.error("Failed to transfer KV cache for request "
|
||||
f"{request_id}: {e}")
|
||||
finally:
|
||||
self.task_tracker.update_done_task_count(request_id, self.tp_rank)
|
||||
self.task_tracker.update_done_task_count(request_id)
|
||||
# Always send the done signal to the remote host to ensure proper
|
||||
# resource cleanup. Failing to do so may cause a memory leak on the
|
||||
# remote host.
|
||||
@@ -422,8 +396,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
sock: Optional[zmq.Socket] = None # type: ignore
|
||||
try:
|
||||
sock = self._get_remote_socket(remote_host, remote_handshake_port)
|
||||
data_bytes = self.encoder.encode(
|
||||
(DONE_RECVING_MSG, request_id, self.tp_rank))
|
||||
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
|
||||
ensure_zmq_send(sock, data_bytes)
|
||||
resp = ensure_zmq_recv(sock,
|
||||
self.remote_poller,
|
||||
@@ -479,6 +452,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
def __init__(self):
|
||||
self.requests: dict[str, ReqMeta] = {}
|
||||
self.requests_to_send: dict[str, float] = {}
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@@ -543,6 +517,10 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def get_finished_count(self) -> Optional[int]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_finished_count()
|
||||
|
||||
############################################################
|
||||
# Worker Side Methods
|
||||
############################################################
|
||||
@@ -599,6 +577,7 @@ class MooncakeConnectorScheduler:
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@@ -684,6 +663,8 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
meta.requests_to_send = self._reqs_need_send
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
|
||||
@@ -711,6 +692,8 @@ class MooncakeConnectorScheduler:
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
self._reqs_need_send[request.request_id] = time.time()
|
||||
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
@@ -720,6 +703,27 @@ class MooncakeConnectorScheduler:
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
|
||||
def get_finished_count(self) -> Optional[int]:
|
||||
prefill_parallel_config: dict[
|
||||
str,
|
||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {})
|
||||
|
||||
assert "tp_size" in prefill_parallel_config.keys()
|
||||
self._prefill_tp_size = prefill_parallel_config["tp_size"]
|
||||
decode_parallel_config: dict[
|
||||
str,
|
||||
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {})
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
return self._decode_tp_size
|
||||
else:
|
||||
# TODO support mha and gqa
|
||||
return None
|
||||
|
||||
|
||||
class MooncakeConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
@@ -737,6 +741,7 @@ class MooncakeConnectorWorker:
|
||||
self.engine = TransferEngine()
|
||||
|
||||
# Metadata.
|
||||
self.vllm_config = vllm_config
|
||||
self.engine_id = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
@@ -946,6 +951,12 @@ class MooncakeConnectorWorker:
|
||||
remote_handshake_port=remote_handshake_port,
|
||||
)
|
||||
|
||||
if self.kv_send_thread is not None:
|
||||
for req_id, delay_start_time in metadata.requests_to_send.items():
|
||||
if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id):
|
||||
self.kv_send_thread.add_delayed_request(
|
||||
req_id, delay_start_time)
|
||||
|
||||
def _get_remote_tp_rank(self, req_id: str) -> int:
|
||||
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
|
||||
|
||||
|
||||
@@ -139,6 +139,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# caused by the initialization of the Mooncake connector.
|
||||
"PHYSICAL_DEVICES":
|
||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||
# node, if a request is marked for delayed KV block release and the blocks
|
||||
# are not freed within this timeout, they will be forcibly released.
|
||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
Reference in New Issue
Block a user