[Bugfix][PD] Auto-clear producer KV cache if no pull notification (#2174)
### What this PR does / why we need it?
This PR addresses a critical issue where Node D (Device) failures cause
Node P (Processor) to hang due to inability to release KV cache.
**Trigger Scenarios:**
1. Node D fails mid-inference (e.g., network disconnection)
2. Node D rejects requests at a certain stage (e.g., via API server)
3. Load-test script termination causes Node P or D to abort queued
requests
**Root Cause Analysis:**
1. Currently, Node D sends a "KV cache pull complete, release approved"
message to Node P
2. This message is transmitted via the worker connector. If PD
connection breaks or requests are rejected upstream, Node D cannot send
the message
3. Node P will never release KV cache without receiving this message
**Solution:**
Following VLLM community's approach (NIXL connector timeout mechanism),
we're implementing:
- A timeout mechanism with comprehensive warnings
- Updated README documentation
- Reference: VLLM's optimization PR
[#20139](https://github.com/vllm-project/vllm/pull/20139)
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
None
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -17,6 +18,7 @@ import torch
|
||||
import zmq
|
||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||
LLMException, LLMRole)
|
||||
from vllm import envs
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
@@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||
self._reqs_need_send: dict[str, float] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params)
|
||||
|
||||
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_send.clear()
|
||||
|
||||
return meta
|
||||
|
||||
@@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler():
|
||||
if delay_free_blocks:
|
||||
logger.info("Delaying free of %d blocks for request %s",
|
||||
len(computed_block_ids), request.request_id)
|
||||
# Prefill request on remote. It will be read from D upon completion
|
||||
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||
return delay_free_blocks, dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
@@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
set[int]] = defaultdict(set)
|
||||
self.reqs_to_send: dict[str, float] = {}
|
||||
|
||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||
assert self.local_agent_metadata is not None
|
||||
@@ -379,7 +391,9 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
if finished_req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
del self.reqs_to_send[finished_req_id]
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
@@ -582,6 +596,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
|
||||
for future in futures:
|
||||
future.add_done_callback(handle_exception)
|
||||
self.reqs_to_send.update(metadata.reqs_to_send)
|
||||
|
||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||
assert self.local_agent_metadata is not None
|
||||
@@ -839,8 +854,19 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""Get the finished recving and sending requuests."""
|
||||
import copy
|
||||
now = time.perf_counter()
|
||||
with self.thread_lock:
|
||||
while self.reqs_to_send:
|
||||
req_id, expires = next(iter(self.reqs_to_send.items()))
|
||||
if now < expires:
|
||||
break
|
||||
logger.warning(
|
||||
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
||||
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
||||
)
|
||||
if req_id in self.reqs_to_send:
|
||||
self.finished_reqs.add(req_id)
|
||||
del self.reqs_to_send[req_id]
|
||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||
self.finished_reqs.clear()
|
||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||
@@ -871,4 +897,4 @@ def zmq_ctx(socket_type: Any,
|
||||
yield socket
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
ctx.destroy(linger=0)
|
||||
@@ -19,6 +19,7 @@ import numpy.typing as npt
|
||||
import torch
|
||||
import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
@@ -100,7 +101,7 @@ class KVCacheTaskTracker:
|
||||
while self.delayed_free_requests:
|
||||
request_id, delay_start_time = self.delayed_free_requests[0]
|
||||
if (current_time - delay_start_time
|
||||
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
|
||||
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
||||
self.delayed_free_requests.popleft()
|
||||
expired_requests.add(request_id)
|
||||
logger.info("Force freed request: %s", request_id)
|
||||
|
||||
@@ -159,11 +159,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# caused by the initialization of the Mooncake connector.
|
||||
"PHYSICAL_DEVICES":
|
||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||
# node, if a request is marked for delayed KV block release and the blocks
|
||||
# are not freed within this timeout, they will be forcibly released.
|
||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
@@ -177,4 +172,4 @@ def __getattr__(name: str):
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
||||
return list(env_variables.keys())
|
||||
Reference in New Issue
Block a user