[Bugfix][PD] Auto-clear producer KV cache if no pull notification (#2174)
### What this PR does / why we need it?
This PR addresses a critical issue where Node D (Device) failures cause
Node P (Processor) to hang due to inability to release KV cache.
**Trigger Scenarios:**
1. Node D fails mid-inference (e.g., network disconnection)
2. Node D rejects requests at a certain stage (e.g., via API server)
3. Load-test script termination causes Node P or D to abort queued
requests
**Root Cause Analysis:**
1. Currently, Node D sends a "KV cache pull complete, release approved"
message to Node P
2. This message is transmitted via the worker connector. If PD
connection breaks or requests are rejected upstream, Node D cannot send
the message
3. Node P will never release KV cache without receiving this message
**Solution:**
Following VLLM community's approach (NIXL connector timeout mechanism),
we're implementing:
- A timeout mechanism with comprehensive warnings
- Updated README documentation
- Reference: VLLM's optimization PR
[#20139](https://github.com/vllm-project/vllm/pull/20139)
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
None
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
---------
Signed-off-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -17,6 +18,7 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
|
||||||
LLMException, LLMRole)
|
LLMException, LLMRole)
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import KVTransferConfig, VllmConfig
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
@@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler():
|
|||||||
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT
|
||||||
|
|
||||||
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
|
||||||
|
self._reqs_need_send: dict[str, float] = {}
|
||||||
|
|
||||||
def get_num_new_matched_tokens(
|
def get_num_new_matched_tokens(
|
||||||
self, request: "Request",
|
self, request: "Request",
|
||||||
@@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler():
|
|||||||
meta.add_new_req(request_id=req_id,
|
meta.add_new_req(request_id=req_id,
|
||||||
local_block_ids=block_ids,
|
local_block_ids=block_ids,
|
||||||
kv_transfer_params=req.kv_transfer_params)
|
kv_transfer_params=req.kv_transfer_params)
|
||||||
|
|
||||||
|
meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)
|
||||||
|
|
||||||
|
# Clear the list once workers start the transfers
|
||||||
self._reqs_need_recv.clear()
|
self._reqs_need_recv.clear()
|
||||||
|
self._reqs_need_send.clear()
|
||||||
|
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
@@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler():
|
|||||||
if delay_free_blocks:
|
if delay_free_blocks:
|
||||||
logger.info("Delaying free of %d blocks for request %s",
|
logger.info("Delaying free of %d blocks for request %s",
|
||||||
len(computed_block_ids), request.request_id)
|
len(computed_block_ids), request.request_id)
|
||||||
|
# Prefill request on remote. It will be read from D upon completion
|
||||||
|
self._reqs_need_send[request.request_id] = time.perf_counter(
|
||||||
|
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
|
||||||
return delay_free_blocks, dict(
|
return delay_free_blocks, dict(
|
||||||
do_remote_prefill=True,
|
do_remote_prefill=True,
|
||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
@@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||||
self.done_receiving_counts: defaultdict[str,
|
self.done_receiving_counts: defaultdict[str,
|
||||||
set[int]] = defaultdict(set)
|
set[int]] = defaultdict(set)
|
||||||
|
self.reqs_to_send: dict[str, float] = {}
|
||||||
|
|
||||||
def listen_for_agent_metadata_req(self, event: threading.Event):
|
def listen_for_agent_metadata_req(self, event: threading.Event):
|
||||||
assert self.local_agent_metadata is not None
|
assert self.local_agent_metadata is not None
|
||||||
@@ -379,7 +391,9 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||||
)
|
)
|
||||||
self.finished_reqs.add(finished_req_id)
|
if finished_req_id in self.reqs_to_send:
|
||||||
|
self.finished_reqs.add(finished_req_id)
|
||||||
|
del self.reqs_to_send[finished_req_id]
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
(identity, b"", b"receiving decode finished"))
|
(identity, b"", b"receiving decode finished"))
|
||||||
else:
|
else:
|
||||||
@@ -582,6 +596,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
|
|
||||||
for future in futures:
|
for future in futures:
|
||||||
future.add_done_callback(handle_exception)
|
future.add_done_callback(handle_exception)
|
||||||
|
self.reqs_to_send.update(metadata.reqs_to_send)
|
||||||
|
|
||||||
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
|
||||||
assert self.local_agent_metadata is not None
|
assert self.local_agent_metadata is not None
|
||||||
@@ -839,8 +854,19 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
self, finished_req_ids: set[str]
|
self, finished_req_ids: set[str]
|
||||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
"""Get the finished recving and sending requuests."""
|
"""Get the finished recving and sending requuests."""
|
||||||
import copy
|
now = time.perf_counter()
|
||||||
with self.thread_lock:
|
with self.thread_lock:
|
||||||
|
while self.reqs_to_send:
|
||||||
|
req_id, expires = next(iter(self.reqs_to_send.items()))
|
||||||
|
if now < expires:
|
||||||
|
break
|
||||||
|
logger.warning(
|
||||||
|
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
|
||||||
|
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
||||||
|
)
|
||||||
|
if req_id in self.reqs_to_send:
|
||||||
|
self.finished_reqs.add(req_id)
|
||||||
|
del self.reqs_to_send[req_id]
|
||||||
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
|
||||||
self.finished_reqs.clear()
|
self.finished_reqs.clear()
|
||||||
if self.llm_datadist_role == LLMRole.PROMPT:
|
if self.llm_datadist_role == LLMRole.PROMPT:
|
||||||
@@ -871,4 +897,4 @@ def zmq_ctx(socket_type: Any,
|
|||||||
yield socket
|
yield socket
|
||||||
finally:
|
finally:
|
||||||
if ctx is not None:
|
if ctx is not None:
|
||||||
ctx.destroy(linger=0)
|
ctx.destroy(linger=0)
|
||||||
@@ -19,6 +19,7 @@ import numpy.typing as npt
|
|||||||
import torch
|
import torch
|
||||||
import zmq
|
import zmq
|
||||||
from mooncake.engine import TransferEngine # type: ignore
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
@@ -100,7 +101,7 @@ class KVCacheTaskTracker:
|
|||||||
while self.delayed_free_requests:
|
while self.delayed_free_requests:
|
||||||
request_id, delay_start_time = self.delayed_free_requests[0]
|
request_id, delay_start_time = self.delayed_free_requests[0]
|
||||||
if (current_time - delay_start_time
|
if (current_time - delay_start_time
|
||||||
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
|
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
||||||
self.delayed_free_requests.popleft()
|
self.delayed_free_requests.popleft()
|
||||||
expired_requests.add(request_id)
|
expired_requests.add(request_id)
|
||||||
logger.info("Force freed request: %s", request_id)
|
logger.info("Force freed request: %s", request_id)
|
||||||
|
|||||||
@@ -159,11 +159,6 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# caused by the initialization of the Mooncake connector.
|
# caused by the initialization of the Mooncake connector.
|
||||||
"PHYSICAL_DEVICES":
|
"PHYSICAL_DEVICES":
|
||||||
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
lambda: os.getenv("PHYSICAL_DEVICES", None),
|
||||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
|
||||||
# node, if a request is marked for delayed KV block release and the blocks
|
|
||||||
# are not freed within this timeout, they will be forcibly released.
|
|
||||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
|
||||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
@@ -177,4 +172,4 @@ def __getattr__(name: str):
|
|||||||
|
|
||||||
|
|
||||||
def __dir__():
|
def __dir__():
|
||||||
return list(env_variables.keys())
|
return list(env_variables.keys())
|
||||||
Reference in New Issue
Block a user