LLMdatadist connector adapt the distributed KV aggregation (#2718)
### What this PR does / why we need it?
LLMdatadist connector adapt the distributed KV aggregation for the main
branch. Change the P node from returning "finish sending" only when TP0
responds to returning "finish sending" as soon as each NPU receives it.
The D node will send a finish receive signal to the corresponding tp
rank of the P node.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
gsm8k test
2*A3 1P 1D
P: dp2 tp8 D:dp 4 tp4
P: dp2 tp8 D:dp 2 tp8
- vLLM version: main
- vLLM main:
cc99baf14d
Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test_pd.yaml
vendored
1
.github/workflows/vllm_ascend_test_pd.yaml
vendored
@@ -108,4 +108,5 @@ jobs:
|
|||||||
|
|
||||||
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
|
- name: Run vllm-project/vllm-ascend PD Disaggregation edge test
|
||||||
run: |
|
run: |
|
||||||
|
git config --global --add safe.directory/__w/vllm-ascend/vllm-ascend
|
||||||
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
|
bash tests/e2e/pd_disaggreate/run_edge_case_test.sh
|
||||||
@@ -42,7 +42,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
|||||||
export OMP_PROC_BIND=false
|
export OMP_PROC_BIND=false
|
||||||
export OMP_NUM_THREADS=100
|
export OMP_NUM_THREADS=100
|
||||||
export VLLM_USE_V1=1
|
export VLLM_USE_V1=1
|
||||||
export VLLM_LLMDD_RPC_PORT=5559
|
export VLLM_ASCEND_LLMDD_RPC_PORT=5559
|
||||||
|
|
||||||
vllm serve /models/deepseek_r1_w8a8 \
|
vllm serve /models/deepseek_r1_w8a8 \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
@@ -85,7 +85,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
|||||||
export OMP_PROC_BIND=false
|
export OMP_PROC_BIND=false
|
||||||
export OMP_NUM_THREADS=100
|
export OMP_NUM_THREADS=100
|
||||||
export VLLM_USE_V1=1
|
export VLLM_USE_V1=1
|
||||||
export VLLM_LLMDD_RPC_PORT=5659
|
export VLLM_ASCEND_LLMDD_RPC_PORT=5659
|
||||||
|
|
||||||
vllm serve /models/deepseek_r1_w8a8 \
|
vllm serve /models/deepseek_r1_w8a8 \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
@@ -131,7 +131,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
|||||||
export OMP_PROC_BIND=false
|
export OMP_PROC_BIND=false
|
||||||
export OMP_NUM_THREADS=100
|
export OMP_NUM_THREADS=100
|
||||||
export VLLM_USE_V1=1
|
export VLLM_USE_V1=1
|
||||||
export VLLM_LLMDD_RPC_PORT=5759
|
export VLLM_ASCEND_LLMDD_RPC_PORT=5759
|
||||||
|
|
||||||
vllm serve /models/deepseek_r1_w8a8 \
|
vllm serve /models/deepseek_r1_w8a8 \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
@@ -173,7 +173,7 @@ export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/example
|
|||||||
export OMP_PROC_BIND=false
|
export OMP_PROC_BIND=false
|
||||||
export OMP_NUM_THREADS=100
|
export OMP_NUM_THREADS=100
|
||||||
export VLLM_USE_V1=1
|
export VLLM_USE_V1=1
|
||||||
export VLLM_LLMDD_RPC_PORT=5859
|
export VLLM_ASCEND_LLMDD_RPC_PORT=5859
|
||||||
|
|
||||||
vllm serve /models/deepseek_r1_w8a8 \
|
vllm serve /models/deepseek_r1_w8a8 \
|
||||||
--host 0.0.0.0 \
|
--host 0.0.0.0 \
|
||||||
|
|||||||
@@ -375,16 +375,11 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
)
|
)
|
||||||
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
elif event_msg == LLMDataDistCMgrEvent.ReqForFinished:
|
||||||
finished_req_id = decode_msg[0]
|
finished_req_id = decode_msg[0]
|
||||||
decode_tp_rank = decode_msg[1]
|
|
||||||
decode_tp_size = decode_msg[2]
|
|
||||||
with self.thread_lock:
|
with self.thread_lock:
|
||||||
if self._increment_task_count(finished_req_id,
|
logger.debug(
|
||||||
decode_tp_rank,
|
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
||||||
decode_tp_size):
|
)
|
||||||
logger.debug(
|
self.finished_reqs.add(finished_req_id)
|
||||||
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
|
|
||||||
)
|
|
||||||
self.finished_reqs.add(finished_req_id)
|
|
||||||
sock.send_multipart(
|
sock.send_multipart(
|
||||||
(identity, b"", b"receiving decode finished"))
|
(identity, b"", b"receiving decode finished"))
|
||||||
else:
|
else:
|
||||||
@@ -392,24 +387,6 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
f"LLMDataDistCMgrConnectorWorker: Receiving unexpected request event {event_msg} from remote !"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _increment_task_count(self, request_id: str, tp_rank: int,
|
|
||||||
decode_tp_size: int):
|
|
||||||
if request_id not in self.done_receiving_counts:
|
|
||||||
self.done_receiving_counts[request_id] = set()
|
|
||||||
if tp_rank in self.done_receiving_counts[request_id]:
|
|
||||||
logger.warning(
|
|
||||||
f"Received duplicate done signal for request {request_id} "
|
|
||||||
f"from tp rank {tp_rank}. Ignoring.")
|
|
||||||
return False
|
|
||||||
self.done_receiving_counts[request_id].add(tp_rank)
|
|
||||||
if len(self.done_receiving_counts[request_id]) == decode_tp_size:
|
|
||||||
self.done_receiving_counts.pop(request_id)
|
|
||||||
logger.info("All transfers completed for request: "
|
|
||||||
f"{request_id}. Total ranks: "
|
|
||||||
f"{decode_tp_size}.")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def init_llm_datadist(self):
|
def init_llm_datadist(self):
|
||||||
assert self.local_agent_metadata is not None
|
assert self.local_agent_metadata is not None
|
||||||
llm_config = LLMConfig()
|
llm_config = LLMConfig()
|
||||||
@@ -767,24 +744,24 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
cluster_id = self.add_remote_agent(metadata)
|
cluster_id = self.add_remote_agent(metadata)
|
||||||
return cluster_id
|
return cluster_id
|
||||||
|
|
||||||
def send_finish_to_remote(self, host: str, port: int, request_id):
|
def send_finish_to_remote(self, host: str, ports: list[int], request_id):
|
||||||
url = f"tcp://{host}:{port}"
|
for port in ports:
|
||||||
logger.debug(f"Sending finished to remote: {url}")
|
url = f"tcp://{host}:{port}"
|
||||||
msg_encoder = msgspec.msgpack.Encoder()
|
logger.debug(f"Sending finished to remote: {url}")
|
||||||
msg_send = msg_encoder.encode([
|
msg_encoder = msgspec.msgpack.Encoder()
|
||||||
LLMDataDistCMgrEvent.ReqForFinished,
|
msg_send = msg_encoder.encode(
|
||||||
[request_id, self.tp_rank, self.tp_size]
|
[LLMDataDistCMgrEvent.ReqForFinished, [request_id]])
|
||||||
])
|
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
||||||
with zmq_ctx(zmq.REQ, url) as sock: # type: ignore[attr-defined]
|
try:
|
||||||
try:
|
sock.send(msg_send)
|
||||||
sock.send(msg_send)
|
logger.debug(
|
||||||
logger.debug(
|
f"Request id {request_id} finished message send to remote {url}"
|
||||||
f"Request id {request_id} finished message send to remote {url}"
|
)
|
||||||
)
|
_ = sock.recv()
|
||||||
_ = sock.recv()
|
except Exception as e:
|
||||||
except Exception as e:
|
logger.error(
|
||||||
logger.error(
|
f"Failed to send reqest_id {request_id} to prefill: {e}"
|
||||||
f"Failed to send reqest_id {request_id} to prefill: {e}")
|
)
|
||||||
|
|
||||||
def _read_blocks(
|
def _read_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -851,7 +828,10 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||||
)
|
)
|
||||||
self.send_finish_to_remote(remote_ip, remote_port, request_id)
|
remote_ports = list(
|
||||||
|
range(remote_port + self.tp_rank,
|
||||||
|
remote_port + int(remote_tp_size), self.tp_size))
|
||||||
|
self.send_finish_to_remote(remote_ip, remote_ports, request_id)
|
||||||
with self.thread_lock:
|
with self.thread_lock:
|
||||||
self.finished_reqs.add(request_id)
|
self.finished_reqs.add(request_id)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user