Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
230
examples/disaggregated_prefill_v1/README.md
Normal file
230
examples/disaggregated_prefill_v1/README.md
Normal file
@@ -0,0 +1,230 @@
|
||||
# Disaggregated Prefill-Decode Deployment Guide
|
||||
|
||||
## Overview
|
||||
This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2).
|
||||
|
||||
## Prerequisites
|
||||
- Ascend NPU environment with vLLM 0.9.1 installed
|
||||
- Network interfaces configured for distributed communication (eg: eth0)
|
||||
- Model weights located at `/data01/deepseek_r1_w8a8_zhw`
|
||||
|
||||
## Rank table generation
|
||||
The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode:
|
||||
|
||||
Run the following command on every node to generate the rank table:
|
||||
```shell
|
||||
cd vllm-ascend/examples/disaggregate_prefill_v1/
|
||||
bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \
|
||||
--npus-per-node 8 --network-card-name enp189s0f0 --prefill-device-cnt 16 --decode-device-cnt 16
|
||||
```
|
||||
Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json`
|
||||
|
||||
## Start disaggregated vLLM-ascend service
|
||||
Execution Sequence
|
||||
- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36
|
||||
- Start Prefill on Node 1 (P1)
|
||||
- Start Prefill on Node 2 (P2)
|
||||
- Start Decode on Node 1 (D1)
|
||||
- Start Decode on Node 2 (D2)
|
||||
- Start proxy server on Node1
|
||||
|
||||
* Run prefill server P1 on first node
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.32.175 # node ip
|
||||
export GLOO_SOCKET_IFNAME="eth0" # network card name
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
vllm serve /data01/deepseek_r1_w8a8_zhw \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-size-local 1 \
|
||||
--api-server-count 2 \
|
||||
--data-parallel-address 172.19.32.175 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--no-enable-prefix-caching \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 6144 \
|
||||
--max-num-batched-tokens 6144 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
* Run prefill server P2 on second node
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.241.49
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
vllm serve /data01/deepseek_r1_w8a8_zhw \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--headless \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-start-rank 1 \
|
||||
--data-parallel-size-local 1 \
|
||||
--data-parallel-address 172.19.32.175 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--no-enable-prefix-caching \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 6144 \
|
||||
--max-num-batched-tokens 6144 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
* Run decode server d1 on third node
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.123.51
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
vllm serve /data01/deepseek_r1_w8a8_zhw \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-size-local 1 \
|
||||
--api-server-count 2 \
|
||||
--data-parallel-address 172.19.123.51 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--no-enable-prefix-caching \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 6144 \
|
||||
--max-num-batched-tokens 6144 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
* Run decode server d2 on last node
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.190.36
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
vllm serve /data01/deepseek_r1_w8a8_zhw \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--headless \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-start-rank 1 \
|
||||
--data-parallel-size-local 1 \
|
||||
--data-parallel-address 172.19.123.51 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--no-enable-prefix-caching \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 6144 \
|
||||
--max-num-batched-tokens 6144 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled": false, "enable_multistream_shared_expert": false}, "ascend_scheduler_config":{"enabled":false}}'
|
||||
```
|
||||
|
||||
* Run proxy server on the first node
|
||||
```shell
|
||||
cd /vllm-workspace/vllm-ascend/examples/disaggregate_prefill_v1
|
||||
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
|
||||
```
|
||||
|
||||
* Verification
|
||||
Check service health using the proxy server endpoint:
|
||||
```shell
|
||||
curl http://localhost:1025/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek",
|
||||
"prompt": "Who are you?",
|
||||
"max_tokens": 100,
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
* Performance
|
||||
Test performance with vllm benchmark
|
||||
```shell
|
||||
cd /vllm-workspace/vllm/benchmarks
|
||||
python3 benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--random-input-len 4096 \
|
||||
--random-output-len 1536 \
|
||||
--num-prompts 256 \
|
||||
--ignore-eos \
|
||||
--model deepseek \
|
||||
--tokenizer /data01/deepseek_r1_w8a8_zhw \
|
||||
--host localhost \
|
||||
--port 8000 \
|
||||
--endpoint /v1/completions \
|
||||
--max-concurrency 4 \
|
||||
--request-rate 4
|
||||
```
|
||||
120
examples/disaggregated_prefill_v1/gen_ranktable.py
Normal file
120
examples/disaggregated_prefill_v1/gen_ranktable.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.soc_info import NPUSocInfo
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Arguments of rank table generator", )
|
||||
parser.add_argument("--local-host", type=str, required=True, help="local ip")
|
||||
parser.add_argument("--prefill-device-cnt",
|
||||
type=int,
|
||||
required=True,
|
||||
help="number of prefill devices")
|
||||
parser.add_argument("--decode-device-cnt",
|
||||
type=int,
|
||||
required=True,
|
||||
help="number of decode devices")
|
||||
args = parser.parse_args()
|
||||
local_host = args.local_host
|
||||
prefill_device_cnt = args.prefill_device_cnt
|
||||
decode_device_cnt = args.decode_device_cnt
|
||||
|
||||
print("enter py")
|
||||
|
||||
hccn_tool_path = os.environ.get("HCCN_TOOL_PATH",
|
||||
"/usr/local/Ascend/driver/tools/hccn_tool")
|
||||
master_addr = os.environ.get("MASTER_ADDR")
|
||||
master_port = os.environ.get("MASTER_PORT")
|
||||
rank = os.environ.get("RANK")
|
||||
local_rank = os.environ.get("LOCAL_RANK")
|
||||
# This variable is set by torchrun,
|
||||
# and is different from WORLD_SIZE in gen_rank_table.sh.
|
||||
world_size = os.environ.get("WORLD_SIZE")
|
||||
soc_info = NPUSocInfo()
|
||||
|
||||
|
||||
def get_cmd_stdout(cmd):
|
||||
import subprocess
|
||||
return subprocess.run(cmd, capture_output=True,
|
||||
shell=True).stdout.decode("utf-8").strip()
|
||||
|
||||
|
||||
print(f"local_host: {local_host}")
|
||||
print("gen ranktable.json")
|
||||
|
||||
num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split(
|
||||
":")[1].strip()
|
||||
num_cards = int(num_cards)
|
||||
chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split(
|
||||
"\n")[0].split(":")[1].strip()
|
||||
chips_per_card = int(chips_per_card)
|
||||
|
||||
# generate local device list for local rank 0, and gather it to all ranks
|
||||
local_device_list: list[dict[str, str]] = list()
|
||||
if local_rank == "0":
|
||||
super_pod_id = "0"
|
||||
for card_id in range(num_cards):
|
||||
for chip_id in range(chips_per_card):
|
||||
device_id = card_id * chips_per_card + chip_id
|
||||
if soc_info.is_a3:
|
||||
device_ip = get_cmd_stdout(
|
||||
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
|
||||
).split(":")[1].strip()
|
||||
super_device_id = get_cmd_stdout(
|
||||
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID"
|
||||
).split(":")[1].strip()
|
||||
super_pod_id = get_cmd_stdout(
|
||||
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\""
|
||||
).split(":")[1].strip()
|
||||
else:
|
||||
device_ip = get_cmd_stdout(
|
||||
f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr"
|
||||
).split(":")[1].strip()
|
||||
|
||||
device_info = {
|
||||
"server_id": local_host,
|
||||
"device_id": str(device_id),
|
||||
"device_ip": str(device_ip),
|
||||
}
|
||||
if soc_info.is_a3:
|
||||
device_info.update({
|
||||
"super_pod_id": str(super_pod_id),
|
||||
"super_device_id": str(super_device_id)
|
||||
})
|
||||
local_device_list.append(device_info)
|
||||
|
||||
dist.init_process_group(backend=dist.Backend.GLOO)
|
||||
global_device_list = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(global_device_list, local_device_list)
|
||||
global_device_list = [
|
||||
device_info for device_list in global_device_list
|
||||
for device_info in device_list # type: ignore[attr-defined]
|
||||
]
|
||||
cnt = 1
|
||||
for device_info in global_device_list: # type: ignore[assignment]
|
||||
device_info["cluster_id"] = str(cnt)
|
||||
cnt += 1
|
||||
assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \
|
||||
"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster"
|
||||
ranktable = {
|
||||
"version":
|
||||
"1.2",
|
||||
"server_count":
|
||||
str(world_size),
|
||||
"prefill_device_list":
|
||||
global_device_list[:prefill_device_cnt],
|
||||
"decode_device_list":
|
||||
global_device_list[prefill_device_cnt:prefill_device_cnt +
|
||||
decode_device_cnt],
|
||||
"status":
|
||||
"completed"
|
||||
}
|
||||
|
||||
if local_rank == '0':
|
||||
with open("ranktable.json", "w") as f:
|
||||
json.dump(ranktable, f, indent=4)
|
||||
|
||||
print("gen ranktable.json done")
|
||||
79
examples/disaggregated_prefill_v1/gen_ranktable.sh
Normal file
79
examples/disaggregated_prefill_v1/gen_ranktable.sh
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}
|
||||
|
||||
NPUS_PER_NODE=8
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--ips)
|
||||
shift
|
||||
while [[ $# -gt 0 && ! "$1" == --* ]]; do
|
||||
IPs+=("$1")
|
||||
shift
|
||||
done
|
||||
;;
|
||||
--npus-per-node)
|
||||
shift
|
||||
NPUS_PER_NODE="$1"
|
||||
shift
|
||||
;;
|
||||
--network-card-name)
|
||||
shift
|
||||
NETWORK_CARD_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
--prefill-device-cnt)
|
||||
shift
|
||||
PREFILL_DEVICE_CNT="$1"
|
||||
shift
|
||||
;;
|
||||
--decode-device-cnt)
|
||||
shift
|
||||
DECODE_DEVICE_CNT="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
LOCAL_HOSTS=($(hostname -I))
|
||||
LOCAL_HOST="127.0.0.1"
|
||||
MASTER_ADDR=${IPs[0]}
|
||||
MASTER_PORT=6657
|
||||
NNODES=${#IPs[@]}
|
||||
NODE_RANK="8"
|
||||
for i in "${!IPs[@]}"; do
|
||||
ip="${IPs[$i]}"
|
||||
for local_host in "${LOCAL_HOSTS[@]}"; do
|
||||
if [[ "$local_host" == "$ip" ]]; then
|
||||
LOCAL_HOST=$local_host
|
||||
NODE_RANK=$i
|
||||
break 2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [[ $NODE_RANK == "" ]];then
|
||||
echo "[Error] para \"NODE_RANK\" must be defined"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES))
|
||||
RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK`
|
||||
|
||||
echo "========>param:"
|
||||
echo "LOCAL_HOST": $LOCAL_HOST
|
||||
echo "WORLD_SIZE: " $WORLD_SIZE
|
||||
echo "RANKSTART": $RANKSTART
|
||||
echo "NNODES": $NNODES
|
||||
echo "NODE_RANK": $NODE_RANK
|
||||
echo "==============="
|
||||
|
||||
if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then
|
||||
GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \
|
||||
--nproc_per_node 1 \
|
||||
--nnodes ${NNODES} \
|
||||
--node_rank ${NODE_RANK} \
|
||||
--master_addr ${MASTER_ADDR} \
|
||||
--master_port ${MASTER_PORT} \
|
||||
gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT
|
||||
fi
|
||||
32
examples/disaggregated_prefill_v1/run_server.sh
Normal file
32
examples/disaggregated_prefill_v1/run_server.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
export HCCL_IF_IP=141.61.39.117
|
||||
export GLOO_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export TP_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export HCCL_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table
|
||||
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
vllm serve model_path \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--tensor-parallel-size 1\
|
||||
--seed 1024 \
|
||||
--served-model-name dsv3 \
|
||||
--max-model-len 2000 \
|
||||
---max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": 0,
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"enable_graph_mode": "True"}'\
|
||||
275
examples/disaggregated_prefill_v1/toy_proxy_server.py
Normal file
275
examples/disaggregated_prefill_v1/toy_proxy_server.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize client pools for prefiller and decoder services
|
||||
app.state.prefill_clients = []
|
||||
app.state.decode_clients = []
|
||||
limit = httpx.Limits(max_connections=100000,
|
||||
max_keepalive_connections=100000)
|
||||
|
||||
# Create prefill clients
|
||||
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||
prefiller_base_url = f'http://{host}:{port}/v1'
|
||||
app.state.prefill_clients.append({
|
||||
'client':
|
||||
httpx.AsyncClient(timeout=None,
|
||||
base_url=prefiller_base_url,
|
||||
limits=limit),
|
||||
'host':
|
||||
host,
|
||||
'port':
|
||||
port,
|
||||
'id':
|
||||
i
|
||||
})
|
||||
|
||||
# Create decode clients
|
||||
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||
decoder_base_url = f'http://{host}:{port}/v1'
|
||||
app.state.decode_clients.append({
|
||||
'client':
|
||||
httpx.AsyncClient(timeout=None,
|
||||
base_url=decoder_base_url,
|
||||
limits=limit),
|
||||
'host':
|
||||
host,
|
||||
'port':
|
||||
port,
|
||||
'id':
|
||||
i
|
||||
})
|
||||
|
||||
# Initialize round-robin iterators
|
||||
app.state.prefill_iterator = itertools.cycle(
|
||||
range(len(app.state.prefill_clients)))
|
||||
app.state.decode_iterator = itertools.cycle(
|
||||
range(len(app.state.decode_clients)))
|
||||
|
||||
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
||||
f"and {len(app.state.decode_clients)} decode clients.")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close all clients
|
||||
for client_info in app.state.prefill_clients:
|
||||
await client_info['client'].aclose()
|
||||
|
||||
for client_info in app.state.decode_clients:
|
||||
await client_info['client'].aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
"--prefiller-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
"--prefiller-port",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8100])
|
||||
|
||||
# For decoder instances
|
||||
parser.add_argument("--decoder-hosts",
|
||||
"--decoder-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports",
|
||||
"--decoder-port",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8200])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate and pair hosts with ports
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
|
||||
# Create tuples of (host, port) for each service type
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_next_client(app, service_type: str):
|
||||
"""
|
||||
Get the next client in round-robin fashion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app instance
|
||||
service_type: Either 'prefill' or 'decode'
|
||||
|
||||
Returns:
|
||||
The next client to use
|
||||
"""
|
||||
if service_type == 'prefill':
|
||||
client_idx = next(app.state.prefill_iterator)
|
||||
return app.state.prefill_clients[client_idx]
|
||||
elif service_type == 'decode':
|
||||
client_idx = next(app.state.decode_iterator)
|
||||
return app.state.decode_clients[client_idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
|
||||
async def send_request_to_service(client_info: dict, endpoint: str,
|
||||
req_data: dict, request_id: str):
|
||||
"""
|
||||
Send a request to a service using a client from the pool.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
|
||||
response = await client_info['client'].post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(client_info: dict, endpoint: str,
|
||||
req_data: dict, request_id: str):
|
||||
"""
|
||||
Asynchronously stream response from a service using a client from the pool.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
|
||||
async with client_info['client'].stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Get the next prefill client in round-robin fashion
|
||||
prefill_client_info = get_next_client(request.app, 'prefill')
|
||||
|
||||
# Send request to prefill service
|
||||
response = await send_request_to_service(prefill_client_info, api,
|
||||
req_data, request_id)
|
||||
|
||||
# Extract the needed fields
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
# Get the next decode client in round-robin fashion
|
||||
decode_client_info = get_next_client(request.app, 'decode')
|
||||
|
||||
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(decode_client_info,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
"""Simple endpoint to check if the server is running."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(app.state.prefill_clients),
|
||||
"decode_instances": len(app.state.decode_clients)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
Reference in New Issue
Block a user