From 5691104249bbee7648e8cfc1466a96c092a8d76d Mon Sep 17 00:00:00 2001 From: liziyu <56102866+liziyu179@users.noreply.github.com> Date: Thu, 11 Sep 2025 11:37:41 +0800 Subject: [PATCH] LLMdatadist connector adapt the distributed KV aggregation (#2718) ### What this PR does / why we need it? LLMdatadist connector adapt the distributed KV aggregation for the main branch. Change the P node from returning "finish sending" only when TP0 responds to returning "finish sending" as soon as each NPU receives it. The D node will send a finish receive signal to the corresponding tp rank of the P node. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? gsm8k test 2*A3 1P 1D P: dp2 tp8 D:dp 4 tp4 P: dp2 tp8 D:dp 2 tp8 - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/cc99baf14dacc2497d0c5ed84e076ef2c37f6a4d Signed-off-by: liziyu --- .github/workflows/vllm_ascend_test_pd.yaml | 1 + examples/disaggregated_prefill_v1/README.md | 8 +-- .../llmdatadist_c_mgr_connector.py | 72 +++++++------------ 3 files changed, 31 insertions(+), 50 deletions(-) diff --git a/.github/workflows/vllm_ascend_test_pd.yaml b/.github/workflows/vllm_ascend_test_pd.yaml index a86ba60..fee06be 100644 --- a/.github/workflows/vllm_ascend_test_pd.yaml +++ b/.github/workflows/vllm_ascend_test_pd.yaml @@ -108,4 +108,5 @@ jobs: - name: Run vllm-project/vllm-ascend PD Disaggregation edge test run: | + git config --global --add safe.directory/__w/vllm-ascend/vllm-ascend bash tests/e2e/pd_disaggreate/run_edge_case_test.sh \ No newline at end of file diff --git a/examples/disaggregated_prefill_v1/README.md b/examples/disaggregated_prefill_v1/README.md index eec8924..c42cace 100644 --- a/examples/disaggregated_prefill_v1/README.md +++ b/examples/disaggregated_prefill_v1/README.md @@ -42,7 +42,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5559 +export VLLM_ASCEND_LLMDD_RPC_PORT=5559 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -85,7 +85,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5659 +export VLLM_ASCEND_LLMDD_RPC_PORT=5659 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -131,7 +131,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5759 +export VLLM_ASCEND_LLMDD_RPC_PORT=5759 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ @@ -173,7 +173,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 export VLLM_USE_V1=1 -export VLLM_LLMDD_RPC_PORT=5859 +export VLLM_ASCEND_LLMDD_RPC_PORT=5859 vllm serve /models/deepseek_r1_w8a8 \ --host 0.0.0.0 \ diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index fe6617a..ddbc3c0 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -375,16 +375,11 @@ class LLMDataDistCMgrConnectorWorker(): ) elif event_msg == LLMDataDistCMgrEvent.ReqForFinished: finished_req_id = decode_msg[0] - decode_tp_rank = decode_msg[1] - decode_tp_size = decode_msg[2] with self.thread_lock: - if self._increment_task_count(finished_req_id, - decode_tp_rank, - decode_tp_size): - logger.debug( - f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" - ) - self.finished_reqs.add(finished_req_id) + logger.debug( + f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished" + ) + self.finished_reqs.add(finished_req_id) sock.send_multipart( (identity, b"", b"receiving decode finished")) else: @@ -392,24 +387,6 @@ class LLMDataDistCMgrConnectorWorker(): f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !" ) - def _increment_task_count(self, request_id: str, tp_rank: int, - decode_tp_size: int): - if request_id not in self.done_receiving_counts: - self.done_receiving_counts[request_id] = set() - if tp_rank in self.done_receiving_counts[request_id]: - logger.warning( - f"Received duplicate done signal for request {request_id} " - f"from tp rank {tp_rank}. Ignoring.") - return False - self.done_receiving_counts[request_id].add(tp_rank) - if len(self.done_receiving_counts[request_id]) == decode_tp_size: - self.done_receiving_counts.pop(request_id) - logger.info("All transfers completed for request: " - f"{request_id}. Total ranks: " - f"{decode_tp_size}.") - return True - return False - def init_llm_datadist(self): assert self.local_agent_metadata is not None llm_config = LLMConfig() @@ -767,24 +744,24 @@ class LLMDataDistCMgrConnectorWorker(): cluster_id = self.add_remote_agent(metadata) return cluster_id - def send_finish_to_remote(self, host: str, port: int, request_id): - url = f"tcp://{host}:{port}" - logger.debug(f"Sending finished to remote: {url}") - msg_encoder = msgspec.msgpack.Encoder() - msg_send = msg_encoder.encode([ - LLMDataDistCMgrEvent.ReqForFinished, - [request_id, self.tp_rank, self.tp_size] - ]) - with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] - try: - sock.send(msg_send) - logger.debug( - f"Request id {request_id} finished message send to remote {url}" - ) - _ = sock.recv() - except Exception as e: - logger.error( - f"Failed to send reqest_id {request_id} to prefill: {e}") + def send_finish_to_remote(self, host: str, ports: list[int], request_id): + for port in ports: + url = f"tcp://{host}:{port}" + logger.debug(f"Sending finished to remote: {url}") + msg_encoder = msgspec.msgpack.Encoder() + msg_send = msg_encoder.encode( + [LLMDataDistCMgrEvent.ReqForFinished, [request_id]]) + with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined] + try: + sock.send(msg_send) + logger.debug( + f"Request id {request_id} finished message send to remote {url}" + ) + _ = sock.recv() + except Exception as e: + logger.error( + f"Failed to send reqest_id {request_id} to prefill: {e}" + ) def _read_blocks( self, @@ -851,7 +828,10 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) - self.send_finish_to_remote(remote_ip, remote_port, request_id) + remote_ports = list( + range(remote_port + self.tp_rank, + remote_port + int(remote_tp_size), self.tp_size)) + self.send_finish_to_remote(remote_ip, remote_ports, request_id) with self.thread_lock: self.finished_reqs.add(request_id)