LLMdatadist connector adapt the distributed KV aggregation (#2718)
### What this PR does / why we need it?
LLMdatadist connector adapt the distributed KV aggregation for the main
branch. Change the P node from returning "finish sending" only when TP0
responds to returning "finish sending" as soon as each NPU receives it.
The D node will send a finish receive signal to the corresponding tp
rank of the P node.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
gsm8k test
2*A3 1P 1D
P: dp2 tp8 D:dp 4 tp4
P: dp2 tp8 D:dp 2 tp8
- vLLM version: main
- vLLM main:
cc99baf14d
Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test_pd.yaml
vendored
1
.github/workflows/vllm_ascend_test_pd.yaml
vendored
@@ -108,4 +108,5 @@ jobs:
|
||||
|
||||
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
|
||||
run: |
|
||||
git config --global --add safe.directory/__w/vllm-ascend/vllm-ascend
|
||||
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
|
||||
@@ -42,7 +42,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5559
|
||||
export VLLM_ASCEND_LLMDD_RPC_PORT=5559
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
@@ -85,7 +85,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5659
|
||||
export VLLM_ASCEND_LLMDD_RPC_PORT=5659
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
@@ -131,7 +131,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5759
|
||||
export VLLM_ASCEND_LLMDD_RPC_PORT=5759
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
@@ -173,7 +173,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5859
|
||||
export VLLM_ASCEND_LLMDD_RPC_PORT=5859
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
|
||||
@@ -375,16 +375,11 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
)
|
||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||
finished_req_id = decode_msg[0]
|
||||
decode_tp_rank = decode_msg[1]
|
||||
decode_tp_size = decode_msg[2]
|
||||
with self.thread_lock:
|
||||
if self._increment_task_count(finished_req_id,
|
||||
decode_tp_rank,
|
||||
decode_tp_size):
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
logger.debug(
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||
)
|
||||
self.finished_reqs.add(finished_req_id)
|
||||
sock.send_multipart(
|
||||
(identity, b"", b"receiving decode finished"))
|
||||
else:
|
||||
@@ -392,24 +387,6 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||
)
|
||||
|
||||
def _increment_task_count(self, request_id: str, tp_rank: int,
|
||||
decode_tp_size: int):
|
||||
if request_id not in self.done_receiving_counts:
|
||||
self.done_receiving_counts[request_id] = set()
|
||||
if tp_rank in self.done_receiving_counts[request_id]:
|
||||
logger.warning(
|
||||
f"Received duplicate done signal for request {request_id} "
|
||||
f"from tp rank {tp_rank}. Ignoring.")
|
||||
return False
|
||||
self.done_receiving_counts[request_id].add(tp_rank)
|
||||
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
|
||||
self.done_receiving_counts.pop(request_id)
|
||||
logger.info("All transfers completed for request: "
|
||||
f"{request_id}. Total ranks: "
|
||||
f"{decode_tp_size}.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def init_llm_datadist(self):
|
||||
assert self.local_agent_metadata is not None
|
||||
llm_config = LLMConfig()
|
||||
@@ -767,24 +744,24 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
cluster_id = self.add_remote_agent(metadata)
|
||||
return cluster_id
|
||||
|
||||
def send_finish_to_remote(self, host: str, port: int, request_id):
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode([
|
||||
LLMDataDistCMgrEvent.ReqForFinished,
|
||||
[request_id, self.tp_rank, self.tp_size]
|
||||
])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}")
|
||||
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
||||
for port in ports:
|
||||
url = f"tcp://{host}:{port}"
|
||||
logger.debug(f"Sending finished to remote: {url}")
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
msg_send = msg_encoder.encode(
|
||||
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||
try:
|
||||
sock.send(msg_send)
|
||||
logger.debug(
|
||||
f"Request id {request_id} finished message send to remote {url}"
|
||||
)
|
||||
_ = sock.recv()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
@@ -851,7 +828,10 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
self.send_finish_to_remote(remote_ip, remote_port, request_id)
|
||||
remote_ports = list(
|
||||
range(remote_port + self.tp_rank,
|
||||
remote_port + int(remote_tp_size), self.tp_size))
|
||||
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
||||
with self.thread_lock:
|
||||
self.finished_reqs.add(request_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user