diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index ddbc3c0..3782bc7 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -1,4 +1,5 @@ import contextlib +import copy import json import math import os @@ -17,6 +18,7 @@ import torch import zmq from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist, LLMException, LLMRole) +from vllm import envs from vllm.config import KVTransferConfig, VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -184,6 +186,7 @@ class LLMDataDistCMgrConnectorScheduler(): self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[str, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -248,7 +251,12 @@ class LLMDataDistCMgrConnectorScheduler(): meta.add_new_req(request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params) + + meta.reqs_to_send = copy.deepcopy(self._reqs_need_send) + + # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_send.clear() return meta @@ -275,6 +283,9 @@ class LLMDataDistCMgrConnectorScheduler(): if delay_free_blocks: logger.info("Delaying free of %d blocks for request %s", len(computed_block_ids), request.request_id) + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -341,6 +352,7 @@ class LLMDataDistCMgrConnectorWorker(): os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, set[int]] = defaultdict(set) + self.reqs_to_send: dict[str, float] = {} def listen_for_agent_metadata_req(self, event: threading.Event): assert self.local_agent_metadata is not None @@ -379,7 +391,9 @@ class LLMDataDistCMgrConnectorWorker(): logger.debug( f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" ) - self.finished_reqs.add(finished_req_id) + if finished_req_id in self.reqs_to_send: + self.finished_reqs.add(finished_req_id) + del self.reqs_to_send[finished_req_id] sock.send_multipart( (identity, b"", b"receiving decode finished")) else: @@ -582,6 +596,7 @@ class LLMDataDistCMgrConnectorWorker(): for future in futures: future.add_done_callback(handle_exception) + self.reqs_to_send.update(metadata.reqs_to_send) def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int: assert self.local_agent_metadata is not None @@ -839,8 +854,19 @@ class LLMDataDistCMgrConnectorWorker(): self, finished_req_ids: set[str] ) -> tuple[Optional[set[str]], Optional[set[str]]]: """Get the finished recving and sending requuests.""" - import copy + now = time.perf_counter() with self.thread_lock: + while self.reqs_to_send: + req_id, expires = next(iter(self.reqs_to_send.items())) + if now < expires: + break + logger.warning( + "Some requests in prefill node fail to receive KV Cache transfer done signal. " + "If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. " + ) + if req_id in self.reqs_to_send: + self.finished_reqs.add(req_id) + del self.reqs_to_send[req_id] req_ids_to_ret = copy.deepcopy(self.finished_reqs) self.finished_reqs.clear() if self.llm_datadist_role == LLMRole.PROMPT: @@ -871,4 +897,4 @@ def zmq_ctx(socket_type: Any, yield socket finally: if ctx is not None: - ctx.destroy(linger=0) + ctx.destroy(linger=0) \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 75753f0..8c37ad4 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -19,6 +19,7 @@ import numpy.typing as npt import torch import zmq from mooncake.engine import TransferEngine # type: ignore +from vllm import envs from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) @@ -100,7 +101,7 @@ class KVCacheTaskTracker: while self.delayed_free_requests: request_id, delay_start_time = self.delayed_free_requests[0] if (current_time - delay_start_time - > envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT): + > envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT): self.delayed_free_requests.popleft() expired_requests.add(request_id) logger.info("Force freed request: %s", request_id) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 5792c83..61df5e1 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -159,11 +159,6 @@ env_variables: Dict[str, Callable[[], Any]] = { # caused by the initialization of the Mooncake connector. "PHYSICAL_DEVICES": lambda: os.getenv("PHYSICAL_DEVICES", None), - # Timeout (in seconds) for delayed KVCache block release. In the prefill - # node, if a request is marked for delayed KV block release and the blocks - # are not freed within this timeout, they will be forcibly released. - "VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT": - lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), } # end-env-vars-definition @@ -177,4 +172,4 @@ def __getattr__(name: str): def __dir__(): - return list(env_variables.keys()) + return list(env_variables.keys()) \ No newline at end of file