[P/D]mooncake_connector adapted to 0.10.1 (#2664)

### What this PR does / why we need it?
In vllm version 0.10.1, a new KVOutputAggregator was added to the
executor, moving aggregation to the
executor(https://github.com/vllm-project/vllm/pull/19555). This caused
mooncake_connector to break. This change aims to fix this bug and also
adds a policy to forcibly release the KV cache when the prefill node
times out.

This PR is currently linked to a PR in vllm
(https://github.com/vllm-project/vllm/pull/23917). The vllm PR aims to
modify the finish and send count confirmation in heterogeneous TP
situations.

The reason for deleting many UTs is that a lot of communication codes
have been deleted, so the UT as a whole will appear more concise.

- vLLM version: v0.10.1.1
- vLLM main:
fa4311d85f

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
baxingpiaochong
2025-09-04 08:22:10 +08:00
committed by GitHub
parent 07d44ade19
commit df88a2ecc8
3 changed files with 130 additions and 319 deletions

View File

@@ -58,74 +58,21 @@ class ReqMeta:
class KVCacheTaskTracker:
def __init__(self, tp_rank: int, local_engine_id: str, target_count: int):
def __init__(self):
super().__init__()
self.tp_rank = tp_rank
self.local_engine_id = local_engine_id
self.target_count = target_count
self.done_task_lock = threading.Lock()
self.done_task_counts: defaultdict[str, set[int]] = defaultdict(set)
self.finished_requests: set[str] = set()
# Only used in prefill node. Tracks requests whose kv blocks freeing is
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
self.socket_path = \
f"ipc:///tmp/vllm_mooncake_connector_{self.local_engine_id}.ipc"
if tp_rank == 0:
self.listener = threading.Thread(
target=self._listen_for_completion_signals,
daemon=True,
name="KVCacheTaskTrackerListenerThread")
self.listener.start()
self.socket = None
else:
self.listener = None # type: ignore
self.socket = make_zmq_socket(
ctx=zmq.Context(), # type: ignore
path=self.socket_path,
socket_type=zmq.PUSH, # type: ignore
bind=False)
logger.info("Connecting to transfer socket at %s",
self.socket_path)
def _listen_for_completion_signals(self):
socket = make_zmq_socket(
ctx=zmq.Context(), # type: ignore
path=self.socket_path,
socket_type=zmq.PULL, # type: ignore
bind=True)
logger.info("Listening for completion signals on %s", self.socket_path)
while True:
try:
done_request_id, tp_rank = socket.recv_pyobj()
logger.debug("Received completion notification for request: "
f"{done_request_id} from tp rank {tp_rank}")
self._increment_task_count(done_request_id, tp_rank)
except Exception as e:
logger.error(f"Error in run_busy_loop: {e}")
def update_done_task_count(self, request_id: str, tp_rank: int):
if self.tp_rank == 0:
self._increment_task_count(request_id, tp_rank)
else:
self.socket.send_pyobj((request_id, tp_rank)) # type: ignore
logger.debug("Sent done signal for request %s to tp 0", request_id)
def _increment_task_count(self, request_id: str, tp_rank: int):
def update_done_task_count(self, request_id: str):
with self.done_task_lock:
if tp_rank in self.done_task_counts[request_id]:
logger.warning(
f"Received duplicate done signal for request {request_id} "
f"from tp rank {tp_rank}. Ignoring.")
return
self.done_task_counts[request_id].add(tp_rank)
if len(self.done_task_counts[request_id]) == self.target_count:
self.finished_requests.add(request_id)
self.done_task_counts.pop(request_id)
logger.info("All transfers completed for request: "
f"{request_id}. Total ranks: "
f"{self.target_count}.")
self.finished_requests.add(request_id)
self._remove_delayed_requests(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -135,9 +82,37 @@ class KVCacheTaskTracker:
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
expired_requests = self._retrieve_expired_requests()
finished_requests.update(expired_requests)
self.finished_requests.clear()
return finished_requests
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
self.delayed_free_requests.append((request_id, delay_start_time))
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
expired_requests: set[str] = set()
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id, delay_start_time = self.delayed_free_requests[0]
if (current_time - delay_start_time
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
self.delayed_free_requests.popleft()
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
break
return expired_requests
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests = deque(
(r, t) for r, t in self.delayed_free_requests if r != request_id)
class KVCacheSendingThread(threading.Thread):
@@ -154,9 +129,7 @@ class KVCacheSendingThread(threading.Thread):
self.metadata = metadata
self.ready_event = ready_event
self.task_tracker = KVCacheTaskTracker(self.tp_rank,
self.local_engine_id,
self.decode_tp_size)
self.task_tracker = KVCacheTaskTracker()
def get_and_clear_finished_requests(self) -> set[str]:
"""
@@ -166,6 +139,10 @@ class KVCacheSendingThread(threading.Thread):
"""
return self.task_tracker.get_and_clear_finished_requests()
def add_delayed_request(self, request_id: str, delay_start_time: float):
return self.task_tracker.add_delayed_request(request_id,
delay_start_time)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
@@ -204,9 +181,8 @@ class KVCacheSendingThread(threading.Thread):
elif msg[0] == DONE_RECVING_MSG:
logger.debug("Got DONE_RECVING_MSG for request %s",
msg[1])
request_id, decode_tp_rank = msg[1], msg[2]
self.task_tracker.update_done_task_count(
request_id, decode_tp_rank)
request_id = msg[1]
self.task_tracker.update_done_task_count(request_id)
# Acknowledge the request completion.
while True:
try:
@@ -259,9 +235,7 @@ class KVCacheRecvingThread(threading.Thread):
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.task_tracker = KVCacheTaskTracker(self.tp_rank,
self.local_engine_id,
self.tp_size)
self.task_tracker = KVCacheTaskTracker()
self.encoder = msgspec.msgpack.Encoder()
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
@@ -323,7 +297,7 @@ class KVCacheRecvingThread(threading.Thread):
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
finally:
self.task_tracker.update_done_task_count(request_id, self.tp_rank)
self.task_tracker.update_done_task_count(request_id)
# Always send the done signal to the remote host to ensure proper
# resource cleanup. Failing to do so may cause a memory leak on the
# remote host.
@@ -422,8 +396,7 @@ class KVCacheRecvingThread(threading.Thread):
sock: Optional[zmq.Socket] = None # type: ignore
try:
sock = self._get_remote_socket(remote_host, remote_handshake_port)
data_bytes = self.encoder.encode(
(DONE_RECVING_MSG, request_id, self.tp_rank))
data_bytes = self.encoder.encode((DONE_RECVING_MSG, request_id))
ensure_zmq_send(sock, data_bytes)
resp = ensure_zmq_recv(sock,
self.remote_poller,
@@ -479,6 +452,7 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
self.requests_to_send: dict[str, float] = {}
def add_new_req(
self,
@@ -543,6 +517,10 @@ class MooncakeConnector(KVConnectorBase_V1):
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def get_finished_count(self) -> Optional[int]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_finished_count()
############################################################
# Worker Side Methods
############################################################
@@ -599,6 +577,7 @@ class MooncakeConnectorScheduler:
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
def get_num_new_matched_tokens(
self, request: "Request",
@@ -684,6 +663,8 @@ class MooncakeConnectorScheduler:
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
meta.requests_to_send = self._reqs_need_send
self._reqs_need_send = {}
return meta
@@ -711,6 +692,8 @@ class MooncakeConnectorScheduler:
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
self._reqs_need_send[request.request_id] = time.time()
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
@@ -720,6 +703,27 @@ class MooncakeConnectorScheduler:
remote_port=self.side_channel_port,
)
def get_finished_count(self) -> Optional[int]:
prefill_parallel_config: dict[
str,
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {})
assert "tp_size" in prefill_parallel_config.keys()
self._prefill_tp_size = prefill_parallel_config["tp_size"]
decode_parallel_config: dict[
str,
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
return self._decode_tp_size
else:
# TODO support mha and gqa
return None
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
@@ -737,6 +741,7 @@ class MooncakeConnectorWorker:
self.engine = TransferEngine()
# Metadata.
self.vllm_config = vllm_config
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
@@ -946,6 +951,12 @@ class MooncakeConnectorWorker:
remote_handshake_port=remote_handshake_port,
)
if self.kv_send_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id):
self.kv_send_thread.add_delayed_request(
req_id, delay_start_time)
def _get_remote_tp_rank(self, req_id: str) -> int:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]