Sync from v0.13
This commit is contained in:
257
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
257
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
@@ -0,0 +1,257 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Parse command line arguments
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
KV_BUFFER_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option $1"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
|
||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
|
||||
else
|
||||
KV_CONFIG_HETERO_LAYOUT=''
|
||||
fi
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
MODEL_NAMES=${MODEL_NAMES:-}
|
||||
if [[ -n "$MODEL_NAMES" ]]; then
|
||||
MODELS=("$MODEL_NAMES")
|
||||
else
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
fi
|
||||
|
||||
# Number of prefill and decode instances to create
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
|
||||
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM to start.
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Function to clean up previous instances
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Handle to get model-specific arguments for deepseek
|
||||
get_model_args() {
|
||||
local model_name=$1
|
||||
local extra_args=""
|
||||
|
||||
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||
fi
|
||||
|
||||
echo "$extra_args"
|
||||
}
|
||||
|
||||
get_num_gpus() {
|
||||
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
|
||||
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
|
||||
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
|
||||
echo "$($SMI_BIN -l | grep GPU | wc -l)"
|
||||
else
|
||||
# works for non-cuda platforms,
|
||||
# assuming at least 1 device and
|
||||
# let system to decide which card to use
|
||||
echo "1"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to run tests for a specific model
|
||||
run_tests_for_model() {
|
||||
local model_name=$1
|
||||
echo "================================"
|
||||
echo "Testing model: $model_name"
|
||||
echo "================================"
|
||||
|
||||
# Get model-specific arguments
|
||||
local model_args=$(get_model_args "$model_name")
|
||||
|
||||
# Arrays to store all hosts and ports
|
||||
PREFILL_HOSTS=()
|
||||
PREFILL_PORTS=()
|
||||
DECODE_HOSTS=()
|
||||
DECODE_PORTS=()
|
||||
|
||||
# Start prefill instances
|
||||
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||
# Calculate GPU ID - we'll distribute across available GPUs
|
||||
GPU_ID=$((i % $(get_num_gpus)))
|
||||
NEXT_GPU=${GPU_ID}
|
||||
# If PREFILLER_TP_SIZE is more than 1
|
||||
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
|
||||
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
|
||||
GPU_ID="${GPU_ID},${NEXT_GPU}"
|
||||
done
|
||||
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8100 + i))
|
||||
# Calculate side channel port. Avoid clash with with TP workers.
|
||||
SIDE_CHANNEL_PORT=$((5559 + i))
|
||||
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${PREFILL_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Store host and port for proxy configuration
|
||||
PREFILL_HOSTS+=("localhost")
|
||||
PREFILL_PORTS+=($PORT)
|
||||
done
|
||||
|
||||
# Start decode instances
|
||||
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
|
||||
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
|
||||
GPU_ID=$(((i + NEXT_GPU + 1) % $(get_num_gpus)))
|
||||
# If DECODER_TP_SIZE is more than 1
|
||||
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
|
||||
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
|
||||
GPU_ID="${GPU_ID},${NEXT_GPU}"
|
||||
done
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8200 + i))
|
||||
# Calculate side channel port
|
||||
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
|
||||
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
# DP-EP attention mode
|
||||
if [[ -z "$DP_EP" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
||||
else
|
||||
echo "DP-EP Attention enabled, deploying with dp=DECODER_TP_SIZE and tp=1"
|
||||
BASE_CMD="${BASE_CMD} --data-parallel-size $DECODER_TP_SIZE \
|
||||
--tensor-parallel-size 1 --enable-expert-parallel"
|
||||
fi
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Store host and port for proxy configuration
|
||||
DECODE_HOSTS+=("localhost")
|
||||
DECODE_PORTS+=($PORT)
|
||||
done
|
||||
|
||||
# Wait for all instances to start
|
||||
for PORT in "${PREFILL_PORTS[@]}"; do
|
||||
echo "Waiting for prefill instance on port $PORT to start..."
|
||||
wait_for_server $PORT
|
||||
done
|
||||
|
||||
for PORT in "${DECODE_PORTS[@]}"; do
|
||||
echo "Waiting for decode instance on port $PORT to start..."
|
||||
wait_for_server $PORT
|
||||
done
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
|
||||
|
||||
# Add all prefill hosts and ports
|
||||
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
|
||||
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
|
||||
|
||||
# Add all decode hosts and ports
|
||||
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
|
||||
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
|
||||
|
||||
# Start the proxy server
|
||||
echo "Starting proxy server with command: $PROXY_CMD"
|
||||
$PROXY_CMD &
|
||||
|
||||
# Wait for the proxy to start
|
||||
sleep 5
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# Run tests for each model
|
||||
for model in "${MODELS[@]}"; do
|
||||
run_tests_for_model "$model"
|
||||
done
|
||||
|
||||
echo "All tests completed!"
|
||||
148
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Executable file
148
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Executable file
@@ -0,0 +1,148 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Parse command line arguments
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
PREFILL_GPU_ID=4 # Default GPU IDs
|
||||
DECODE_GPU_ID=5
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
KV_BUFFER_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option $1"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)"
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM to start.
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Function to clean up previous instances
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Handle to get model-specific arguments for deepseek
|
||||
get_model_args() {
|
||||
local model_name=$1
|
||||
local extra_args=""
|
||||
|
||||
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||
fi
|
||||
|
||||
echo "$extra_args"
|
||||
}
|
||||
|
||||
|
||||
# Function to run tests for a specific model
|
||||
run_tests_for_model() {
|
||||
local model_name=$1
|
||||
echo "================================"
|
||||
echo "Testing model: $model_name"
|
||||
echo "================================"
|
||||
|
||||
# Get model-specific arguments
|
||||
local model_args=$(get_model_args "$model_name")
|
||||
|
||||
# Start prefill instance
|
||||
PREFILL_PORT=8001
|
||||
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Start decode instance
|
||||
DECODE_PORT=8002
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Wait for all instances to start
|
||||
echo "Waiting for prefill instance on port $PORT to start..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance on port $PORT to start..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_PORT=8192
|
||||
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
|
||||
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
|
||||
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
|
||||
# Start the proxy server
|
||||
echo "Starting proxy server with command: $PROXY_CMD"
|
||||
$PROXY_CMD &
|
||||
|
||||
# Wait for the proxy to start
|
||||
sleep 5
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# Run tests for each model
|
||||
for model in "${MODELS[@]}"; do
|
||||
run_tests_for_model "$model"
|
||||
done
|
||||
|
||||
echo "All tests completed!"
|
||||
@@ -0,0 +1,156 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
launch_baseline() {
|
||||
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${BASELINE_HOST} \
|
||||
--port ${BASELINE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--enforce-eager"
|
||||
echo ${BASELINE_BASE_CMD}
|
||||
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
|
||||
}
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
run_tests(){
|
||||
local service_url=$1
|
||||
local mode=$2
|
||||
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
|
||||
}
|
||||
|
||||
|
||||
# run non-disagg. baseline & save outputs
|
||||
launch_baseline
|
||||
sleep 2
|
||||
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
|
||||
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
|
||||
cleanup
|
||||
sleep 10
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
|
||||
echo "-----P/D success----"
|
||||
|
||||
rm ${OUTPUT_FILE}
|
||||
cleanup
|
||||
|
||||
exit 0
|
||||
124
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
124
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
|
||||
PREFILL_HOST=${PREFILL_HOST} \
|
||||
PREFILL_PORT=${PREFILL_PORT} \
|
||||
DECODE_HOST=${DECODE_HOST} \
|
||||
DECODE_PORT=${DECODE_PORT} \
|
||||
PROXY_HOST=${PROXY_HOST} \
|
||||
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||
71
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
71
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import lm_eval
|
||||
import openai
|
||||
|
||||
BASE_URL = "http://localhost:8192/v1"
|
||||
NUM_CONCURRENT = 100
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
# Model-specific expected values
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen3-0.6B": 0.41,
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59,
|
||||
"deepseek-ai/deepseek-vl2-tiny": 0.19,
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = (
|
||||
"The best part about working on vLLM is that I got to meet so many people across "
|
||||
"various different organizations like UCB, Google, and Meta which means",
|
||||
)
|
||||
|
||||
# Get model name from environment variable
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
|
||||
|
||||
|
||||
def run_simple_prompt():
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
|
||||
completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT)
|
||||
|
||||
print("-" * 50)
|
||||
print(f"Completion results for {MODEL_NAME}:")
|
||||
print(completion)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
"""Run the end to end accuracy test."""
|
||||
run_simple_prompt()
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={BASE_URL}/completions,"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
|
||||
|
||||
if expected_value is None:
|
||||
print(
|
||||
f"Warning: No expected value found for {MODEL_NAME}. "
|
||||
"Skipping accuracy check."
|
||||
)
|
||||
print(f"Measured value: {measured_value}")
|
||||
return
|
||||
|
||||
assert (
|
||||
measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
180
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
180
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
MAX_OUTPUT_LEN = 30
|
||||
|
||||
SAMPLE_PROMPTS = (
|
||||
"Red Hat is the best company in the world to work for because it works on "
|
||||
"open source software, which means that all the contributions are "
|
||||
"delivered to the community. As a result, when working on projects like "
|
||||
"vLLM we are able to meet many amazing people from various organizations "
|
||||
"like AMD, Google, NVIDIA, ",
|
||||
"We hold these truths to be self-evident, that all men are created equal, "
|
||||
"that they are endowed by their Creator with certain unalienable Rights, "
|
||||
"that among these are Life, Liberty and the pursuit of Happiness.--That "
|
||||
"to secure these rights, Governments are instituted among Men, deriving "
|
||||
"their just powers from the consent of the governed, ",
|
||||
)
|
||||
|
||||
|
||||
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
|
||||
"""
|
||||
Checks if the vLLM server is ready by sending a GET request to the
|
||||
/health endpoint.
|
||||
|
||||
Args:
|
||||
url (str): The base URL of the vLLM server.
|
||||
timeout (int): Timeout in seconds for the request.
|
||||
retries (int): Number of retries if the server is not ready.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is ready, False otherwise.
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"Attempt {attempt + 1}: Server returned status code "
|
||||
"{response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
|
||||
time.sleep(1) # Wait before retrying
|
||||
return False
|
||||
|
||||
|
||||
def run_simple_prompt(
|
||||
base_url: str, model_name: str, input_prompt: str, use_chat_endpoint: bool
|
||||
) -> str:
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
if use_chat_endpoint:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": [{"type": "text", "text": input_prompt}]}
|
||||
],
|
||||
max_completion_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
else:
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
prompt=input_prompt,
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
return completion.choices[0].text
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script demonstrates how to accept two optional string arguments
|
||||
("service_url" and "file_name") from the command line, each with a
|
||||
default value of an empty string, using the argparse module.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="vLLM client script")
|
||||
|
||||
parser.add_argument(
|
||||
"--service_url", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="The vLLM service URL.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="model_name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode", # Name of the second argument
|
||||
type=str,
|
||||
default="baseline",
|
||||
help="mode: baseline==non-disagg, or disagg",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file_name", # Name of the second argument
|
||||
type=str,
|
||||
default=".vllm_output.txt",
|
||||
help="the file that saves the output tokens ",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for arg in vars(args):
|
||||
print(f"{arg}: {getattr(args, arg)}")
|
||||
|
||||
if args.mode == "baseline":
|
||||
# non-disagg
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
else:
|
||||
# disagg proxy
|
||||
health_check_url = f"{args.service_url}/healthcheck"
|
||||
if not os.path.exists(args.file_name):
|
||||
raise ValueError(
|
||||
f"In disagg mode, the output file {args.file_name} from "
|
||||
"non-disagg. baseline does not exist."
|
||||
)
|
||||
|
||||
service_url = f"{args.service_url}/v1"
|
||||
|
||||
if not check_vllm_server(health_check_url):
|
||||
raise RuntimeError(f"vllm server: {args.service_url} is not ready yet!")
|
||||
|
||||
output_strs = dict()
|
||||
for i, prompt in enumerate(SAMPLE_PROMPTS):
|
||||
use_chat_endpoint = i % 2 == 1
|
||||
output_str = run_simple_prompt(
|
||||
base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
input_prompt=prompt,
|
||||
use_chat_endpoint=use_chat_endpoint,
|
||||
)
|
||||
print(f"Prompt: {prompt}, output: {output_str}")
|
||||
output_strs[prompt] = output_str
|
||||
|
||||
if args.mode == "baseline":
|
||||
# baseline: save outputs
|
||||
try:
|
||||
with open(args.file_name, "w") as json_file:
|
||||
json.dump(output_strs, json_file, indent=4)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
else:
|
||||
# disagg. verify outputs
|
||||
baseline_outputs = None
|
||||
try:
|
||||
with open(args.file_name) as json_file:
|
||||
baseline_outputs = json.load(json_file)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
assert isinstance(baseline_outputs, dict)
|
||||
assert len(baseline_outputs) == len(output_strs)
|
||||
for prompt, output in baseline_outputs.items():
|
||||
assert prompt in output_strs, f"{prompt} not included"
|
||||
assert output == output_strs[prompt], (
|
||||
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
80
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import openai
|
||||
|
||||
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
|
||||
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
|
||||
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
|
||||
DECODE_PORT = os.getenv("DECODE_PORT", None)
|
||||
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
|
||||
PROXY_PORT = os.getenv("PROXY_PORT", None)
|
||||
|
||||
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
|
||||
raise ValueError("Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
|
||||
|
||||
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
|
||||
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
|
||||
SHORT_PROMPT = "Red Hat is "
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
# Set the OpenAI API key and base URL
|
||||
decode_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
|
||||
)
|
||||
prefill_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
|
||||
)
|
||||
proxy_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
|
||||
)
|
||||
|
||||
# Get the list of models
|
||||
models = decode_client.models.list()
|
||||
MODEL = models.data[0].id
|
||||
|
||||
# (1) Check that we can handle a very short prompt,
|
||||
# less than the length of the block size.
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=SHORT_PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(
|
||||
model=MODEL, prompt=SHORT_PROMPT, temperature=0
|
||||
)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"SMALL PROMPT: {proxy_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
|
||||
# (2) Check that we can handle a full prefix cache
|
||||
# hit on the D worker but not on the P worker.
|
||||
# (2a): prime the D worker.
|
||||
completion = decode_client.completions.create(
|
||||
model=MODEL, prompt=PROMPT, temperature=0
|
||||
)
|
||||
decode_response = completion.choices[0].text
|
||||
# (2b): send via the P/D setup
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
print(f"FULL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == decode_response
|
||||
|
||||
# (3) Check that we can handle a partial prefix cache
|
||||
# hit on the D worker.
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=LONG_PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(
|
||||
model=MODEL, prompt=LONG_PROMPT, temperature=0
|
||||
)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"PARTIAL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
283
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
283
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize client pools for prefiller and decoder services
|
||||
app.state.prefill_clients = []
|
||||
app.state.decode_clients = []
|
||||
|
||||
# Create prefill clients
|
||||
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||
prefiller_base_url = f"http://{host}:{port}/v1"
|
||||
app.state.prefill_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=prefiller_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
"host": host,
|
||||
"port": port,
|
||||
"id": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Create decode clients
|
||||
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||
decoder_base_url = f"http://{host}:{port}/v1"
|
||||
app.state.decode_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=decoder_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
"host": host,
|
||||
"port": port,
|
||||
"id": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize round-robin iterators
|
||||
app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
|
||||
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
|
||||
|
||||
print(
|
||||
f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
||||
f"and {len(app.state.decode_clients)} decode clients."
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close all clients
|
||||
for client_info in app.state.prefill_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
for client_info in app.state.decode_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument(
|
||||
"--prefiller-hosts",
|
||||
"--prefiller-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100]
|
||||
)
|
||||
|
||||
# For decoder instances
|
||||
parser.add_argument(
|
||||
"--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate and pair hosts with ports
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports"
|
||||
)
|
||||
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
||||
|
||||
# Create tuples of (host, port) for each service type
|
||||
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_next_client(app, service_type: str):
|
||||
"""
|
||||
Get the next client in round-robin fashion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app instance
|
||||
service_type: Either 'prefill' or 'decode'
|
||||
|
||||
Returns:
|
||||
The next client to use
|
||||
"""
|
||||
if service_type == "prefill":
|
||||
client_idx = next(app.state.prefill_iterator)
|
||||
return app.state.prefill_clients[client_idx]
|
||||
elif service_type == "decode":
|
||||
client_idx = next(app.state.decode_iterator)
|
||||
return app.state.decode_clients[client_idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
|
||||
async def send_request_to_service(
|
||||
client_info: dict, endpoint: str, req_data: dict, request_id: str
|
||||
):
|
||||
"""
|
||||
Send a request to a service using a client from the pool.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
|
||||
response = await client_info["client"].post(
|
||||
endpoint, json=req_data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# read/consume the response body to release the connection
|
||||
# otherwise, it would http.ReadError
|
||||
await response.aread()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(
|
||||
client_info: dict, endpoint: str, req_data: dict, request_id: str
|
||||
):
|
||||
"""
|
||||
Asynchronously stream response from a service using a client from the pool.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
|
||||
async with client_info["client"].stream(
|
||||
"POST", endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Get the next prefill client in round-robin fashion
|
||||
prefill_client_info = get_next_client(request.app, "prefill")
|
||||
|
||||
# Send request to prefill service
|
||||
response = await send_request_to_service(
|
||||
prefill_client_info, api, req_data, request_id
|
||||
)
|
||||
|
||||
# Extract the needed fields
|
||||
response_json = response.json()
|
||||
await response.aclose() # CRITICAL: Release connection back to pool
|
||||
kv_transfer_params = response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
# Get the next decode client in round-robin fashion
|
||||
decode_client_info = get_next_client(request.app, "decode")
|
||||
|
||||
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(
|
||||
decode_client_info, api, req_data, request_id=request_id
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
"""Simple endpoint to check if the server is running."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(app.state.prefill_clients),
|
||||
"decode_instances": len(app.state.decode_clients),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
41
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
41
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Utility to run integration tests sequentially with varying TP configurations.
|
||||
SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
|
||||
|
||||
# Define test configurations
|
||||
configs=(
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
)
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_env=$2
|
||||
|
||||
echo "=== Running tests (${label}) ==="
|
||||
for cfg in "${configs[@]}"; do
|
||||
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
||||
# Use 'env' to safely set variables without eval
|
||||
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
||||
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "✅ All ${label} tests passed!"
|
||||
}
|
||||
|
||||
# Run tests
|
||||
run_tests "default backend" ""
|
||||
|
||||
# Check if FLASHINFER is set (non-empty)
|
||||
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
else
|
||||
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||
fi
|
||||
Reference in New Issue
Block a user