LLMdatadist connector adapt the distributed KV aggregation (#2718)

### What this PR does / why we need it?
LLMdatadist connector adapt the distributed KV aggregation for the main
branch. Change the P node from returning "finish sending" only when TP0
responds to returning "finish sending" as soon as each NPU receives it.
The D node will send a finish receive signal to the corresponding tp
rank of the P node.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
gsm8k test
2*A3 1P 1D
P: dp2 tp8 D:dp 4 tp4
P: dp2 tp8 D:dp 2 tp8


- vLLM version: main
- vLLM main:
cc99baf14d

Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
liziyu
2025-09-11 11:37:41 +08:00
committed by GitHub
parent c2fdd4b8bc
commit 5691104249
3 changed files with 31 additions and 50 deletions

View File

@@ -108,4 +108,5 @@ jobs:
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
run: |
git config --global --add safe.directory/__w/vllm-ascend/vllm-ascend
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh

View File

@@ -42,7 +42,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_LLMDD_RPC_PORT=5559
export VLLM_ASCEND_LLMDD_RPC_PORT=5559
vllm serve /models/deepseek_r1_w8a8 \
--host 0.0.0.0 \
@@ -85,7 +85,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_LLMDD_RPC_PORT=5659
export VLLM_ASCEND_LLMDD_RPC_PORT=5659
vllm serve /models/deepseek_r1_w8a8 \
--host 0.0.0.0 \
@@ -131,7 +131,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_LLMDD_RPC_PORT=5759
export VLLM_ASCEND_LLMDD_RPC_PORT=5759
vllm serve /models/deepseek_r1_w8a8 \
--host 0.0.0.0 \
@@ -173,7 +173,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=1
export VLLM_LLMDD_RPC_PORT=5859
export VLLM_ASCEND_LLMDD_RPC_PORT=5859
vllm serve /models/deepseek_r1_w8a8 \
--host 0.0.0.0 \

View File

@@ -375,16 +375,11 @@ class LLMDataDistCMgrConnectorWorker():
)
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
finished_req_id = decode_msg[0]
decode_tp_rank = decode_msg[1]
decode_tp_size = decode_msg[2]
with self.thread_lock:
if self._increment_task_count(finished_req_id,
decode_tp_rank,
decode_tp_size):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
@@ -392,24 +387,6 @@ class LLMDataDistCMgrConnectorWorker():
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
)
def _increment_task_count(self, request_id: str, tp_rank: int,
decode_tp_size: int):
if request_id not in self.done_receiving_counts:
self.done_receiving_counts[request_id] = set()
if tp_rank in self.done_receiving_counts[request_id]:
logger.warning(
f"Received duplicate done signal for request {request_id} "
f"from tp rank {tp_rank}. Ignoring.")
return False
self.done_receiving_counts[request_id].add(tp_rank)
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
self.done_receiving_counts.pop(request_id)
logger.info("All transfers completed for request: "
f"{request_id}. Total ranks: "
f"{decode_tp_size}.")
return True
return False
def init_llm_datadist(self):
assert self.local_agent_metadata is not None
llm_config = LLMConfig()
@@ -767,24 +744,24 @@ class LLMDataDistCMgrConnectorWorker():
cluster_id = self.add_remote_agent(metadata)
return cluster_id
def send_finish_to_remote(self, host: str, port: int, request_id):
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode([
LLMDataDistCMgrEvent.ReqForFinished,
[request_id, self.tp_rank, self.tp_size]
])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}")
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
for port in ports:
url = f"tcp://{host}:{port}"
logger.debug(f"Sending finished to remote: {url}")
msg_encoder = msgspec.msgpack.Encoder()
msg_send = msg_encoder.encode(
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
try:
sock.send(msg_send)
logger.debug(
f"Request id {request_id} finished message send to remote {url}"
)
_ = sock.recv()
except Exception as e:
logger.error(
f"Failed to send reqest_id {request_id} to prefill: {e}"
)
def _read_blocks(
self,
@@ -851,7 +828,10 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
self.send_finish_to_remote(remote_ip, remote_port, request_id)
remote_ports = list(
range(remote_port + self.tp_rank,
remote_port + int(remote_tp_size), self.tp_size))
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
with self.thread_lock:
self.finished_reqs.add(request_id)