v0.10.1rc1
This commit is contained in:
246
examples/disaggregated_prefill_v1/README.md
Normal file
246
examples/disaggregated_prefill_v1/README.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# Disaggregated Prefill-Decode Deployment Guide
|
||||
|
||||
## Overview
|
||||
This demo document provides instructions for running a disaggregated vLLM-ascend service with separate prefill and decode stages across 4 nodes, uses 16 Ascend NPUs for two prefill nodes (P1/P2) and 16 Ascend NPUS for two decode nodes (D1/D2).
|
||||
|
||||
## Prerequisites
|
||||
- Ascend NPU environment with vLLM 0.9.1 installed
|
||||
- Network interfaces configured for distributed communication (eg: eth0)
|
||||
- Model weights located at `/models/deepseek_r1_w8a8`
|
||||
|
||||
## Rank table generation
|
||||
The rank table is a JSON file that specifies the mapping of Ascend NPU ranks to nodes. The following command generates a rank table for all nodes with 16 cards prefill and 16 cards decode:
|
||||
|
||||
Run the following command on every node to generate the rank table:
|
||||
```shell
|
||||
cd /vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/
|
||||
bash gen_ranktable.sh --ips 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36 \
|
||||
--npus-per-node 8 --network-card-name eth0 --prefill-device-cnt 16 --decode-device-cnt 16
|
||||
```
|
||||
Rank table will generated at `/vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/ranktable.json`
|
||||
|
||||
## Start disaggregated vLLM-ascend service
|
||||
For demonstration purposes, we will utilize the quantized version of Deepseek-R1. Recommended Parallelization Strategies:
|
||||
- P-node: DP2-TP8-EP16 (Data Parallelism 2, Tensor Parallelism 8, Expert Parallelism 16)
|
||||
- D-node: DP4-TP4-EP16 (Data Parallelism 4, Tensor Parallelism 4, Expert Parallelism 16)
|
||||
|
||||
Execution Sequence
|
||||
- 4 configured node ip are: 172.19.32.175 172.19.241.49 172.19.123.51 172.19.190.36
|
||||
- Start Prefill on Node 1 (P1)
|
||||
- Start Prefill on Node 2 (P2)
|
||||
- Start Decode on Node 1 (D1)
|
||||
- Start Decode on Node 2 (D2)
|
||||
- Start proxy server on Node1
|
||||
|
||||
Run prefill server P1 on first node:
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.32.175 # node ip
|
||||
export GLOO_SOCKET_IFNAME="eth0" # network card name
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5559
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-size-local 1 \
|
||||
--api-server-count 2 \
|
||||
--data-parallel-address 172.19.32.175 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--enable-expert-parallel \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 32768 \
|
||||
--max-num-batched-tokens 32768 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"chunked_prefill_for_mla":true}'
|
||||
```
|
||||
|
||||
Run prefill server P2 on second node:
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.241.49
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5659
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--headless \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-start-rank 1 \
|
||||
--data-parallel-size-local 1 \
|
||||
--data-parallel-address 172.19.32.175 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 8 \
|
||||
--enable-expert-parallel \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 32768 \
|
||||
--max-num-batched-tokens 32768 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"chunked_prefill_for_mla":true}'
|
||||
```
|
||||
|
||||
Run decode server d1 on third node:
|
||||
|
||||
* In the D node, the `max-num-batched-tokens` parameter can be set to a smaller value since the D node processes at most `max-num-seqs` batches concurrently. As the `profile_run` only needs to handle `max-num-seqs` sequences at a time, we can safely set `max-num-batched-tokens` equal to `max-num-seqs`. This optimization will help reduce activation memory consumption.
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.123.51
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5759
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--data-parallel-size 4 \
|
||||
--data-parallel-size-local 2 \
|
||||
--api-server-count 2 \
|
||||
--data-parallel-address 172.19.123.51 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 32768 \
|
||||
--max-num-batched-tokens 256 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled":true}}'
|
||||
```
|
||||
|
||||
Run decode server d2 on last node:
|
||||
```shell
|
||||
export HCCL_IF_IP=172.19.190.36
|
||||
export GLOO_SOCKET_IFNAME="eth0"
|
||||
export TP_SOCKET_IFNAME="eth0"
|
||||
export HCCL_SOCKET_IFNAME="eth0"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=/vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1/ranktable.json
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_LLMDD_RPC_PORT=5859
|
||||
|
||||
vllm serve /models/deepseek_r1_w8a8 \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--headless \
|
||||
--data-parallel-size 4 \
|
||||
--data-parallel-start-rank 2 \
|
||||
--data-parallel-size-local 2 \
|
||||
--data-parallel-address 172.19.123.51 \
|
||||
--data-parallel-rpc-port 13356 \
|
||||
--tensor-parallel-size 4 \
|
||||
--enable-expert-parallel \
|
||||
--seed 1024 \
|
||||
--served-model-name deepseek \
|
||||
--max-model-len 32768 \
|
||||
--max-num-batched-tokens 256 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"torchair_graph_config": {"enabled":true}}'
|
||||
```
|
||||
|
||||
Run proxy server on the first node:
|
||||
```shell
|
||||
cd /vllm-workspace/vllm-ascend/examples/disaggregated_prefill_v1
|
||||
python toy_proxy_server.py --host 172.19.32.175 --port 1025 --prefiller-hosts 172.19.241.49 --prefiller-port 20002 --decoder-hosts 172.19.123.51 --decoder-ports 20002
|
||||
```
|
||||
|
||||
Verification
|
||||
Check service health using the proxy server endpoint:
|
||||
```shell
|
||||
curl http://localhost:1025/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "deepseek",
|
||||
"prompt": "Who are you?",
|
||||
"max_tokens": 100,
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
Performance
|
||||
Test performance with vllm benchmark:
|
||||
```shell
|
||||
cd /vllm-workspace/vllm/benchmarks
|
||||
python3 benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--random-input-len 4096 \
|
||||
--random-output-len 1536 \
|
||||
--num-prompts 256 \
|
||||
--ignore-eos \
|
||||
--model deepseek \
|
||||
--tokenizer /models/deepseek_r1_w8a8 \
|
||||
--host localhost \
|
||||
--port 1025 \
|
||||
--endpoint /v1/completions \
|
||||
--max-concurrency 4 \
|
||||
--request-rate 4
|
||||
```
|
||||
122
examples/disaggregated_prefill_v1/gen_ranktable.py
Normal file
122
examples/disaggregated_prefill_v1/gen_ranktable.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Arguments of rank table generator", )
|
||||
parser.add_argument("--local-host", type=str, required=True, help="local ip")
|
||||
parser.add_argument("--prefill-device-cnt",
|
||||
type=int,
|
||||
required=True,
|
||||
help="number of prefill devices")
|
||||
parser.add_argument("--decode-device-cnt",
|
||||
type=int,
|
||||
required=True,
|
||||
help="number of decode devices")
|
||||
args = parser.parse_args()
|
||||
local_host = args.local_host
|
||||
prefill_device_cnt = args.prefill_device_cnt
|
||||
decode_device_cnt = args.decode_device_cnt
|
||||
|
||||
print("enter py")
|
||||
|
||||
hccn_tool_path = os.environ.get("HCCN_TOOL_PATH",
|
||||
"/usr/local/Ascend/driver/tools/hccn_tool")
|
||||
master_addr = os.environ.get("MASTER_ADDR")
|
||||
master_port = os.environ.get("MASTER_PORT")
|
||||
rank = os.environ.get("RANK")
|
||||
local_rank = os.environ.get("LOCAL_RANK")
|
||||
# This variable is set by torchrun,
|
||||
# and is different from WORLD_SIZE in gen_rank_table.sh.
|
||||
world_size = os.environ.get("WORLD_SIZE")
|
||||
|
||||
init_ascend_soc_version()
|
||||
soc_info = get_ascend_soc_version()
|
||||
|
||||
|
||||
def get_cmd_stdout(cmd):
|
||||
import subprocess
|
||||
return subprocess.run(cmd, capture_output=True,
|
||||
shell=True).stdout.decode("utf-8").strip()
|
||||
|
||||
|
||||
print(f"local_host: {local_host}")
|
||||
print("gen ranktable.json")
|
||||
|
||||
num_cards = get_cmd_stdout("npu-smi info -l | grep \"Total Count\"").split(
|
||||
":")[1].strip()
|
||||
num_cards = int(num_cards)
|
||||
chips_per_card = get_cmd_stdout("npu-smi info -l | grep \"Chip Count\"").split(
|
||||
"\n")[0].split(":")[1].strip()
|
||||
chips_per_card = int(chips_per_card)
|
||||
|
||||
# generate local device list for local rank 0, and gather it to all ranks
|
||||
local_device_list: list[dict[str, str]] = list()
|
||||
if local_rank == "0":
|
||||
super_pod_id = "0"
|
||||
for card_id in range(num_cards):
|
||||
for chip_id in range(chips_per_card):
|
||||
device_id = card_id * chips_per_card + chip_id
|
||||
if soc_info == AscendSocVersion.A3:
|
||||
device_ip = get_cmd_stdout(
|
||||
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
|
||||
).split(":")[1].strip()
|
||||
super_device_id = get_cmd_stdout(
|
||||
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep SDID"
|
||||
).split(":")[1].strip()
|
||||
super_pod_id = get_cmd_stdout(
|
||||
f"npu-smi info -t spod-info -i {card_id} -c {chip_id} | grep \"Super Pod ID\""
|
||||
).split(":")[1].strip()
|
||||
else:
|
||||
device_ip = get_cmd_stdout(
|
||||
f"{hccn_tool_path} -i {device_id} -ip -g | grep ipaddr"
|
||||
).split(":")[1].strip()
|
||||
|
||||
device_info = {
|
||||
"server_id": local_host,
|
||||
"device_id": str(device_id),
|
||||
"device_ip": str(device_ip),
|
||||
}
|
||||
if soc_info == AscendSocVersion.A3:
|
||||
device_info.update({
|
||||
"super_pod_id": str(super_pod_id),
|
||||
"super_device_id": str(super_device_id)
|
||||
})
|
||||
local_device_list.append(device_info)
|
||||
|
||||
dist.init_process_group(backend=dist.Backend.GLOO)
|
||||
global_device_list = [None] * dist.get_world_size()
|
||||
dist.all_gather_object(global_device_list, local_device_list)
|
||||
global_device_list = [
|
||||
device_info for device_list in global_device_list
|
||||
for device_info in device_list # type: ignore[attr-defined]
|
||||
]
|
||||
cnt = 1
|
||||
for device_info in global_device_list: # type: ignore[assignment]
|
||||
device_info["cluster_id"] = str(cnt)
|
||||
cnt += 1
|
||||
assert (prefill_device_cnt + decode_device_cnt) <= len(global_device_list), \
|
||||
"prefill_device_cnt + decode_device_cnt must be less than or equal to number of all devices in cluster"
|
||||
ranktable = {
|
||||
"version":
|
||||
"1.2",
|
||||
"server_count":
|
||||
str(world_size),
|
||||
"prefill_device_list":
|
||||
global_device_list[:prefill_device_cnt],
|
||||
"decode_device_list":
|
||||
global_device_list[prefill_device_cnt:prefill_device_cnt +
|
||||
decode_device_cnt],
|
||||
"status":
|
||||
"completed"
|
||||
}
|
||||
|
||||
if local_rank == '0':
|
||||
with open("ranktable.json", "w") as f:
|
||||
json.dump(ranktable, f, indent=4)
|
||||
|
||||
print("gen ranktable.json done")
|
||||
79
examples/disaggregated_prefill_v1/gen_ranktable.sh
Normal file
79
examples/disaggregated_prefill_v1/gen_ranktable.sh
Normal file
@@ -0,0 +1,79 @@
|
||||
#!/bin/bash
|
||||
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/opp/vendors/customize/op_api/lib/:${LD_LIBRARY_PATH}
|
||||
|
||||
NPUS_PER_NODE=8
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--ips)
|
||||
shift
|
||||
while [[ $# -gt 0 && ! "$1" == --* ]]; do
|
||||
IPs+=("$1")
|
||||
shift
|
||||
done
|
||||
;;
|
||||
--npus-per-node)
|
||||
shift
|
||||
NPUS_PER_NODE="$1"
|
||||
shift
|
||||
;;
|
||||
--network-card-name)
|
||||
shift
|
||||
NETWORK_CARD_NAME="$1"
|
||||
shift
|
||||
;;
|
||||
--prefill-device-cnt)
|
||||
shift
|
||||
PREFILL_DEVICE_CNT="$1"
|
||||
shift
|
||||
;;
|
||||
--decode-device-cnt)
|
||||
shift
|
||||
DECODE_DEVICE_CNT="$1"
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
LOCAL_HOSTS=($(hostname -I))
|
||||
LOCAL_HOST="127.0.0.1"
|
||||
MASTER_ADDR=${IPs[0]}
|
||||
MASTER_PORT=6657
|
||||
NNODES=${#IPs[@]}
|
||||
NODE_RANK="8"
|
||||
for i in "${!IPs[@]}"; do
|
||||
ip="${IPs[$i]}"
|
||||
for local_host in "${LOCAL_HOSTS[@]}"; do
|
||||
if [[ "$local_host" == "$ip" ]]; then
|
||||
LOCAL_HOST=$local_host
|
||||
NODE_RANK=$i
|
||||
break 2
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [[ $NODE_RANK == "" ]];then
|
||||
echo "[Error] para \"NODE_RANK\" must be defined"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE * $NNODES))
|
||||
RANKSTART=`expr $NPUS_PER_NODE \* $NODE_RANK`
|
||||
|
||||
echo "========>param:"
|
||||
echo "LOCAL_HOST": $LOCAL_HOST
|
||||
echo "WORLD_SIZE: " $WORLD_SIZE
|
||||
echo "RANKSTART": $RANKSTART
|
||||
echo "NNODES": $NNODES
|
||||
echo "NODE_RANK": $NODE_RANK
|
||||
echo "==============="
|
||||
|
||||
if [[ -n "${GEN_RANKTABLE}" || ! -e ${PWD}/ranktable.json ]]; then
|
||||
GLOO_SOCKET_IFNAME=$NETWORK_CARD_NAME torchrun \
|
||||
--nproc_per_node 1 \
|
||||
--nnodes ${NNODES} \
|
||||
--node_rank ${NODE_RANK} \
|
||||
--master_addr ${MASTER_ADDR} \
|
||||
--master_port ${MASTER_PORT} \
|
||||
gen_ranktable.py --local-host $LOCAL_HOST --prefill-device-cnt $PREFILL_DEVICE_CNT --decode-device-cnt $DECODE_DEVICE_CNT
|
||||
fi
|
||||
@@ -0,0 +1,546 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Tutorial: Using the Load Balance Proxy Server Example
|
||||
#
|
||||
# This proxy server is designed to distribute requests between multiple
|
||||
# "prefiller" and "decoder" backend servers for large language model inference.
|
||||
# It is useful for scaling out inference workloads and balancing load across
|
||||
# multiple backend instances.
|
||||
#
|
||||
# Features:
|
||||
# - Load balances requests to multiple prefiller and decoder servers.
|
||||
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
||||
# - Streams responses from backend servers to clients.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - Python 3.8+
|
||||
# - Install dependencies:
|
||||
# pip install fastapi httpx uvicorn vllm
|
||||
#
|
||||
# Step 1: Start Your Backend Servers
|
||||
# ----------------------------------
|
||||
# You need to have at least one prefiller and one decoder backend running.
|
||||
# These can be mock servers or actual vLLM servers.
|
||||
#
|
||||
# For testing, you can use the provided mock server:
|
||||
#
|
||||
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
||||
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
||||
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
||||
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
||||
#
|
||||
# Step 2: Start the Proxy Server
|
||||
# ------------------------------
|
||||
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
||||
#
|
||||
# python load_balance_proxy_server_example.py \
|
||||
# --host 0.0.0.0 --port 9000 \
|
||||
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --prefiller-ports 8100 8101 \
|
||||
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --decoder-ports 8200 8201
|
||||
#
|
||||
# This will start the proxy on port 9000, load balancing between two prefiller
|
||||
# and two decoder servers.
|
||||
#
|
||||
# Step 3: Send a Request to the Proxy
|
||||
# -----------------------------------
|
||||
# You can now send OpenAI-compatible requests to the proxy. For example:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "prompt": "The quick brown fox jumps over the lazy dog",
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Or for chat completions:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/chat/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "messages": [{"role": "user", "content": "Hello!"}],
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Step 4: Health Check
|
||||
# --------------------
|
||||
# To check if the proxy is running and see how many backend instances are
|
||||
# connected, use:
|
||||
#
|
||||
# curl http://localhost:9000/healthcheck
|
||||
#
|
||||
# This will return a JSON object with the status and the number of prefiller
|
||||
# and decoder instances.
|
||||
#
|
||||
# Notes:
|
||||
# - You can scale the number of prefiller and decoder servers as needed.
|
||||
# - The proxy will round-robin requests to balance load.
|
||||
# - For production, ensure your backend servers are robust and secure.
|
||||
#
|
||||
# For more details, see the code and comments in this file.
|
||||
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f'http://{host}:{port}/v1'
|
||||
self.client = httpx.AsyncClient(timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=100000,
|
||||
max_keepalive_connections=100000))
|
||||
self.active_tokens = 0
|
||||
self.active_kv_cache = 0 # Only for prefiller
|
||||
self.active_requests = 0 # Number of active requests
|
||||
self.aborted_requests = set() # Track aborted requests
|
||||
# Removed individual server lock - will use global locks instead
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.prefillers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in prefiller_instances
|
||||
]
|
||||
self.decoders: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in decoder_instances
|
||||
]
|
||||
self.req_to_prefiller = {}
|
||||
self.req_id_lock = asyncio.Lock()
|
||||
# Removed selection locks - no longer needed for synchronous methods
|
||||
|
||||
# Initialize priority queues for efficient server selection
|
||||
# Each entry is (priority_score, server_index, server_reference)
|
||||
# Lower priority score = higher priority (less loaded)
|
||||
self.prefiller_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.prefillers)]
|
||||
self.decoder_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.decoders)]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
heapq.heapify(self.decoder_heap)
|
||||
|
||||
def _update_prefiller_priority(self, server_idx: int):
|
||||
"""Update the priority of a prefiller server in the heap."""
|
||||
server = self.prefillers[server_idx]
|
||||
# Priority based on active_tokens and active_kv_cache
|
||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||
# Remove old entry and add new one
|
||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.prefiller_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def _update_decoder_priority(self, server_idx: int):
|
||||
"""Update the priority of a decoder server in the heap."""
|
||||
server = self.decoders[server_idx]
|
||||
priority = server.active_tokens
|
||||
# Remove old entry and add new one
|
||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.decoder_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def abort_prefiller_request(self, server_idx: int,
|
||||
request_id): # Changed to synchronous
|
||||
"""
|
||||
Mark a request as aborted. This will helps to release kv cache in
|
||||
prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||
|
||||
def aquire_aborted_prefiller_requests(
|
||||
self, server_idx: int): # Changed to synchronous
|
||||
"""
|
||||
Get the set of aborted requests and clear it.
|
||||
This is used to release kv cache in prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
||||
self.prefillers[server_idx].aborted_requests.clear()
|
||||
return aborted_requests
|
||||
|
||||
async def next_req_id(self):
|
||||
async with self.req_id_lock:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def select_prefiller(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.prefiller_heap:
|
||||
raise RuntimeError("No prefiller servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.prefillers[chosen].active_tokens += token_count
|
||||
self.prefillers[chosen].active_kv_cache += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_prefiller_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if self.prefillers[idx].active_kv_cache > 0:
|
||||
self.prefillers[idx].active_kv_cache -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def select_decoder(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.decoder_heap:
|
||||
raise RuntimeError("No decoder servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.decoders[chosen].active_tokens += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_decoder_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_decoder(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.decoders[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_decoder_priority(idx)
|
||||
|
||||
# Omni_infer's calculate_input_scores function
|
||||
def calculate_prefill_scores(self, request_length: int) -> float:
|
||||
length_score = request_length / 4.0
|
||||
input_score = length_score * 0.0345 + 120.0745
|
||||
return input_score
|
||||
|
||||
def calculate_decode_scores(self, request_length: int) -> float:
|
||||
return request_length
|
||||
|
||||
|
||||
proxy_state = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8001])
|
||||
parser.add_argument("--decoder-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
args = parser.parse_args()
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
return args
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global proxy_state
|
||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
||||
global_args.decoder_instances)
|
||||
print(
|
||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
||||
)
|
||||
yield
|
||||
for p in proxy_state.prefillers:
|
||||
await p.client.aclose()
|
||||
for d in proxy_state.decoders:
|
||||
await d.client.aclose()
|
||||
|
||||
|
||||
async def listen_for_disconnect(request: Request) -> None:
|
||||
"""Return if a disconnect message is received"""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
break
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs["request"]
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if handler_task in done:
|
||||
return handler_task.result()
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
async def send_request_to_service(client: httpx.AsyncClient,
|
||||
prefiller_id: int,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||
prefiller_id)
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
"aborted_request": list(aborted_requests),
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
last_exc = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = await client.post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||
last_exc = e
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for {endpoint}.")
|
||||
raise last_exc
|
||||
|
||||
|
||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with client.stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
first_chunk_sent = False
|
||||
async for chunk in response.aiter_bytes():
|
||||
first_chunk_sent = True
|
||||
yield chunk
|
||||
return # Success, exit after streaming
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# If any chunk has been sent, do not retry, just log and drop
|
||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
||||
logger.error(
|
||||
f"Streaming to client interrupted after response started: {str(e)}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
# Send request to prefiller
|
||||
response = await send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
prefiller_idx, prefiller_score)
|
||||
released_kv = True
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
|
||||
proxy_state.release_prefiller_kv(prefiller_idx,
|
||||
prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@with_cancellation
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@with_cancellation
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(proxy_state.prefillers),
|
||||
"decode_instances": len(proxy_state.decoders)
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
@@ -0,0 +1,165 @@
|
||||
# Mooncake connector deployment Guide
|
||||
|
||||
## Environmental Dependencies
|
||||
|
||||
* Software:
|
||||
* Python >= 3.9, < 3.12
|
||||
* CANN >= 8.2.rc1
|
||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
||||
* vLLM (same version as vllm-ascend)
|
||||
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
|
||||
|
||||
The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is
|
||||
|
||||
* vllm: v0.10.1
|
||||
* vllm-ascend: v0.10.1rc1
|
||||
|
||||
## run
|
||||
|
||||
### 1.Run `prefill` Node
|
||||
|
||||
```
|
||||
bash run_prefill.sh
|
||||
```
|
||||
|
||||
Content of the run_prefill.sh script
|
||||
|
||||
```
|
||||
export HCCL_EXEC_TIMEOUT=204
|
||||
export HCCL_CONNECT_TIMEOUT=120
|
||||
export HCCL_IF_IP=localhost
|
||||
export GLOO_SOCKET_IFNAME="xxxxxx"
|
||||
export TP_SOCKET_IFNAME="xxxxxx"
|
||||
export HCCL_SOCKET_IFNAME="xxxxxx"
|
||||
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
|
||||
export PHYSICAL_DEVICES=$(ls /dev/davinci* 2>/dev/null | grep -o '[0-9]\+' | sort -n | paste -sd',' -)
|
||||
|
||||
vllm serve "/xxxxx/DeepSeek-V2-Lite-Chat" \
|
||||
--host localhost \
|
||||
--port 8100 \
|
||||
--tensor-parallel-size 2\
|
||||
--seed 1024 \
|
||||
--max-model-len 2000 \
|
||||
--max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-address localhost \
|
||||
--data-parallel-rpc-port 9100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "MooncakeConnectorV1",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_producer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": "0",
|
||||
"kv_rank": 0,
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||
"kv_connector_extra_config": {
|
||||
"prefill": {
|
||||
"dp_size": 2,
|
||||
"tp_size": 2
|
||||
},
|
||||
"decode": {
|
||||
"dp_size": 2,
|
||||
"tp_size": 2
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
`HCCL_EXEC_TIMEOUT`, `HCCL_CONNECT_TIMEOUT`, and `HCCL_IF_IP` are hccl-related configurations.<br>
|
||||
Set `GLOO_SOCKET_IFNAME`, `TP_SOCKET_IFNAME`, and `HCCL_SOCKET_IFNAME` to the corresponding NIC.<br>
|
||||
`ASCEND_RT_VISIBLE_DEVICES` specifies the cards on which the node run resides. The total number of cards equals `dp_size*tp_size`.<br>
|
||||
`/xxxxx/DeepSeek-V2-Lite-Chat` is configured as a model that requires run.<br>
|
||||
`--host`: indicates the IP address of the node to be started.<br>
|
||||
`--port`: indicates the port to be started, which corresponds to the port in step 4.<br>
|
||||
`--seed`, --max-model-len, and --max-num-batched-tokens model basic configuration. Set this parameter based on the site requirements.<br>
|
||||
`--tensor-parallel-size`: specifies the TP size.<br>
|
||||
`--data-parallel-size`: indicates the DP size.<br>
|
||||
`--data-parallel-address`: indicates the IP address of the DP. Set this parameter to the IP address of the node.--data-parallel-rpc-port: indicates the RPC port for communication in the DP group.<br>
|
||||
`--trust-remote-code` can load the local model.<br>
|
||||
`--enforce-eager` Turn off the map mode<br>
|
||||
`--gpu-memory-utilization`: Percentage of video memory occupied by the card<br>
|
||||
`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.<br>
|
||||
|
||||
|
||||
### 2. Run `decode` Node
|
||||
|
||||
```
|
||||
bash run_decode.sh
|
||||
```
|
||||
|
||||
Content of the run_decode.sh script
|
||||
|
||||
```
|
||||
export HCCL_EXEC_TIMEOUT=204
|
||||
export HCCL_CONNECT_TIMEOUT=120
|
||||
export HCCL_IF_IP=localhost
|
||||
export GLOO_SOCKET_IFNAME="xxxxxx"
|
||||
export TP_SOCKET_IFNAME="xxxxxx"
|
||||
export HCCL_SOCKET_IFNAME="xxxxxx"
|
||||
export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7
|
||||
export PHYSICAL_DEVICES=$(ls /dev/davinci* 2>/dev/null | grep -o '[0-9]\+' | sort -n | paste -sd',' -)
|
||||
|
||||
vllm serve "/xxxxx/DeepSeek-V2-Lite-Chat" \
|
||||
--host localhost \
|
||||
--port 8200 \
|
||||
--tensor-parallel-size 2\
|
||||
--seed 1024 \
|
||||
--max-model-len 2000 \
|
||||
--max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--enforce-eager \
|
||||
--data-parallel-size 2 \
|
||||
--data-parallel-address localhost \
|
||||
--data-parallel-rpc-port 9100 \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "MooncakeConnectorV1",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20002",
|
||||
"engine_id": "1",
|
||||
"kv_rank": 1,
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.mooncake_connector",
|
||||
"kv_connector_extra_config": {
|
||||
"prefill": {
|
||||
"dp_size": 2,
|
||||
"tp_size": 2
|
||||
},
|
||||
"decode": {
|
||||
"dp_size": 2,
|
||||
"tp_size": 2
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### 3. Start proxy_server. ###
|
||||
|
||||
```
|
||||
cd /vllm-ascend/examples/disaggregate_prefill_v1/
|
||||
python load_balance_proxy_server_example.py --host localhost --prefiller-hosts host1 host2 --prefiller-ports 8100 8101 --decoder-hosts host3 host4 --decoder-ports 8200 8201
|
||||
```
|
||||
|
||||
`--host`: indicates the active node. The value of localhost in the curl command delivered in step 5 must be the same as the host. The default port number for starting the service proxy is 8000.<br>
|
||||
`--prefiller-hosts`: Set this parameter to the IP addresses of all p nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
|
||||
`--prefiller-ports`: Set this parameter to the port number of all p nodes, which is the configuration of the port number for the vllm to start the service in step 3. Write the port number after the configuration in sequence and leave a blank space between the port number and the port number. The sequence must be one-to-one mapping to the IP address of --prefiller-hosts.<br>
|
||||
`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
|
||||
`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.<br>
|
||||
|
||||
|
||||
### 4. Run Inference
|
||||
|
||||
Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script.
|
||||
|
||||
```
|
||||
curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "model_path",
|
||||
"prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?",
|
||||
"max_tokens": 256
|
||||
}'
|
||||
```
|
||||
32
examples/disaggregated_prefill_v1/run_server.sh
Normal file
32
examples/disaggregated_prefill_v1/run_server.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
export HCCL_IF_IP=141.61.39.117
|
||||
export GLOO_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export TP_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export HCCL_SOCKET_IFNAME="enp48s3u1u1"
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH=path-to-rank-table
|
||||
|
||||
export OMP_PROC_BIND=false
|
||||
export OMP_NUM_THREADS=100
|
||||
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
vllm serve model_path \
|
||||
--host 0.0.0.0 \
|
||||
--port 20002 \
|
||||
--tensor-parallel-size 1\
|
||||
--seed 1024 \
|
||||
--served-model-name dsv3 \
|
||||
--max-model-len 2000 \
|
||||
---max-num-batched-tokens 2000 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector": "LLMDataDistCMgrConnector",
|
||||
"kv_buffer_device": "npu",
|
||||
"kv_role": "kv_consumer",
|
||||
"kv_parallel_size": 1,
|
||||
"kv_port": "20001",
|
||||
"engine_id": 0,
|
||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_connector_v1_a3"
|
||||
}' \
|
||||
--additional-config \
|
||||
'{"enable_graph_mode": "True"}'\
|
||||
Reference in New Issue
Block a user