Sync from v0.13
This commit is contained in:
0
tests/v1/kv_connector/__init__.py
Normal file
0
tests/v1/kv_connector/__init__.py
Normal file
257
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
257
tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh
Executable file
@@ -0,0 +1,257 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Parse command line arguments
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
KV_BUFFER_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option $1"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
|
||||
|
||||
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
|
||||
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
|
||||
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
|
||||
else
|
||||
KV_CONFIG_HETERO_LAYOUT=''
|
||||
fi
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
MODEL_NAMES=${MODEL_NAMES:-}
|
||||
if [[ -n "$MODEL_NAMES" ]]; then
|
||||
MODELS=("$MODEL_NAMES")
|
||||
else
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
fi
|
||||
|
||||
# Number of prefill and decode instances to create
|
||||
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
|
||||
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
|
||||
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
|
||||
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
|
||||
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
|
||||
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
|
||||
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM to start.
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Function to clean up previous instances
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Handle to get model-specific arguments for deepseek
|
||||
get_model_args() {
|
||||
local model_name=$1
|
||||
local extra_args=""
|
||||
|
||||
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||
fi
|
||||
|
||||
echo "$extra_args"
|
||||
}
|
||||
|
||||
get_num_gpus() {
|
||||
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
|
||||
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
|
||||
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
|
||||
echo "$($SMI_BIN -l | grep GPU | wc -l)"
|
||||
else
|
||||
# works for non-cuda platforms,
|
||||
# assuming at least 1 device and
|
||||
# let system to decide which card to use
|
||||
echo "1"
|
||||
fi
|
||||
}
|
||||
|
||||
# Function to run tests for a specific model
|
||||
run_tests_for_model() {
|
||||
local model_name=$1
|
||||
echo "================================"
|
||||
echo "Testing model: $model_name"
|
||||
echo "================================"
|
||||
|
||||
# Get model-specific arguments
|
||||
local model_args=$(get_model_args "$model_name")
|
||||
|
||||
# Arrays to store all hosts and ports
|
||||
PREFILL_HOSTS=()
|
||||
PREFILL_PORTS=()
|
||||
DECODE_HOSTS=()
|
||||
DECODE_PORTS=()
|
||||
|
||||
# Start prefill instances
|
||||
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
|
||||
# Calculate GPU ID - we'll distribute across available GPUs
|
||||
GPU_ID=$((i % $(get_num_gpus)))
|
||||
NEXT_GPU=${GPU_ID}
|
||||
# If PREFILLER_TP_SIZE is more than 1
|
||||
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
|
||||
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
|
||||
GPU_ID="${GPU_ID},${NEXT_GPU}"
|
||||
done
|
||||
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8100 + i))
|
||||
# Calculate side channel port. Avoid clash with with TP workers.
|
||||
SIDE_CHANNEL_PORT=$((5559 + i))
|
||||
|
||||
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT='HND' \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${PREFILL_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--tensor-parallel-size $PREFILLER_TP_SIZE \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Store host and port for proxy configuration
|
||||
PREFILL_HOSTS+=("localhost")
|
||||
PREFILL_PORTS+=($PORT)
|
||||
done
|
||||
|
||||
# Start decode instances
|
||||
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
|
||||
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
|
||||
GPU_ID=$(((i + NEXT_GPU + 1) % $(get_num_gpus)))
|
||||
# If DECODER_TP_SIZE is more than 1
|
||||
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
|
||||
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
|
||||
GPU_ID="${GPU_ID},${NEXT_GPU}"
|
||||
done
|
||||
# Calculate port number (base port + instance number)
|
||||
PORT=$((8200 + i))
|
||||
# Calculate side channel port
|
||||
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
|
||||
|
||||
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
|
||||
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
|
||||
UCX_NET_DEVICES=all \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
|
||||
vllm serve $model_name \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--block-size ${DECODE_BLOCK_SIZE} \
|
||||
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
# DP-EP attention mode
|
||||
if [[ -z "$DP_EP" ]]; then
|
||||
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
|
||||
else
|
||||
echo "DP-EP Attention enabled, deploying with dp=DECODER_TP_SIZE and tp=1"
|
||||
BASE_CMD="${BASE_CMD} --data-parallel-size $DECODER_TP_SIZE \
|
||||
--tensor-parallel-size 1 --enable-expert-parallel"
|
||||
fi
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Store host and port for proxy configuration
|
||||
DECODE_HOSTS+=("localhost")
|
||||
DECODE_PORTS+=($PORT)
|
||||
done
|
||||
|
||||
# Wait for all instances to start
|
||||
for PORT in "${PREFILL_PORTS[@]}"; do
|
||||
echo "Waiting for prefill instance on port $PORT to start..."
|
||||
wait_for_server $PORT
|
||||
done
|
||||
|
||||
for PORT in "${DECODE_PORTS[@]}"; do
|
||||
echo "Waiting for decode instance on port $PORT to start..."
|
||||
wait_for_server $PORT
|
||||
done
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
|
||||
|
||||
# Add all prefill hosts and ports
|
||||
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
|
||||
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
|
||||
|
||||
# Add all decode hosts and ports
|
||||
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
|
||||
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
|
||||
|
||||
# Start the proxy server
|
||||
echo "Starting proxy server with command: $PROXY_CMD"
|
||||
$PROXY_CMD &
|
||||
|
||||
# Wait for the proxy to start
|
||||
sleep 5
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# Run tests for each model
|
||||
for model in "${MODELS[@]}"; do
|
||||
run_tests_for_model "$model"
|
||||
done
|
||||
|
||||
echo "All tests completed!"
|
||||
148
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Executable file
148
tests/v1/kv_connector/nixl_integration/run_edge_case_test.sh
Executable file
@@ -0,0 +1,148 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Parse command line arguments
|
||||
KV_BUFFER_DEVICE="cuda" # Default to cuda
|
||||
PREFILL_GPU_ID=4 # Default GPU IDs
|
||||
DECODE_GPU_ID=5
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--kv_buffer_device)
|
||||
KV_BUFFER_DEVICE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option $1"
|
||||
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)"
|
||||
|
||||
# Build the kv-transfer-config once
|
||||
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
|
||||
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
else
|
||||
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
|
||||
fi
|
||||
|
||||
# Models to run
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B"
|
||||
)
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM to start.
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Function to clean up previous instances
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Handle to get model-specific arguments for deepseek
|
||||
get_model_args() {
|
||||
local model_name=$1
|
||||
local extra_args=""
|
||||
|
||||
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
|
||||
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
|
||||
fi
|
||||
|
||||
echo "$extra_args"
|
||||
}
|
||||
|
||||
|
||||
# Function to run tests for a specific model
|
||||
run_tests_for_model() {
|
||||
local model_name=$1
|
||||
echo "================================"
|
||||
echo "Testing model: $model_name"
|
||||
echo "================================"
|
||||
|
||||
# Get model-specific arguments
|
||||
local model_args=$(get_model_args "$model_name")
|
||||
|
||||
# Start prefill instance
|
||||
PREFILL_PORT=8001
|
||||
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Start decode instance
|
||||
DECODE_PORT=8002
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.2 \
|
||||
--kv-transfer-config '$KV_CONFIG'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Wait for all instances to start
|
||||
echo "Waiting for prefill instance on port $PORT to start..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance on port $PORT to start..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_PORT=8192
|
||||
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
|
||||
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
|
||||
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
|
||||
# Start the proxy server
|
||||
echo "Starting proxy server with command: $PROXY_CMD"
|
||||
$PROXY_CMD &
|
||||
|
||||
# Wait for the proxy to start
|
||||
sleep 5
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# Run tests for each model
|
||||
for model in "${MODELS[@]}"; do
|
||||
run_tests_for_model "$model"
|
||||
done
|
||||
|
||||
echo "All tests completed!"
|
||||
@@ -0,0 +1,156 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
launch_baseline() {
|
||||
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${BASELINE_HOST} \
|
||||
--port ${BASELINE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--enforce-eager"
|
||||
echo ${BASELINE_BASE_CMD}
|
||||
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
|
||||
}
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
run_tests(){
|
||||
local service_url=$1
|
||||
local mode=$2
|
||||
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
|
||||
}
|
||||
|
||||
|
||||
# run non-disagg. baseline & save outputs
|
||||
launch_baseline
|
||||
sleep 2
|
||||
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
|
||||
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
|
||||
cleanup
|
||||
sleep 10
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
|
||||
echo "-----P/D success----"
|
||||
|
||||
rm ${OUTPUT_FILE}
|
||||
cleanup
|
||||
|
||||
exit 0
|
||||
124
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
124
tests/v1/kv_connector/nixl_integration/run_tpu_edge_case_test.sh
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/bin/bash
|
||||
set -xe
|
||||
|
||||
# Hosts / ports
|
||||
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
|
||||
PREFILL_PORT=${PREFILL_PORT:-8100}
|
||||
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
|
||||
DECODE_HOST=${DECODE_HOST:-"localhost"}
|
||||
DECODE_PORT=${DECODE_PORT:-8200}
|
||||
PROXY_HOST=${PROXY_HOST:-"localhost"}
|
||||
PROXY_PORT=${PROXY_PORT:-8192}
|
||||
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
|
||||
BASELINE_PORT=${BASELINE_PORT:-9290}
|
||||
|
||||
|
||||
# Model to run.
|
||||
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
|
||||
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
|
||||
BLOCK_SIZE=${BLOCK_SIZE:-32}
|
||||
|
||||
|
||||
# execution env
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
|
||||
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
|
||||
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
|
||||
|
||||
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Waits for vLLM server to start.
|
||||
wait_for_server() {
|
||||
local host=$1
|
||||
local port=$2
|
||||
timeout 1200 bash -c "
|
||||
until curl -s ${host}:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9 || true
|
||||
# pkill -f python || true
|
||||
echo "Cleanup complete. Exiting."
|
||||
}
|
||||
|
||||
|
||||
launch_pd() {
|
||||
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${PREFILL_HOST} \
|
||||
--port ${PREFILL_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
|
||||
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
UCX_TLS=tcp \
|
||||
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
|
||||
VLLM_LOGGING_LEVEL=DEBUG \
|
||||
PJRT_DEVICE=TPU \
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
|
||||
--host ${DECODE_HOST} \
|
||||
--port ${DECODE_PORT} \
|
||||
--max-model-len ${MAX_MODEL_LEN}\
|
||||
--seed 42 \
|
||||
--block-size ${BLOCK_SIZE} \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.5 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
|
||||
|
||||
echo ${PREFILL_BASE_CMD}
|
||||
echo ${DECODE_BASE_CMD}
|
||||
sleep 2
|
||||
|
||||
# execute on hosts
|
||||
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
|
||||
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
|
||||
sleep 1
|
||||
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
|
||||
sleep 1
|
||||
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
|
||||
sleep 1
|
||||
}
|
||||
|
||||
launch_pd_proxy(){
|
||||
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
|
||||
python3 ${EXP_ROOT}/toy_proxy_server.py \
|
||||
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
|
||||
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
|
||||
--host=${PROXY_HOST} --port ${PROXY_PORT}"
|
||||
echo ${PROXY_BASE_CMD}
|
||||
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
|
||||
}
|
||||
|
||||
|
||||
# run disagg. & do exact-match with the outputs from baseline
|
||||
launch_pd
|
||||
launch_pd_proxy
|
||||
sleep 10
|
||||
|
||||
PREFILL_HOST=${PREFILL_HOST} \
|
||||
PREFILL_PORT=${PREFILL_PORT} \
|
||||
DECODE_HOST=${DECODE_HOST} \
|
||||
DECODE_PORT=${DECODE_PORT} \
|
||||
PROXY_HOST=${PROXY_HOST} \
|
||||
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||
71
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
71
tests/v1/kv_connector/nixl_integration/test_accuracy.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import lm_eval
|
||||
import openai
|
||||
|
||||
BASE_URL = "http://localhost:8192/v1"
|
||||
NUM_CONCURRENT = 100
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
|
||||
# Model-specific expected values
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen3-0.6B": 0.41,
|
||||
"deepseek-ai/deepseek-vl2-small": 0.59,
|
||||
"deepseek-ai/deepseek-vl2-tiny": 0.19,
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
|
||||
}
|
||||
|
||||
SIMPLE_PROMPT = (
|
||||
"The best part about working on vLLM is that I got to meet so many people across "
|
||||
"various different organizations like UCB, Google, and Meta which means",
|
||||
)
|
||||
|
||||
# Get model name from environment variable
|
||||
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
|
||||
|
||||
|
||||
def run_simple_prompt():
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
|
||||
completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT)
|
||||
|
||||
print("-" * 50)
|
||||
print(f"Completion results for {MODEL_NAME}:")
|
||||
print(completion)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
"""Run the end to end accuracy test."""
|
||||
run_simple_prompt()
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL_NAME},"
|
||||
f"base_url={BASE_URL}/completions,"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
|
||||
)
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
|
||||
|
||||
if expected_value is None:
|
||||
print(
|
||||
f"Warning: No expected value found for {MODEL_NAME}. "
|
||||
"Skipping accuracy check."
|
||||
)
|
||||
print(f"Measured value: {measured_value}")
|
||||
return
|
||||
|
||||
assert (
|
||||
measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
180
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
180
tests/v1/kv_connector/nixl_integration/test_disagg_accuracy.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
MAX_OUTPUT_LEN = 30
|
||||
|
||||
SAMPLE_PROMPTS = (
|
||||
"Red Hat is the best company in the world to work for because it works on "
|
||||
"open source software, which means that all the contributions are "
|
||||
"delivered to the community. As a result, when working on projects like "
|
||||
"vLLM we are able to meet many amazing people from various organizations "
|
||||
"like AMD, Google, NVIDIA, ",
|
||||
"We hold these truths to be self-evident, that all men are created equal, "
|
||||
"that they are endowed by their Creator with certain unalienable Rights, "
|
||||
"that among these are Life, Liberty and the pursuit of Happiness.--That "
|
||||
"to secure these rights, Governments are instituted among Men, deriving "
|
||||
"their just powers from the consent of the governed, ",
|
||||
)
|
||||
|
||||
|
||||
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
|
||||
"""
|
||||
Checks if the vLLM server is ready by sending a GET request to the
|
||||
/health endpoint.
|
||||
|
||||
Args:
|
||||
url (str): The base URL of the vLLM server.
|
||||
timeout (int): Timeout in seconds for the request.
|
||||
retries (int): Number of retries if the server is not ready.
|
||||
|
||||
Returns:
|
||||
bool: True if the server is ready, False otherwise.
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"Attempt {attempt + 1}: Server returned status code "
|
||||
"{response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
|
||||
time.sleep(1) # Wait before retrying
|
||||
return False
|
||||
|
||||
|
||||
def run_simple_prompt(
|
||||
base_url: str, model_name: str, input_prompt: str, use_chat_endpoint: bool
|
||||
) -> str:
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
if use_chat_endpoint:
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{"role": "user", "content": [{"type": "text", "text": input_prompt}]}
|
||||
],
|
||||
max_completion_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
return completion.choices[0].message.content
|
||||
else:
|
||||
completion = client.completions.create(
|
||||
model=model_name,
|
||||
prompt=input_prompt,
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
return completion.choices[0].text
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
This script demonstrates how to accept two optional string arguments
|
||||
("service_url" and "file_name") from the command line, each with a
|
||||
default value of an empty string, using the argparse module.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="vLLM client script")
|
||||
|
||||
parser.add_argument(
|
||||
"--service_url", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="The vLLM service URL.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name", # Name of the first argument
|
||||
type=str,
|
||||
required=True,
|
||||
help="model_name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode", # Name of the second argument
|
||||
type=str,
|
||||
default="baseline",
|
||||
help="mode: baseline==non-disagg, or disagg",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file_name", # Name of the second argument
|
||||
type=str,
|
||||
default=".vllm_output.txt",
|
||||
help="the file that saves the output tokens ",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
for arg in vars(args):
|
||||
print(f"{arg}: {getattr(args, arg)}")
|
||||
|
||||
if args.mode == "baseline":
|
||||
# non-disagg
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
else:
|
||||
# disagg proxy
|
||||
health_check_url = f"{args.service_url}/healthcheck"
|
||||
if not os.path.exists(args.file_name):
|
||||
raise ValueError(
|
||||
f"In disagg mode, the output file {args.file_name} from "
|
||||
"non-disagg. baseline does not exist."
|
||||
)
|
||||
|
||||
service_url = f"{args.service_url}/v1"
|
||||
|
||||
if not check_vllm_server(health_check_url):
|
||||
raise RuntimeError(f"vllm server: {args.service_url} is not ready yet!")
|
||||
|
||||
output_strs = dict()
|
||||
for i, prompt in enumerate(SAMPLE_PROMPTS):
|
||||
use_chat_endpoint = i % 2 == 1
|
||||
output_str = run_simple_prompt(
|
||||
base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
input_prompt=prompt,
|
||||
use_chat_endpoint=use_chat_endpoint,
|
||||
)
|
||||
print(f"Prompt: {prompt}, output: {output_str}")
|
||||
output_strs[prompt] = output_str
|
||||
|
||||
if args.mode == "baseline":
|
||||
# baseline: save outputs
|
||||
try:
|
||||
with open(args.file_name, "w") as json_file:
|
||||
json.dump(output_strs, json_file, indent=4)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
else:
|
||||
# disagg. verify outputs
|
||||
baseline_outputs = None
|
||||
try:
|
||||
with open(args.file_name) as json_file:
|
||||
baseline_outputs = json.load(json_file)
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
assert isinstance(baseline_outputs, dict)
|
||||
assert len(baseline_outputs) == len(output_strs)
|
||||
for prompt, output in baseline_outputs.items():
|
||||
assert prompt in output_strs, f"{prompt} not included"
|
||||
assert output == output_strs[prompt], (
|
||||
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
80
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
80
tests/v1/kv_connector/nixl_integration/test_edge_cases.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
import openai
|
||||
|
||||
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
|
||||
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
|
||||
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
|
||||
DECODE_PORT = os.getenv("DECODE_PORT", None)
|
||||
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
|
||||
PROXY_PORT = os.getenv("PROXY_PORT", None)
|
||||
|
||||
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
|
||||
raise ValueError("Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
|
||||
|
||||
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
|
||||
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
|
||||
SHORT_PROMPT = "Red Hat is "
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
# Set the OpenAI API key and base URL
|
||||
decode_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
|
||||
)
|
||||
prefill_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
|
||||
)
|
||||
proxy_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
|
||||
)
|
||||
|
||||
# Get the list of models
|
||||
models = decode_client.models.list()
|
||||
MODEL = models.data[0].id
|
||||
|
||||
# (1) Check that we can handle a very short prompt,
|
||||
# less than the length of the block size.
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=SHORT_PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(
|
||||
model=MODEL, prompt=SHORT_PROMPT, temperature=0
|
||||
)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"SMALL PROMPT: {proxy_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
|
||||
# (2) Check that we can handle a full prefix cache
|
||||
# hit on the D worker but not on the P worker.
|
||||
# (2a): prime the D worker.
|
||||
completion = decode_client.completions.create(
|
||||
model=MODEL, prompt=PROMPT, temperature=0
|
||||
)
|
||||
decode_response = completion.choices[0].text
|
||||
# (2b): send via the P/D setup
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
print(f"FULL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == decode_response
|
||||
|
||||
# (3) Check that we can handle a partial prefix cache
|
||||
# hit on the D worker.
|
||||
completion = proxy_client.completions.create(
|
||||
model=MODEL, prompt=LONG_PROMPT, temperature=0
|
||||
)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(
|
||||
model=MODEL, prompt=LONG_PROMPT, temperature=0
|
||||
)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"PARTIAL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
283
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
283
tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Lifespan context manager to handle startup and shutdown events.
|
||||
"""
|
||||
# Startup: Initialize client pools for prefiller and decoder services
|
||||
app.state.prefill_clients = []
|
||||
app.state.decode_clients = []
|
||||
|
||||
# Create prefill clients
|
||||
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
||||
prefiller_base_url = f"http://{host}:{port}/v1"
|
||||
app.state.prefill_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=prefiller_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
"host": host,
|
||||
"port": port,
|
||||
"id": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Create decode clients
|
||||
for i, (host, port) in enumerate(global_args.decoder_instances):
|
||||
decoder_base_url = f"http://{host}:{port}/v1"
|
||||
app.state.decode_clients.append(
|
||||
{
|
||||
"client": httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=decoder_base_url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=None,
|
||||
max_keepalive_connections=None,
|
||||
),
|
||||
),
|
||||
"host": host,
|
||||
"port": port,
|
||||
"id": i,
|
||||
}
|
||||
)
|
||||
|
||||
# Initialize round-robin iterators
|
||||
app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
|
||||
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
|
||||
|
||||
print(
|
||||
f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
||||
f"and {len(app.state.decode_clients)} decode clients."
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Close all clients
|
||||
for client_info in app.state.prefill_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
for client_info in app.state.decode_clients:
|
||||
await client_info["client"].aclose()
|
||||
|
||||
|
||||
# Update FastAPI app initialization to use lifespan
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
|
||||
# For prefiller instances
|
||||
parser.add_argument(
|
||||
"--prefiller-hosts",
|
||||
"--prefiller-host",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100]
|
||||
)
|
||||
|
||||
# For decoder instances
|
||||
parser.add_argument(
|
||||
"--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate and pair hosts with ports
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports"
|
||||
)
|
||||
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
||||
|
||||
# Create tuples of (host, port) for each service type
|
||||
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_next_client(app, service_type: str):
|
||||
"""
|
||||
Get the next client in round-robin fashion.
|
||||
|
||||
Args:
|
||||
app: The FastAPI app instance
|
||||
service_type: Either 'prefill' or 'decode'
|
||||
|
||||
Returns:
|
||||
The next client to use
|
||||
"""
|
||||
if service_type == "prefill":
|
||||
client_idx = next(app.state.prefill_iterator)
|
||||
return app.state.prefill_clients[client_idx]
|
||||
elif service_type == "decode":
|
||||
client_idx = next(app.state.decode_iterator)
|
||||
return app.state.decode_clients[client_idx]
|
||||
else:
|
||||
raise ValueError(f"Unknown service type: {service_type}")
|
||||
|
||||
|
||||
async def send_request_to_service(
|
||||
client_info: dict, endpoint: str, req_data: dict, request_id: str
|
||||
):
|
||||
"""
|
||||
Send a request to a service using a client from the pool.
|
||||
"""
|
||||
req_data = req_data.copy()
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "max_completion_tokens" in req_data:
|
||||
req_data["max_completion_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
|
||||
response = await client_info["client"].post(
|
||||
endpoint, json=req_data, headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# read/consume the response body to release the connection
|
||||
# otherwise, it would http.ReadError
|
||||
await response.aread()
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def stream_service_response(
|
||||
client_info: dict, endpoint: str, req_data: dict, request_id: str
|
||||
):
|
||||
"""
|
||||
Asynchronously stream response from a service using a client from the pool.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
|
||||
async with client_info["client"].stream(
|
||||
"POST", endpoint, json=req_data, headers=headers
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
# Get the next prefill client in round-robin fashion
|
||||
prefill_client_info = get_next_client(request.app, "prefill")
|
||||
|
||||
# Send request to prefill service
|
||||
response = await send_request_to_service(
|
||||
prefill_client_info, api, req_data, request_id
|
||||
)
|
||||
|
||||
# Extract the needed fields
|
||||
response_json = response.json()
|
||||
await response.aclose() # CRITICAL: Release connection back to pool
|
||||
kv_transfer_params = response_json.get("kv_transfer_params", {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
|
||||
# Get the next decode client in round-robin fashion
|
||||
decode_client_info = get_next_client(request.app, "decode")
|
||||
|
||||
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
||||
|
||||
# Stream response from decode service
|
||||
async def generate_stream():
|
||||
async for chunk in stream_service_response(
|
||||
decode_client_info, api, req_data, request_id=request_id
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
"""Simple endpoint to check if the server is running."""
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(app.state.prefill_clients),
|
||||
"decode_instances": len(app.state.decode_clients),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
41
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
41
tests/v1/kv_connector/nixl_integration/tp_config_sweep_accuracy_test.sh
Executable file
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Utility to run integration tests sequentially with varying TP configurations.
|
||||
SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
|
||||
|
||||
# Define test configurations
|
||||
configs=(
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
|
||||
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
|
||||
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
|
||||
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
|
||||
)
|
||||
|
||||
run_tests() {
|
||||
local label=$1
|
||||
local extra_env=$2
|
||||
|
||||
echo "=== Running tests (${label}) ==="
|
||||
for cfg in "${configs[@]}"; do
|
||||
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
|
||||
# Use 'env' to safely set variables without eval
|
||||
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
|
||||
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
echo "✅ All ${label} tests passed!"
|
||||
}
|
||||
|
||||
# Run tests
|
||||
run_tests "default backend" ""
|
||||
|
||||
# Check if FLASHINFER is set (non-empty)
|
||||
if [[ -n "${FLASHINFER:-}" ]]; then
|
||||
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
|
||||
else
|
||||
echo "FLASHINFER not set, skipping FLASHINFER runs."
|
||||
fi
|
||||
0
tests/v1/kv_connector/unit/__init__.py
Normal file
0
tests/v1/kv_connector/unit/__init__.py
Normal file
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for backwards compatibility with external KV connector implementations.
|
||||
|
||||
This test ensures that external connectors (loaded via kv_connector_module_path)
|
||||
implemented with the old signature continue to work:
|
||||
- Old signature: __init__(self, vllm_config, role)
|
||||
- New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
from .utils import create_scheduler, create_vllm_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class OldStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the old signature with 2 required arguments.
|
||||
This simulates external connectors that haven't been updated yet.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
# Old-style call to super().__init__ with only 2 arguments
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
class NewStyleTestConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Test connector using the new signature with 3 required arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
# New-style call to super().__init__ with all 3 arguments
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int | None, bool]:
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||
return None
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_old_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with old signature (2 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated with backwards
|
||||
compatibility support.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, OldStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_external_new_signature_factory_instantiation(role):
|
||||
"""
|
||||
Test that external connectors with new signature (3 required args) loaded
|
||||
via kv_connector_module_path are correctly instantiated.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
|
||||
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||
)
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, NewStyleTestConnector)
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||
def test_old_signature_super_init(role):
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() without
|
||||
kv_cache_config parameter.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
connector = OldStyleTestConnector(vllm_config, role)
|
||||
|
||||
assert connector is not None
|
||||
assert connector.role == role
|
||||
assert connector._kv_cache_config is None
|
||||
|
||||
|
||||
def test_old_signature_super_init_with_kwargs():
|
||||
"""
|
||||
Test that old-style connectors can call super().__init__() with keyword
|
||||
arguments in different orders.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test with vllm_config= and role= kwargs
|
||||
connector1 = OldStyleTestConnector(
|
||||
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
|
||||
)
|
||||
assert connector1 is not None
|
||||
assert connector1._kv_cache_config is None
|
||||
|
||||
# Test with role= and vllm_config= in reversed order
|
||||
connector2 = OldStyleTestConnector(
|
||||
role=KVConnectorRole.WORKER, vllm_config=vllm_config
|
||||
)
|
||||
assert connector2 is not None
|
||||
assert connector2._kv_cache_config is None
|
||||
|
||||
|
||||
def test_internal_connector_uses_new_signature():
|
||||
"""
|
||||
Test that internal connectors (registered in factory) always use the new
|
||||
signature and get kv_cache_config.
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
|
||||
ExampleConnector,
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
|
||||
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
|
||||
assert connector is not None
|
||||
assert isinstance(connector, ExampleConnector)
|
||||
assert connector._kv_cache_config is not None
|
||||
assert connector._kv_cache_config == kv_cache_config
|
||||
|
||||
|
||||
def test_signature_detection_with_mocking():
|
||||
"""
|
||||
Test that the factory correctly applies compat_sig flag returned from
|
||||
_get_connector_class_with_compat.
|
||||
"""
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
kv_cache_config = scheduler.kv_cache_config
|
||||
|
||||
# Mock _get_connector_class_with_compat to return old-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(OldStyleTestConnector, True),
|
||||
):
|
||||
old_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert old_connector is not None
|
||||
assert isinstance(old_connector, OldStyleTestConnector)
|
||||
assert old_connector._kv_cache_config is None
|
||||
|
||||
# Mock _get_connector_class_with_compat to return new-style connector
|
||||
with patch.object(
|
||||
KVConnectorFactory,
|
||||
"_get_connector_class_with_compat",
|
||||
return_value=(NewStyleTestConnector, False),
|
||||
):
|
||||
new_connector = KVConnectorFactory.create_connector(
|
||||
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||
)
|
||||
assert new_connector is not None
|
||||
assert isinstance(new_connector, NewStyleTestConnector)
|
||||
assert new_connector._kv_cache_config is not None
|
||||
assert new_connector._kv_cache_config == kv_cache_config
|
||||
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
test that invalid blocks are evicted from prefix cache to prevent pollution.
|
||||
|
||||
verifies that when sync-loading fails, invalid blocks are removed from the
|
||||
prefix cache hash table so future requests cannot match and reuse corrupted data.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_invalid_blocks_evicted_prevents_cache_pollution(
|
||||
fail_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
verify invalid blocks are evicted to prevent future cache hits.
|
||||
|
||||
scenario:
|
||||
1. request 1 loads externally-computed blocks (sync mode)
|
||||
2. some blocks fail to load and are marked invalid
|
||||
3. with fail policy, invalid blocks should be evicted from prefix cache
|
||||
4. request is marked as FINISHED_ERROR
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
# request 1: will have invalid blocks
|
||||
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
|
||||
fail_scheduler.add_request(request=request1)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request1.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify eviction later
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# cache the blocks to simulate they've been computed and cached
|
||||
# (in real scenario blocks would be cached after compute)
|
||||
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
|
||||
|
||||
# verify block has a hash (is cached) before reporting invalid blocks
|
||||
assert block.block_hash is not None, (
|
||||
f"block {invalid_block_id} should be cached (have a hash) before "
|
||||
f"eviction test, but hash is None"
|
||||
)
|
||||
|
||||
# report invalid blocks
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request finished with error (fail policy)
|
||||
assert request1.status == RequestStatus.FINISHED_ERROR
|
||||
|
||||
# critical assertion: invalid block and all subsequent blocks should be evicted
|
||||
# all blocks from invalid_block_idx onwards become invalid since they were
|
||||
# computed based on the failed block
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is None, (
|
||||
f"block {block_id} at index {idx} should have been evicted "
|
||||
f"(hash reset to None), but hash is {block_obj.block_hash}. "
|
||||
f"All blocks from index {invalid_block_idx} onwards should be evicted "
|
||||
f"since they depend on the invalid block at index {invalid_block_idx}."
|
||||
)
|
||||
|
||||
# verify cache contains exactly the valid blocks (before first affected block)
|
||||
# and none of the invalid blocks (from first affected block onwards)
|
||||
|
||||
# valid blocks: all blocks before invalid_block_idx should be cached
|
||||
for idx in range(invalid_block_idx):
|
||||
block_id = req_block_ids[idx]
|
||||
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||
assert block_obj.block_hash is not None, (
|
||||
f"valid block {block_id} at index {idx} should still be cached "
|
||||
f"(have a hash), but hash is None. Only blocks from index "
|
||||
f"{invalid_block_idx} onwards should be evicted."
|
||||
)
|
||||
|
||||
# invalid blocks: verify they're not in the cached_block_hash_to_block map
|
||||
cached_blocks = (
|
||||
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||
)
|
||||
cached_block_ids = {
|
||||
b.block_id
|
||||
for blocks_val in cached_blocks._cache.values()
|
||||
for b in (
|
||||
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
|
||||
)
|
||||
}
|
||||
|
||||
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||
block_id = req_block_ids[idx]
|
||||
assert block_id not in cached_block_ids, (
|
||||
f"invalid block {block_id} at index {idx} should not be in cache hash table"
|
||||
)
|
||||
65
tests/v1/kv_connector/unit/test_config.py
Normal file
65
tests/v1/kv_connector/unit/test_config.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tests for KV cache offloading configuration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||
[
|
||||
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
|
||||
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||
(None, None, 1, 1, None, None),
|
||||
],
|
||||
)
|
||||
def test_kv_connector(
|
||||
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||
):
|
||||
kv_transfer_config = (
|
||||
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||
if expected_backend is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
cache_config=CacheConfig(
|
||||
kv_offloading_backend=kv_offloading_backend,
|
||||
kv_offloading_size=kv_offloading_size,
|
||||
),
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp, pipeline_parallel_size=pp
|
||||
),
|
||||
)
|
||||
|
||||
# No KV transfer config expected
|
||||
if expected_backend is None:
|
||||
assert vllm_config.kv_transfer_config is expected_backend
|
||||
return
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
assert kv_transfer_config.kv_connector == expected_backend
|
||||
assert kv_transfer_config.kv_role == "kv_both"
|
||||
|
||||
if kv_offloading_backend == "native":
|
||||
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
|
||||
assert kv_connector_extra_config["num_cpu_blocks"] == 0
|
||||
# Existing config should be preserved
|
||||
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||
elif kv_offloading_backend == "lmcache":
|
||||
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||
# Existing config should be replaced
|
||||
assert "existing_key" not in kv_connector_extra_config
|
||||
415
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
415
tests/v1/kv_connector/unit/test_decode_bench_connector.py
Normal file
@@ -0,0 +1,415 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for DecodeBenchConnector.
|
||||
|
||||
Tests the functionality of the DecodeBenchConnector which fills KV cache
|
||||
with dummy values for decode performance benchmarking.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
|
||||
# ruff: noqa: E501
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import (
|
||||
DecodeBenchConnector,
|
||||
DecodeBenchConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (
|
||||
EOS_TOKEN_ID,
|
||||
create_model_runner_output,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class DecodeBenchTestRunner:
|
||||
"""Test runner for DecodeBenchConnector."""
|
||||
|
||||
def __init__(self, block_size: int, num_gpu_blocks: int):
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
|
||||
self.req_id = -1
|
||||
|
||||
# Create vllm config with DecodeBenchConnector
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="DecodeBenchConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler: Scheduler = create_scheduler(
|
||||
vllm_config, num_blocks=num_gpu_blocks
|
||||
)
|
||||
|
||||
# Create worker-side connector
|
||||
self.worker_connector = DecodeBenchConnector(
|
||||
vllm_config, KVConnectorRole.WORKER
|
||||
)
|
||||
|
||||
# Create dummy KV caches for testing
|
||||
# Shape: [num_blocks, 2, num_heads, block_size, head_dim]
|
||||
# Using simplified shape for testing
|
||||
num_heads = 4
|
||||
head_dim = 64
|
||||
self.kv_caches = {
|
||||
f"layer_{i}": torch.zeros(
|
||||
num_gpu_blocks, 2, num_heads, block_size, head_dim
|
||||
)
|
||||
for i in range(2) # 2 layers for testing
|
||||
}
|
||||
|
||||
# Register KV caches with worker connector
|
||||
self.worker_connector.register_kv_caches(self.kv_caches)
|
||||
|
||||
# Extract scheduler-side connector
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, DecodeBenchConnector)
|
||||
self.scheduler_connector: DecodeBenchConnector = scheduler_connector
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]) -> Request:
|
||||
"""Create a new request with given token IDs."""
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=100),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
return req
|
||||
|
||||
def run_single_step(self, token_id: int = 0):
|
||||
"""Run a single scheduler + worker step."""
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
# Get connector metadata
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata)
|
||||
|
||||
# Bind metadata and load KV
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
# Create model runner output
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
return scheduler_output, kv_connector_metadata
|
||||
|
||||
|
||||
def test_decode_bench_connector_basic():
|
||||
"""Test basic functionality of DecodeBenchConnector."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with multiple blocks worth of tokens
|
||||
num_tokens = block_size * 3 # 3 blocks
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache with dummy values
|
||||
scheduler_output, metadata = runner.run_single_step()
|
||||
|
||||
# Check that get_num_new_matched_tokens returned correct value
|
||||
# Should be num_tokens - 1 (all except the last token for decode)
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
|
||||
# Check metadata has the request to fill
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify KV caches were filled with constant value
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
# Check that the block was filled
|
||||
block_data = kv_cache[block_id]
|
||||
# Should be filled with constant value 0.015
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_no_refill():
|
||||
"""Test that DecodeBenchConnector only fills once per request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request
|
||||
num_tokens = block_size * 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run first step - should fill KV cache
|
||||
_, metadata1 = runner.run_single_step()
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
|
||||
# Run second step - should NOT fill again (already filled)
|
||||
_, metadata2 = runner.run_single_step()
|
||||
assert len(metadata2.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_single_token():
|
||||
"""Test DecodeBenchConnector with single token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with just 1 token
|
||||
# Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode)
|
||||
token_ids = [1]
|
||||
|
||||
runner.new_request(token_ids)
|
||||
|
||||
# Run step - should NOT fill KV cache
|
||||
_, metadata = runner.run_single_step()
|
||||
assert len(metadata.reqs_to_fill) == 0
|
||||
|
||||
|
||||
def test_decode_bench_connector_two_tokens():
|
||||
"""Test DecodeBenchConnector with two token request."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with 2 tokens
|
||||
# Should fill 1 token (first token), decode the second
|
||||
token_ids = [1, 2]
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
assert num_tokens_to_fill == 1
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block
|
||||
|
||||
|
||||
def test_decode_bench_connector_large_context():
|
||||
"""Test DecodeBenchConnector with large context size."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 1000
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request with many blocks
|
||||
num_blocks = 20
|
||||
num_tokens = block_size * num_blocks
|
||||
token_ids = list(range(num_tokens))
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Calculate expected number of blocks
|
||||
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
# Verify blocks were filled
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
block_data = kv_cache[block_id]
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
|
||||
def test_decode_bench_connector_multiple_requests():
|
||||
"""Test DecodeBenchConnector with multiple sequential requests."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# First request
|
||||
req1 = runner.new_request([1] * (block_size * 2))
|
||||
_, metadata1 = runner.run_single_step()
|
||||
|
||||
assert len(metadata1.reqs_to_fill) == 1
|
||||
assert req1.request_id in metadata1.reqs_to_fill
|
||||
|
||||
# Complete first request
|
||||
while runner.scheduler.running:
|
||||
runner.run_single_step()
|
||||
|
||||
# Add EOS to finish
|
||||
scheduler_output = runner.scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=runner.scheduler.running,
|
||||
token_id=EOS_TOKEN_ID,
|
||||
use_eos=True,
|
||||
)
|
||||
runner.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Second request - should also get filled
|
||||
req2 = runner.new_request([2] * (block_size * 3))
|
||||
_, metadata2 = runner.run_single_step()
|
||||
|
||||
assert len(metadata2.reqs_to_fill) == 1
|
||||
assert req2.request_id in metadata2.reqs_to_fill
|
||||
|
||||
# Different request should have different metadata
|
||||
_, num_tokens1 = metadata1.reqs_to_fill[req1.request_id]
|
||||
_, num_tokens2 = metadata2.reqs_to_fill[req2.request_id]
|
||||
|
||||
assert num_tokens1 == block_size * 2 - 1
|
||||
assert num_tokens2 == block_size * 3 - 1
|
||||
|
||||
|
||||
def test_decode_bench_connector_partial_block():
|
||||
"""Test DecodeBenchConnector with partial block filling."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create a request that doesn't align to block boundaries
|
||||
# e.g., 2.5 blocks worth of tokens
|
||||
num_tokens = block_size * 2 + block_size // 2
|
||||
token_ids = [1] * num_tokens
|
||||
|
||||
req = runner.new_request(token_ids)
|
||||
|
||||
# Run step
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
assert len(metadata.reqs_to_fill) == 1
|
||||
assert req.request_id in metadata.reqs_to_fill
|
||||
|
||||
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
|
||||
|
||||
# Should fill all tokens except the last one
|
||||
expected_fill_tokens = num_tokens - 1
|
||||
assert num_tokens_to_fill == expected_fill_tokens
|
||||
|
||||
# For standard attention, there's only one group
|
||||
assert len(block_ids_per_group) == 1
|
||||
block_ids = block_ids_per_group[0]
|
||||
|
||||
# Should allocate 3 blocks to hold the partial data
|
||||
expected_num_blocks = 3
|
||||
assert len(block_ids) == expected_num_blocks
|
||||
|
||||
|
||||
def test_decode_bench_connector_concurrent_requests():
|
||||
"""Test DecodeBenchConnector with multiple concurrent requests in the same batch."""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 1000
|
||||
|
||||
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# Create multiple requests that will be batched together
|
||||
req1 = runner.new_request([1] * (block_size * 2))
|
||||
req2 = runner.new_request([2] * (block_size * 3))
|
||||
req3 = runner.new_request([3] * (block_size * 1))
|
||||
|
||||
# Run first step - all requests should be filled concurrently
|
||||
_, metadata = runner.run_single_step()
|
||||
|
||||
# All three requests should be in the metadata
|
||||
assert len(metadata.reqs_to_fill) == 3
|
||||
assert req1.request_id in metadata.reqs_to_fill
|
||||
assert req2.request_id in metadata.reqs_to_fill
|
||||
assert req3.request_id in metadata.reqs_to_fill
|
||||
|
||||
# Verify each request has correct fill info
|
||||
block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id]
|
||||
block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id]
|
||||
block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id]
|
||||
|
||||
# Verify token counts (all tokens except last one)
|
||||
assert num_tokens1 == block_size * 2 - 1
|
||||
assert num_tokens2 == block_size * 3 - 1
|
||||
assert num_tokens3 == block_size * 1 - 1
|
||||
|
||||
# Verify block counts for each request
|
||||
assert len(block_ids_per_group1[0]) == 2 # 2 blocks
|
||||
assert len(block_ids_per_group2[0]) == 3 # 3 blocks
|
||||
assert len(block_ids_per_group3[0]) == 1 # 1 block
|
||||
|
||||
# Verify all blocks are filled in KV cache
|
||||
for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items():
|
||||
block_ids = block_ids_per_group[0]
|
||||
for layer_name, kv_cache in runner.kv_caches.items():
|
||||
for block_id in block_ids:
|
||||
block_data = kv_cache[block_id]
|
||||
assert torch.allclose(block_data, torch.tensor(0.015))
|
||||
|
||||
# Run second step - should NOT fill again (already filled)
|
||||
_, metadata2 = runner.run_single_step()
|
||||
assert len(metadata2.reqs_to_fill) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
|
||||
def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.waiting) == 0
|
||||
256
tests/v1/kv_connector/unit/test_example_connector.py
Normal file
256
tests/v1/kv_connector/unit/test_example_connector.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
|
||||
|
||||
TEXT_PROMPTS = [
|
||||
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
class InputCase(NamedTuple):
|
||||
text: str
|
||||
img: list[Image]
|
||||
expected_len: int
|
||||
info: str
|
||||
|
||||
|
||||
def _check_path_len(path):
|
||||
"""Return the latest length in path"""
|
||||
return len(list(path.iterdir()))
|
||||
|
||||
|
||||
def _list_path(path):
|
||||
"""Return the list of foldername (hashes generated) under the path"""
|
||||
return list(path.iterdir())
|
||||
|
||||
|
||||
def run_test(
|
||||
tmp_path,
|
||||
processor,
|
||||
llm: LLM,
|
||||
question: str,
|
||||
image_urls: list[Image],
|
||||
expected_len: int,
|
||||
info: str,
|
||||
):
|
||||
"""
|
||||
One individual test to process the prompt and output base on 1 set of input
|
||||
Then check if the length in the storage path matches the expected length
|
||||
`info` introduces details or purpose of the individual test
|
||||
"""
|
||||
print(f"***info: {info}***")
|
||||
print(f"**Expected storage path length after llm generate: {expected_len}**")
|
||||
process_prompt(processor, llm, question, image_urls)
|
||||
|
||||
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
|
||||
print(f"Hashes under the storage path: {_list_path(tmp_path)}")
|
||||
|
||||
assert _check_path_len(tmp_path) == expected_len, (
|
||||
f"Expect storage path length {expected_len} ;",
|
||||
f"but end up {_check_path_len(tmp_path)} instead. ",
|
||||
f"Info: {info}",
|
||||
)
|
||||
|
||||
|
||||
def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
|
||||
"""
|
||||
Form the prompt based on the text and image input, then llm generate output
|
||||
"""
|
||||
placeholders = [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
|
||||
}
|
||||
for image_pil in image_urls
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
**({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}),
|
||||
},
|
||||
sampling_params=SAMPLING_PARAMS,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("Output:")
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"hipErrorLaunchFailure when running this test, see issue:"
|
||||
"https://github.com/ROCm/pytorch/issues/2822"
|
||||
),
|
||||
)
|
||||
def test_shared_storage_connector_hashes(tmp_path):
|
||||
"""
|
||||
Tests that ExampleConnector saves KV to the storage locations
|
||||
with proper hashes; that are unique for inputs with identical text but
|
||||
different images (same size), or same multiple images but different orders.
|
||||
"""
|
||||
# Using tmp_path as the storage path to store KV
|
||||
print(f"KV storage path at: {str(tmp_path)}")
|
||||
|
||||
# Configure the ExampleConnector
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="ExampleConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
|
||||
)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=1,
|
||||
gpu_memory_utilization=0.4,
|
||||
enforce_eager=True,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
)
|
||||
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
# Create processor to handle the chat prompt
|
||||
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||
|
||||
# Prepare images for the tests
|
||||
# Resize to the same size to check hashes correctness
|
||||
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
|
||||
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
|
||||
|
||||
# Make sure that they are not the same picture
|
||||
assert image_1 != image_2, "The images should not be identical"
|
||||
|
||||
# Create the LLM instance
|
||||
engine_args = asdict(engine_args)
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# Prepare the input cases
|
||||
input_cases = [
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=1,
|
||||
info="image_1 single input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the first time. "
|
||||
"It is in same pixel size with image_1, yet it "
|
||||
"should be able to form a new unique hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_1 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2],
|
||||
expected_len=2,
|
||||
info=(
|
||||
"image_2 single input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=3,
|
||||
info="image_1 with image_2 input the first time.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info="The image order is swapped. Should form new hash.",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_1, image_2],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_1, image_2] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[image_2, image_1],
|
||||
expected_len=4,
|
||||
info=(
|
||||
"[image_2, image_1] input the 2nd time. "
|
||||
"It should not form another new hash."
|
||||
),
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Pure text input test as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[0],
|
||||
img=[],
|
||||
expected_len=5,
|
||||
info="Identical pure text input as a case-control",
|
||||
),
|
||||
InputCase(
|
||||
text=TEXT_PROMPTS[1],
|
||||
img=[],
|
||||
expected_len=6,
|
||||
info="Another pure text input as a case-control",
|
||||
),
|
||||
]
|
||||
|
||||
# Run tests
|
||||
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
|
||||
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
|
||||
run_test(tmp_path, processor, llm, text, img, expected_len, info)
|
||||
|
||||
print("All tests passed successfully!")
|
||||
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Tests for correctness in invalid block handling.
|
||||
|
||||
These tests verify correct behavior in three scenarios:
|
||||
1. Sync recompute case: Blocks should not be freed for running requests
|
||||
that need to recompute invalid blocks
|
||||
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
|
||||
3. Async recompute case: Invalid blocks should not be cached after transfer
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load: bool,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fail_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='fail'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recompute_scheduler():
|
||||
"""scheduler with kv_load_failure_policy='recompute'"""
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
def test_sync_recompute_blocks_not_freed_for_running_requests(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test sync recompute case - blocks must not be freed for running requests.
|
||||
|
||||
When a running request has invalid blocks and retry_policy is 'recompute':
|
||||
1. Request should remain in RUNNING state
|
||||
2. num_computed_tokens should be truncated to invalid block boundary
|
||||
3. Blocks should NOT be freed (request still needs them for recomputation)
|
||||
4. Request should remain in scheduler.requests and scheduler.running
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(recompute_scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get the allocated block IDs before invalid blocks are reported
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
|
||||
# store original num_computed_tokens for comparison
|
||||
original_num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False, # not finished - should continue running
|
||||
)
|
||||
|
||||
outputs = recompute_scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# critical assertions for recompute case:
|
||||
|
||||
# 1. request should still be RUNNING (not finished, not aborted)
|
||||
assert request.status == RequestStatus.RUNNING, (
|
||||
f"Request should remain RUNNING for recompute, got {request.status}"
|
||||
)
|
||||
|
||||
# 2. num_computed_tokens should be truncated to first invalid block
|
||||
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_truncated_tokens, (
|
||||
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
assert request.num_computed_tokens < original_num_computed_tokens, (
|
||||
"num_computed_tokens should be reduced after invalid block detection"
|
||||
)
|
||||
|
||||
# 3. no output should be generated (request is still running)
|
||||
# the request should be skipped in the output loop
|
||||
assert len(outputs) == 0 or request.request_id not in [
|
||||
out.request_id for outs in outputs.values() for out in outs.outputs
|
||||
], "No output should be generated for recompute requests"
|
||||
|
||||
# 4. request should still be in running queue
|
||||
assert request in recompute_scheduler.running, (
|
||||
"Request should remain in running queue for recomputation"
|
||||
)
|
||||
|
||||
# 5. request should still be in scheduler.requests (not deleted)
|
||||
assert request.request_id in recompute_scheduler.requests, (
|
||||
"Request should not be deleted from scheduler.requests"
|
||||
)
|
||||
|
||||
# 6. blocks should NOT be freed - verify blocks are still allocated
|
||||
try:
|
||||
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
assert allocated_blocks is not None
|
||||
assert len(allocated_blocks[0]) > 0, (
|
||||
"Blocks should still be allocated for recomputation"
|
||||
)
|
||||
except KeyError:
|
||||
pytest.fail(
|
||||
"Blocks were freed incorrectly! Running requests need their blocks "
|
||||
"to recompute invalid portions."
|
||||
)
|
||||
|
||||
# 7. verify request can be rescheduled in next step
|
||||
scheduler_output_2 = recompute_scheduler.schedule()
|
||||
|
||||
# request should appear in the new schedule to recompute invalid blocks
|
||||
scheduled_req_ids = [
|
||||
req.request_id for req in scheduler_output_2.scheduled_new_reqs
|
||||
]
|
||||
if scheduler_output_2.num_scheduled_tokens:
|
||||
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
|
||||
|
||||
assert (
|
||||
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
|
||||
), "Request should be reschedulable for recomputation"
|
||||
|
||||
|
||||
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||
"""
|
||||
Test sync fail case - invalid blocks must be evicted from cache.
|
||||
|
||||
When a request fails with policy='fail' and has invalid blocks from sync loading:
|
||||
1. Request should be finished with FINISHED_ERROR
|
||||
2. Invalid blocks should be evicted from the KV cache
|
||||
3. Valid blocks (if shared) should remain in cache
|
||||
4. Future requests should not reuse the invalid blocks
|
||||
|
||||
This test verifies that invalid blocks are properly evicted to prevent
|
||||
cache corruption and reuse of invalid data.
|
||||
"""
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * fail_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
fail_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating sync load
|
||||
fail_scheduler.connector = Mock()
|
||||
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||
)
|
||||
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||
fail_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
# request should be running with sync KV load
|
||||
assert len(fail_scheduler.running) == 1
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# get allocated block IDs
|
||||
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# verify the block is in the block pool before we report it as invalid
|
||||
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
assert block is not None
|
||||
|
||||
# report invalid blocks - request should fail
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify request is finished with error
|
||||
assert request.status == RequestStatus.FINISHED_ERROR
|
||||
assert request.get_finished_reason() == FinishReason.ERROR
|
||||
|
||||
# verify output is generated
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert len(engine_outputs.outputs) == 1
|
||||
output = engine_outputs.outputs[0]
|
||||
assert output.request_id == request.request_id
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
# verify the request was removed from scheduler
|
||||
assert request.request_id not in fail_scheduler.requests
|
||||
assert len(fail_scheduler.running) == 0
|
||||
|
||||
# critical: verify invalid block was actually freed from cache
|
||||
# this is the key assertion - the invalid block should no longer be
|
||||
# tracked by the KV cache manager for this request
|
||||
# if it's still there, a future request could reuse the invalid data
|
||||
try:
|
||||
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
# if we get here, check if blocks were actually freed
|
||||
if block_ids is not None and len(block_ids[0]) > 0:
|
||||
pytest.fail(
|
||||
f"Invalid blocks still tracked for finished request! "
|
||||
f"Request {request.request_id} should have been freed but "
|
||||
f"still has {len(block_ids[0])} blocks allocated."
|
||||
)
|
||||
# blocks list exists but is empty - this is fine, they were freed
|
||||
except KeyError:
|
||||
# expected - request completely removed from tracking
|
||||
pass
|
||||
|
||||
# critical: verify invalid block was evicted from prefix cache
|
||||
# the block should no longer have a hash (hash is reset on eviction)
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should have been evicted from cache "
|
||||
f"(hash should be None), but hash is still {block.block_hash}"
|
||||
)
|
||||
|
||||
|
||||
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
recompute_scheduler: Scheduler,
|
||||
):
|
||||
"""
|
||||
Test async recompute case - invalid blocks not cached after transfer.
|
||||
|
||||
When async KV loading has invalid blocks and retry_policy is 'recompute':
|
||||
1. Blocks are allocated but not cached yet
|
||||
2. When async transfer completes, only valid blocks should be cached
|
||||
3. Invalid blocks should never enter the prefix cache
|
||||
|
||||
This test verifies correctness, the failed_recving_kv_req_ids protection
|
||||
ensures only valid blocks are cached when the transfer completes, and we
|
||||
only evict blocks from cache that are already hashed in the block table.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
num_prompt_blocks = 100
|
||||
num_external_computed_blocks = 99
|
||||
invalid_block_idx = 50
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||
num_external_computed_tokens = (
|
||||
num_external_computed_blocks * recompute_scheduler.block_size
|
||||
)
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
recompute_scheduler.add_request(request=request)
|
||||
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
# mock connector indicating async load
|
||||
recompute_scheduler.connector = Mock()
|
||||
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||
)
|
||||
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||
recompute_scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be waiting for remote KVs
|
||||
assert len(recompute_scheduler.waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# get the allocated block IDs
|
||||
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||
request.request_id
|
||||
)
|
||||
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||
invalid_block_ids = {invalid_block_id}
|
||||
|
||||
# get the block object to verify it's not cached yet and stays uncached
|
||||
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||
|
||||
# verify block has no hash before invalid blocks are reported
|
||||
assert block.block_hash is None, (
|
||||
"Async loading blocks should not be cached yet (no hash)"
|
||||
)
|
||||
|
||||
# report invalid blocks (transfer not finished yet)
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=None, # transfer NOT finished
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
# critical: spy on evict_blocks to verify it's NOT called for async blocks
|
||||
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
|
||||
evict_blocks_calls = []
|
||||
|
||||
def evict_blocks_spy(block_ids):
|
||||
evict_blocks_calls.append(set(block_ids))
|
||||
return original_evict_blocks(block_ids)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||
):
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||
assert len(evict_blocks_calls) == 0, (
|
||||
f"evict_blocks should not be called for async-only invalid blocks, "
|
||||
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
|
||||
)
|
||||
|
||||
# request should still be waiting (not finished with error due to recompute policy)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# verify num_computed_tokens was truncated to before invalid block
|
||||
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||
assert request.num_computed_tokens == expected_valid_tokens
|
||||
|
||||
# verify invalid block still has no hash (was not evicted)
|
||||
assert block.block_hash is None, (
|
||||
f"Async loading blocks shouldn't be cached or evicted. "
|
||||
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||
)
|
||||
|
||||
# now simulate async transfer completing
|
||||
model_runner_output_2 = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request.request_id},
|
||||
invalid_block_ids=None,
|
||||
use_eos=False,
|
||||
)
|
||||
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
|
||||
|
||||
# verify request is now marked as finished receiving and ready to be processed
|
||||
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
|
||||
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||
|
||||
# critical: verify invalid block still has no hash before recompute
|
||||
# the async transfer invalid data was never cached
|
||||
assert block.block_hash is None, (
|
||||
f"Invalid block {invalid_block_id} should not be cached before recompute "
|
||||
f"(hash should be None), but hash is {block.block_hash}"
|
||||
)
|
||||
|
||||
# critical end-to-end test: spy on cache_blocks to verify it's called with
|
||||
# the truncated num_computed_tokens value
|
||||
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
|
||||
cache_blocks_calls = []
|
||||
|
||||
def cache_blocks_spy(req, num_tokens):
|
||||
cache_blocks_calls.append((req.request_id, num_tokens))
|
||||
return original_cache_blocks(req, num_tokens)
|
||||
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
|
||||
):
|
||||
# call schedule() again - this triggers _update_waiting_for_remote_kv()
|
||||
# which should call cache_blocks with the truncated value
|
||||
recompute_scheduler.schedule()
|
||||
|
||||
# verify cache_blocks was called with the truncated value
|
||||
assert len(cache_blocks_calls) == 1, (
|
||||
f"cache_blocks should be called exactly once, "
|
||||
f"got {len(cache_blocks_calls)} calls"
|
||||
)
|
||||
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
|
||||
assert cached_req_id == request.request_id
|
||||
assert cached_num_tokens == expected_valid_tokens, (
|
||||
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
|
||||
f"but was called with {cached_num_tokens}"
|
||||
)
|
||||
|
||||
# request should now be RUNNING (scheduled immediately after transfer completes)
|
||||
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
|
||||
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
|
||||
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
|
||||
assert request.num_computed_tokens >= expected_valid_tokens, (
|
||||
f"num_computed_tokens should be at least {expected_valid_tokens}, "
|
||||
f"got {request.num_computed_tokens}"
|
||||
)
|
||||
|
||||
# request should no longer be in the failed/finished receiving sets
|
||||
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
|
||||
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# request should be in the running queue
|
||||
assert request in recompute_scheduler.running
|
||||
60
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
60
tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
|
||||
ExampleConnectorMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
ensure_kv_transfer_initialized,
|
||||
get_kv_transfer_group,
|
||||
)
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
|
||||
# Importing utils registers TestExampleConnector with the factory
|
||||
from .utils import create_vllm_config
|
||||
|
||||
|
||||
def _make_empty_scheduler_output():
|
||||
return SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
kv_connector_metadata=ExampleConnectorMetadata(),
|
||||
)
|
||||
|
||||
|
||||
def test_kv_connector_mixin_clears_metadata():
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
|
||||
vllm_config.kv_transfer_config.kv_role = "kv_both"
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
|
||||
|
||||
# Initialize the global connector instance
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
try:
|
||||
# Minimal scheduler output with empty metadata; mixin should still
|
||||
# bind/clear metadata even if no loads happen
|
||||
scheduler_output = _make_empty_scheduler_output()
|
||||
|
||||
# Invoke the no-forward path which uses the mixin context manager
|
||||
KVConnectorModelRunnerMixin.kv_connector_no_forward(
|
||||
scheduler_output, vllm_config
|
||||
)
|
||||
|
||||
# Verify clear_connector_metadata was called on the connector
|
||||
connector = get_kv_transfer_group()
|
||||
assert connector._connector_metadata is None
|
||||
# Test connector wrapper records method calls
|
||||
assert connector.call_record.get("bind_connector_metadata", 0) == 1
|
||||
assert connector.call_record.get("clear_connector_metadata", 0) == 1
|
||||
finally:
|
||||
# Ensure we clean up the global connector between tests
|
||||
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()
|
||||
335
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
335
tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
def _make_get_num_new_matched_tokens(
|
||||
req_num_new_matched_tokens: dict[str, int],
|
||||
async_load,
|
||||
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||
return value, async_load
|
||||
|
||||
return get_num_new_matched_tokens
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler():
|
||||
vllm_config = create_vllm_config()
|
||||
return create_scheduler(vllm_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_async_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
(req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id)
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving={request1.request_id, request3.request_id},
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(invalid_block_idxs)
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
if request.request_id == request2.request_id:
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
else:
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 98}),
|
||||
(100, 99, {50, 98}),
|
||||
(100, 99, {98}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request2)
|
||||
request3 = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request3)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.request_finished.return_value = (False, None)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: num_external_computed_tokens,
|
||||
request3.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 3
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
|
||||
# Simulate a failure in loading some of request2 blocks.
|
||||
req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0]
|
||||
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2, request3],
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == request2.request_id
|
||||
assert scheduler.running[0].num_computed_tokens == (
|
||||
min(invalid_block_idxs) * scheduler.block_size
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
assert scheduler.connector.request_finished.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,"
|
||||
"num_external_computed_blocks,"
|
||||
"num_common_prefix_blocks,"
|
||||
"invalid_block_idxs",
|
||||
[
|
||||
(100, 99, 50, {0, 49}),
|
||||
(100, 99, 50, {25, 49}),
|
||||
(100, 99, 50, {49}),
|
||||
],
|
||||
)
|
||||
def test_sync_load_failure_with_shared_blocks(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
num_common_prefix_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
|
||||
|
||||
request1 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request1)
|
||||
request2 = create_request(
|
||||
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
|
||||
)
|
||||
scheduler.add_request(request=request2)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: num_external_computed_tokens,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 2
|
||||
for request in scheduler_output.scheduled_new_reqs:
|
||||
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
# Simulate a failure in loading some of the shared blocks.
|
||||
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# req_id -> num_computed_tokens
|
||||
# all the common prefix blocks will be computed by request1
|
||||
expected_computed_tokens = {
|
||||
request1.request_id: min(invalid_block_idxs) * scheduler.block_size,
|
||||
request2.request_id: common_prefix_len,
|
||||
}
|
||||
|
||||
assert len(scheduler.running) == 2
|
||||
for request in scheduler.running:
|
||||
assert (
|
||||
request.num_computed_tokens == expected_computed_tokens[request.request_id]
|
||||
)
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
|
||||
[
|
||||
(100, 99, {0, 50, 98}),
|
||||
(100, 99, {98, 50, 0}),
|
||||
],
|
||||
)
|
||||
def test_async_progressive_load_failure(
|
||||
scheduler: Scheduler,
|
||||
num_prompt_blocks: int,
|
||||
num_external_computed_blocks: int,
|
||||
invalid_block_idxs: set[int],
|
||||
):
|
||||
assert num_prompt_blocks >= num_external_computed_blocks
|
||||
|
||||
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
|
||||
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
|
||||
|
||||
request = create_request(num_tokens=num_prompt_tokens)
|
||||
scheduler.add_request(request=request)
|
||||
|
||||
# Mock KV connector method.
|
||||
# req_id -> num_external_computed_tokens
|
||||
req_num_new_matched_tokens = {
|
||||
request.request_id: num_external_computed_tokens,
|
||||
}
|
||||
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
|
||||
)
|
||||
scheduler.connector.take_events.return_value = ()
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == 0
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
|
||||
min_invalid_block_idx = max(invalid_block_idxs) + 1
|
||||
# Simulate failures when progressively loading request blocks.
|
||||
for invalid_block_idx in invalid_block_idxs:
|
||||
(req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[],
|
||||
finished_recving=set(),
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
use_eos=True,
|
||||
)
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
756
tests/v1/kv_connector/unit/test_lmcache_connector.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
|
||||
LMCacheConnectorV1,
|
||||
LMCacheKVEvents,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lmcache_engine_event():
|
||||
"""Create a mock event object that mimics what the lmcache engine returns."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(
|
||||
self,
|
||||
block_hashes,
|
||||
parent_block_hash,
|
||||
token_ids,
|
||||
lora_id,
|
||||
block_size,
|
||||
medium,
|
||||
):
|
||||
self.block_hashes = block_hashes
|
||||
self.parent_block_hash = parent_block_hash
|
||||
self.token_ids = token_ids
|
||||
self.lora_id = lora_id
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
return MockEvent(
|
||||
block_hashes=["hash1", "hash2"],
|
||||
parent_block_hash="parent_hash",
|
||||
token_ids=[1, 2, 3, 4],
|
||||
lora_id=None,
|
||||
block_size=16,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connector():
|
||||
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
|
||||
connector = MagicMock(spec=LMCacheConnectorV1)
|
||||
connector._kv_cache_events = None
|
||||
connector._lmcache_engine = MagicMock()
|
||||
|
||||
# Make the methods use the real implementation
|
||||
connector.get_kv_connector_kv_cache_events = (
|
||||
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.update_connector_output = (
|
||||
LMCacheConnectorV1.update_connector_output.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
)
|
||||
connector.take_events = LMCacheConnectorV1.take_events.__get__(
|
||||
connector, LMCacheConnectorV1
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
class TestGetKVConnectorKVCacheEvents:
|
||||
"""Test get_kv_connector_kv_cache_events method."""
|
||||
|
||||
def test_returns_none_when_no_events(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine has no events."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
|
||||
|
||||
def test_returns_none_when_empty_list(self, mock_connector):
|
||||
"""Test that None is returned when lmcache engine returns empty list."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = []
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test conversion of a single event from lmcache engine format."""
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
assert result.get_number_of_workers() == 1
|
||||
|
||||
events = result.get_all_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], BlockStored)
|
||||
assert events[0].block_hashes == ["hash1", "hash2"]
|
||||
assert events[0].parent_block_hash == "parent_hash"
|
||||
assert events[0].token_ids == [1, 2, 3, 4]
|
||||
assert events[0].lora_id is None
|
||||
assert events[0].block_size == 16
|
||||
assert events[0].medium == "GPU"
|
||||
|
||||
def test_converts_multiple_events(self, mock_connector):
|
||||
"""Test conversion of multiple events from lmcache engine format."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, i):
|
||||
self.block_hashes = [f"hash{i}"]
|
||||
self.parent_block_hash = f"parent{i}"
|
||||
self.token_ids = [i]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
events = [MockEvent(i) for i in range(5)]
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = events
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LMCacheKVEvents)
|
||||
|
||||
converted_events = result.get_all_events()
|
||||
assert len(converted_events) == 5
|
||||
|
||||
for i, event in enumerate(converted_events):
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == [f"hash{i}"]
|
||||
assert event.parent_block_hash == f"parent{i}"
|
||||
assert event.token_ids == [i]
|
||||
|
||||
def test_preserves_event_attributes(self, mock_connector):
|
||||
"""Test that all event attributes are correctly preserved."""
|
||||
|
||||
class MockEventWithLora:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
|
||||
self.parent_block_hash = "parent_xyz"
|
||||
self.token_ids = [100, 200, 300]
|
||||
self.lora_id = 42
|
||||
self.block_size = 32
|
||||
self.medium = "DISK"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventWithLora()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
event = events[0]
|
||||
|
||||
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
|
||||
assert event.parent_block_hash == "parent_xyz"
|
||||
assert event.token_ids == [100, 200, 300]
|
||||
assert event.lora_id == 42
|
||||
assert event.block_size == 32
|
||||
assert event.medium == "DISK"
|
||||
|
||||
def test_handles_none_parent_block_hash(self, mock_connector):
|
||||
"""Test handling of events with None parent_block_hash."""
|
||||
|
||||
class MockEventNoParent:
|
||||
def __init__(self):
|
||||
self.block_hashes = ["hash1"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1, 2]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEventNoParent()
|
||||
]
|
||||
|
||||
result = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
events = result.get_all_events()
|
||||
assert events[0].parent_block_hash is None
|
||||
|
||||
|
||||
class TestUpdateConnectorOutput:
|
||||
"""Test update_connector_output method."""
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that method returns early when kv_cache_events is None."""
|
||||
connector_output = KVConnectorOutput(kv_cache_events=None)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that method returns early when kv_cache_events is not
|
||||
LMCacheKVEvents."""
|
||||
# Create a mock object that is not LMCacheKVEvents
|
||||
fake_events = MagicMock()
|
||||
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_sets_kv_cache_events_when_none(self, mock_connector):
|
||||
"""Test that _kv_cache_events is set when it was None."""
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1, 2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
assert mock_connector._kv_cache_events is kv_events
|
||||
|
||||
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
|
||||
"""Test that events are added when _kv_cache_events already exists."""
|
||||
# Set up existing events
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
existing_events.add_events([event1])
|
||||
existing_events.add_events([event1]) # Simulate 2 workers reporting
|
||||
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events to add
|
||||
new_events = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event2])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Check that events were added
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3 # 2 from existing + 1 from new
|
||||
assert event1 in all_events
|
||||
assert event2 in all_events
|
||||
|
||||
def test_increments_workers_when_kv_cache_events_already_exists(
|
||||
self, mock_connector
|
||||
):
|
||||
"""Test that worker count is incremented correctly."""
|
||||
# Set up existing events with 2 workers
|
||||
existing_events = LMCacheKVEvents(num_workers=2)
|
||||
mock_connector._kv_cache_events = existing_events
|
||||
|
||||
# Create new events from 3 workers
|
||||
new_events = LMCacheKVEvents(num_workers=3)
|
||||
event = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
new_events.add_events([event])
|
||||
|
||||
connector_output = KVConnectorOutput(kv_cache_events=new_events)
|
||||
|
||||
mock_connector.update_connector_output(connector_output)
|
||||
|
||||
# Worker count should be 2 + 3 = 5
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
|
||||
|
||||
def test_multiple_updates(self, mock_connector):
|
||||
"""Test multiple consecutive updates."""
|
||||
# First update
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events2.add_events([event2])
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Third update
|
||||
events3 = LMCacheKVEvents(num_workers=1)
|
||||
event3 = BlockStored(
|
||||
block_hashes=["hash3"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events3.add_events([event3])
|
||||
output3 = KVConnectorOutput(kv_cache_events=events3)
|
||||
mock_connector.update_connector_output(output3)
|
||||
|
||||
# Check final state
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 3
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
|
||||
|
||||
def test_updates_with_empty_events(self, mock_connector):
|
||||
"""Test updating with empty event lists."""
|
||||
# First update with actual events
|
||||
events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
events1.add_events([event1])
|
||||
output1 = KVConnectorOutput(kv_cache_events=events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Second update with empty events
|
||||
events2 = LMCacheKVEvents(num_workers=2)
|
||||
# No events added
|
||||
output2 = KVConnectorOutput(kv_cache_events=events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Should still have the original event
|
||||
all_events = mock_connector._kv_cache_events.get_all_events()
|
||||
assert len(all_events) == 1
|
||||
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
|
||||
|
||||
|
||||
class TestTakeEvents:
|
||||
"""Test take_events method."""
|
||||
|
||||
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
|
||||
"""Test that nothing is yielded when _kv_cache_events is None."""
|
||||
mock_connector._kv_cache_events = None
|
||||
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
assert events == []
|
||||
|
||||
def test_yields_events_and_clears(self, mock_connector):
|
||||
"""Test that events are yielded and then cleared."""
|
||||
# Set up events
|
||||
kv_events = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events.add_events([event1, event2])
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Check that events were yielded
|
||||
assert len(events) == 2
|
||||
assert event1 in events
|
||||
assert event2 in events
|
||||
|
||||
# Check that _kv_cache_events was cleared
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_aggregates_before_yielding(self, mock_connector):
|
||||
"""Test that events are aggregated before yielding."""
|
||||
# Set up events from multiple workers
|
||||
kv_events = LMCacheKVEvents(num_workers=3)
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
uncommon_event = BlockStored(
|
||||
block_hashes=["hash_uncommon"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# All 3 workers report common_event
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
kv_events.add_events([common_event])
|
||||
|
||||
# Only 1 worker reports uncommon_event
|
||||
kv_events.add_events([uncommon_event])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# Only the common event should be yielded
|
||||
assert len(events) == 1
|
||||
assert events[0] == common_event
|
||||
|
||||
def test_multiple_take_events_calls(self, mock_connector):
|
||||
"""Test calling take_events multiple times."""
|
||||
# First call with events
|
||||
kv_events1 = LMCacheKVEvents(num_workers=1)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events1.add_events([event1])
|
||||
mock_connector._kv_cache_events = kv_events1
|
||||
|
||||
events1 = list(mock_connector.take_events())
|
||||
assert len(events1) == 1
|
||||
assert events1[0] == event1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
# Second call with no events
|
||||
events2 = list(mock_connector.take_events())
|
||||
assert events2 == []
|
||||
|
||||
# Third call after adding new events
|
||||
kv_events2 = LMCacheKVEvents(num_workers=1)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
kv_events2.add_events([event2])
|
||||
mock_connector._kv_cache_events = kv_events2
|
||||
|
||||
events3 = list(mock_connector.take_events())
|
||||
assert len(events3) == 1
|
||||
assert events3[0] == event2
|
||||
|
||||
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
|
||||
"""Test that nothing is yielded if aggregation removes all events."""
|
||||
# Set up events from 2 workers with no common events
|
||||
kv_events = LMCacheKVEvents(num_workers=2)
|
||||
event1 = BlockStored(
|
||||
block_hashes=["hash1"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[1],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
event2 = BlockStored(
|
||||
block_hashes=["hash2"],
|
||||
parent_block_hash=None,
|
||||
token_ids=[2],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Worker 1 reports event1
|
||||
kv_events.add_events([event1])
|
||||
# Worker 2 reports event2
|
||||
kv_events.add_events([event2])
|
||||
|
||||
mock_connector._kv_cache_events = kv_events
|
||||
|
||||
# Take events
|
||||
events = list(mock_connector.take_events())
|
||||
|
||||
# No common events, so nothing should be yielded
|
||||
assert events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Test integration scenarios."""
|
||||
|
||||
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
|
||||
"""Test a complete workflow from getting events to taking them."""
|
||||
# Step 1: Get events from lmcache engine
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
mock_lmcache_engine_event
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is not None
|
||||
assert len(kv_events.get_all_events()) == 1
|
||||
|
||||
# Step 2: Update connector output (simulate receiving from worker)
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
assert mock_connector._kv_cache_events is not None
|
||||
|
||||
# Step 3: Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert len(taken_events) == 1
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_multiple_workers_workflow(self, mock_connector):
|
||||
"""Test workflow with multiple workers."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, hash_val):
|
||||
self.block_hashes = [hash_val]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [1]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
# Worker 1
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker1"),
|
||||
]
|
||||
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
|
||||
mock_connector.update_connector_output(output1)
|
||||
|
||||
# Worker 2
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent("hash_common"),
|
||||
MockEvent("hash_worker2"),
|
||||
]
|
||||
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
|
||||
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
|
||||
mock_connector.update_connector_output(output2)
|
||||
|
||||
# Take events (should only get common events)
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# With aggregation, only events reported by both workers should be present
|
||||
# In this case, hash_common was reported by both
|
||||
event_hashes = [e.block_hashes[0] for e in taken_events]
|
||||
assert "hash_common" in event_hashes
|
||||
|
||||
def test_empty_workflow(self, mock_connector):
|
||||
"""Test workflow when there are no events at any stage."""
|
||||
# Get events returns None
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = None
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
assert kv_events is None
|
||||
|
||||
# Update with None
|
||||
output = KVConnectorOutput(kv_cache_events=None)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take events
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
assert taken_events == []
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_repeated_cycles(self, mock_connector):
|
||||
"""Test multiple cycles of the complete workflow."""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, cycle_num):
|
||||
self.block_hashes = [f"hash_cycle_{cycle_num}"]
|
||||
self.parent_block_hash = None
|
||||
self.token_ids = [cycle_num]
|
||||
self.lora_id = None
|
||||
self.block_size = 16
|
||||
self.medium = "GPU"
|
||||
|
||||
for cycle in range(3):
|
||||
# Get events
|
||||
mock_connector._lmcache_engine.get_kv_events.return_value = [
|
||||
MockEvent(cycle)
|
||||
]
|
||||
kv_events = mock_connector.get_kv_connector_kv_cache_events()
|
||||
|
||||
# Update
|
||||
output = KVConnectorOutput(kv_cache_events=kv_events)
|
||||
mock_connector.update_connector_output(output)
|
||||
|
||||
# Take
|
||||
taken_events = list(mock_connector.take_events())
|
||||
|
||||
# Verify
|
||||
assert len(taken_events) == 1
|
||||
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
|
||||
assert mock_connector._kv_cache_events is None
|
||||
|
||||
def test_lmcache_kv_events_aggregation(self):
|
||||
"""
|
||||
Test LMCacheKVEvents aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Define common and unique events
|
||||
common_event = BlockStored(
|
||||
block_hashes=["hash_common"],
|
||||
parent_block_hash="parent_common",
|
||||
token_ids=[1, 2, 3],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker1_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker1"],
|
||||
parent_block_hash="parent_w1",
|
||||
token_ids=[4, 5],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker2_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker2"],
|
||||
parent_block_hash="parent_w2",
|
||||
token_ids=[6, 7],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
worker3_unique_event = BlockStored(
|
||||
block_hashes=["hash_worker3"],
|
||||
parent_block_hash="parent_w3",
|
||||
token_ids=[8, 9],
|
||||
block_size=16,
|
||||
lora_id=None,
|
||||
medium="GPU",
|
||||
)
|
||||
|
||||
# Create events for each worker
|
||||
# Worker 0: reports common event and its unique event
|
||||
worker0_events = LMCacheKVEvents(num_workers=1)
|
||||
worker0_events.add_events([common_event, worker1_unique_event])
|
||||
|
||||
# Worker 1: reports common event and its unique event
|
||||
worker1_events = LMCacheKVEvents(num_workers=1)
|
||||
worker1_events.add_events([common_event, worker2_unique_event])
|
||||
|
||||
# Worker 2: reports common event and its unique event
|
||||
worker2_events = LMCacheKVEvents(num_workers=1)
|
||||
worker2_events.add_events([common_event, worker3_unique_event])
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_events in enumerate(
|
||||
[worker0_events, worker1_events, worker2_events]
|
||||
):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2
|
||||
else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0
|
||||
else None, # Workers 1,2 finished receiving
|
||||
kv_cache_events=worker_events,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
|
||||
|
||||
assert isinstance(kv_cache_events, LMCacheKVEvents)
|
||||
|
||||
# After aggregation, events should be combined from all workers
|
||||
# The aggregator doesn't automatically aggregate events, so we need to call
|
||||
# aggregate() to get only common events
|
||||
kv_cache_events.aggregate()
|
||||
aggregated_events = kv_cache_events.get_all_events()
|
||||
|
||||
# Only the common event should remain after aggregation
|
||||
# because it's the only event reported by all 3 workers
|
||||
assert len(aggregated_events) == 1
|
||||
assert aggregated_events[0] == common_event
|
||||
|
||||
# Verify the common event properties
|
||||
assert aggregated_events[0].block_hashes == ["hash_common"]
|
||||
assert aggregated_events[0].parent_block_hash == "parent_common"
|
||||
assert aggregated_events[0].token_ids == [1, 2, 3]
|
||||
228
tests/v1/kv_connector/unit/test_lmcache_integration.py
Normal file
228
tests/v1/kv_connector/unit/test_lmcache_integration.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# NOTE: if your PR has broken one of the tests here (sorry),
|
||||
# kindly patch the corresponding integration in
|
||||
# /vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
|
||||
# or reach out to @aposataC for assistance
|
||||
|
||||
# Assumption vs. Correctness Tests:
|
||||
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
|
||||
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def assumes(obj, attr, is_callable=False, is_instance_of=None):
|
||||
import inspect
|
||||
from dataclasses import is_dataclass
|
||||
|
||||
assumption_msg = (
|
||||
f"LMCache connector currently assumes that {obj} has a(n) {attr} attribute"
|
||||
)
|
||||
if hasattr(obj, attr):
|
||||
attr_value = getattr(obj, attr)
|
||||
elif is_dataclass(obj) and attr in getattr(obj, "__dataclass_fields__", {}):
|
||||
field = obj.__dataclass_fields__[attr]
|
||||
field_type = field.type
|
||||
origin = getattr(field_type, "__origin__", None)
|
||||
if origin is not None:
|
||||
field_type = origin
|
||||
attr_value = field_type
|
||||
else:
|
||||
raise AssertionError(assumption_msg)
|
||||
if is_callable:
|
||||
assumption_msg += f" and that {obj}.{attr} is a callable"
|
||||
assert callable(attr_value), assumption_msg
|
||||
if is_instance_of:
|
||||
assumption_msg += f" and that {obj}.{attr} is an instance of {is_instance_of}"
|
||||
if isinstance(attr_value, property):
|
||||
fget = attr_value.fget
|
||||
assert fget is not None, f"Property {obj}.{attr} has no fget"
|
||||
sig = inspect.signature(fget)
|
||||
ret_anno = sig.return_annotation
|
||||
assert ret_anno is not inspect._empty, (
|
||||
f"Property {obj}.{attr} has no return annotation"
|
||||
)
|
||||
assert ret_anno == is_instance_of, assumption_msg
|
||||
else:
|
||||
if isinstance(attr_value, type):
|
||||
assert attr_value is is_instance_of, assumption_msg
|
||||
else:
|
||||
assert isinstance(attr_value, is_instance_of), assumption_msg
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_multimodal_interface():
|
||||
# protect against interface changes
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
assumes(PlaceholderRange, "offset")
|
||||
assumes(PlaceholderRange, "length")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_config_interface():
|
||||
# protect against interface changes
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
|
||||
assumes(VllmConfig, "model_config")
|
||||
assumes(VllmConfig, "cache_config")
|
||||
assumes(VllmConfig, "parallel_config")
|
||||
assumes(VllmConfig, "kv_transfer_config")
|
||||
|
||||
assumes(KVTransferConfig, "kv_role")
|
||||
assumes(KVTransferConfig, "kv_connector_extra_config")
|
||||
|
||||
assumes(ModelConfig, "use_mla", is_instance_of=bool)
|
||||
assumes(ModelConfig, "dtype")
|
||||
assumes(ModelConfig, "max_model_len")
|
||||
assumes(ModelConfig, "get_vocab_size", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_attention_heads", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
|
||||
assumes(ModelConfig, "get_head_size", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_layers", is_callable=True)
|
||||
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
|
||||
assumes(ModelConfig, "model")
|
||||
|
||||
assumes(ParallelConfig, "world_size")
|
||||
assumes(ParallelConfig, "rank")
|
||||
assumes(ParallelConfig, "tensor_parallel_size")
|
||||
assumes(ParallelConfig, "pipeline_parallel_size")
|
||||
assumes(ParallelConfig, "data_parallel_size_local")
|
||||
assumes(ParallelConfig, "data_parallel_rank_local")
|
||||
|
||||
assumes(CacheConfig, "cache_dtype")
|
||||
assumes(CacheConfig, "block_size")
|
||||
assumes(CacheConfig, "gpu_memory_utilization")
|
||||
|
||||
# kv metadata minimal case
|
||||
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
|
||||
|
||||
model_config = ModelConfig(dtype="bfloat16")
|
||||
parallel_config = ParallelConfig()
|
||||
cache_config = CacheConfig(cache_dtype="bfloat16")
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
|
||||
use_mla = False
|
||||
chunk_size = 256
|
||||
num_layer = model_config.get_num_layers(parallel_config)
|
||||
num_kv_head = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# dummy lmcache metadata creation example
|
||||
_ = (
|
||||
model_config.model,
|
||||
parallel_config.world_size,
|
||||
parallel_config.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
|
||||
)
|
||||
def test_request_interface():
|
||||
# protect against interface changes
|
||||
from types import NoneType
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
req = Request(
|
||||
request_id="test_request",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
)
|
||||
assumes(req, "mm_features", is_instance_of=(list, NoneType))
|
||||
assumes(req, "request_id")
|
||||
assumes(req, "priority")
|
||||
assumes(req, "prompt_token_ids")
|
||||
assumes(req, "sampling_params")
|
||||
assumes(req, "num_tokens")
|
||||
assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType))
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
|
||||
assumes(MultiModalFeatureSpec, "identifier")
|
||||
assumes(MultiModalFeatureSpec, "mm_position")
|
||||
|
||||
|
||||
def test_new_request_interface():
|
||||
# protect against interface changes
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
assumes(NewRequestData, "req_id")
|
||||
assumes(NewRequestData, "block_ids")
|
||||
assumes(NewRequestData, "prompt_token_ids")
|
||||
assumes(NewRequestData, "sampling_params")
|
||||
|
||||
|
||||
def test_sampling_params_interface():
|
||||
# protect against interface changes
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
assumes(SamplingParams, "extra_args")
|
||||
|
||||
# dumb example use case in LMCache
|
||||
kv_transfer_params = {
|
||||
"lmcache.tag.user": "example_user_1",
|
||||
"lmcache.ttl": 60,
|
||||
}
|
||||
sampling_params = SamplingParams(
|
||||
extra_args={"kv_transfer_params": kv_transfer_params}
|
||||
)
|
||||
assert sampling_params.extra_args["kv_transfer_params"] == kv_transfer_params
|
||||
|
||||
|
||||
def test_tp_interface():
|
||||
# protect against interface changes
|
||||
import inspect
|
||||
|
||||
from vllm.distributed.parallel_state import get_tp_group
|
||||
|
||||
sig = inspect.signature(get_tp_group)
|
||||
GroupCoordinator = sig.return_annotation
|
||||
|
||||
assumes(GroupCoordinator, "broadcast", is_callable=True)
|
||||
assumes(GroupCoordinator, "broadcast_object", is_callable=True)
|
||||
|
||||
|
||||
def test_forward_context_interface():
|
||||
# protect against interface changes
|
||||
from vllm.forward_context import ForwardContext
|
||||
|
||||
assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
|
||||
assumes(ForwardContext, "virtual_engine")
|
||||
assumes(ForwardContext, "attn_metadata")
|
||||
|
||||
|
||||
def test_scheduler_output_interface():
|
||||
# protect against interface changes
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
assumes(SchedulerOutput, "finished_req_ids")
|
||||
assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list)
|
||||
assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict)
|
||||
assumes(SchedulerOutput, "scheduled_cached_reqs")
|
||||
|
||||
from vllm.v1.core.sched.output import CachedRequestData
|
||||
|
||||
assumes(CachedRequestData, "req_ids", is_instance_of=list)
|
||||
assumes(CachedRequestData, "new_block_ids", is_instance_of=list)
|
||||
603
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
603
tests/v1/kv_connector/unit/test_multi_connector.py
Normal file
@@ -0,0 +1,603 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import filecmp
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiConnector,
|
||||
MultiKVConnectorStats,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
PROMPT_CONTEXT = "Hi " * 100
|
||||
PROMPTS = [
|
||||
PROMPT_CONTEXT + "Hello, my name is",
|
||||
PROMPT_CONTEXT + "The capital of France is",
|
||||
]
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
|
||||
|
||||
|
||||
# Test connector with custom stats for testing MultiConnector
|
||||
class MockConnectorStats(KVConnectorStats):
|
||||
"""Mock stats class for testing."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MockConnector(KVConnectorBase_V1):
|
||||
"""Mock connector that implements build_kv_connector_stats for testing."""
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls, data: dict[str, Any] | None = None
|
||||
) -> KVConnectorStats | None:
|
||||
return MockConnectorStats(data=data) if data is not None else None
|
||||
|
||||
|
||||
# Register the mock connector
|
||||
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||
|
||||
|
||||
# Helper function to compare directories recursively
|
||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||
"""Compares two directories recursively for identical content."""
|
||||
dcmp = filecmp.dircmp(dir1, dir2)
|
||||
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
|
||||
print(f"Differences found between {dir1} and {dir2}:")
|
||||
print(f" Left only: {dcmp.left_only}")
|
||||
print(f" Right only: {dcmp.right_only}")
|
||||
print(f" Different files: {dcmp.diff_files}")
|
||||
return False
|
||||
for sub_dir in dcmp.common_dirs:
|
||||
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason=(
|
||||
"hipErrorLaunchFailure when running this test, see issue:"
|
||||
"https://github.com/ROCm/pytorch/issues/2822"
|
||||
),
|
||||
)
|
||||
def test_multi_example_connector_consistency():
|
||||
"""
|
||||
Tests that MultiConnector with two ExampleConnectors saves
|
||||
identical KV cache data to separate storage locations.
|
||||
"""
|
||||
storage_1_path = Path("storage_1/")
|
||||
storage_2_path = Path("storage_2/")
|
||||
shutil.rmtree(storage_1_path, ignore_errors=True)
|
||||
shutil.rmtree(storage_2_path, ignore_errors=True)
|
||||
storage_1_path.mkdir()
|
||||
storage_2_path.mkdir()
|
||||
|
||||
# Configure MultiConnector with two ExampleConnectors
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [
|
||||
{
|
||||
"kv_connector": "TestExampleConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_1_path),
|
||||
"name": "storage1",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
{
|
||||
"kv_connector": "TestExampleConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_extra_config": {
|
||||
"shared_storage_path": str(storage_2_path),
|
||||
"name": "storage2",
|
||||
},
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
|
||||
},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
)
|
||||
# Run generation - this should trigger saving KV cache
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
# --- Verification ---
|
||||
|
||||
# Check that both storage directories were populated
|
||||
local_subdirs = list(storage_1_path.iterdir())
|
||||
external_subdirs = list(storage_2_path.iterdir())
|
||||
|
||||
assert len(local_subdirs) > 0, (
|
||||
f"Local storage path {storage_1_path} is empty after generation."
|
||||
)
|
||||
assert len(external_subdirs) > 0, (
|
||||
f"External storage path {storage_2_path} is empty after generation."
|
||||
)
|
||||
assert len(local_subdirs) == len(external_subdirs), (
|
||||
f"Mismatch in number of cache entries: "
|
||||
f"Local={len(local_subdirs)}, External={len(external_subdirs)}"
|
||||
)
|
||||
|
||||
# The subdirectories should correspond to the prompt hashes
|
||||
# Since prompts are the same, the hash directories should be the same name
|
||||
local_subdir_names = sorted([d.name for d in local_subdirs])
|
||||
external_subdir_names = sorted([d.name for d in external_subdirs])
|
||||
assert local_subdir_names == external_subdir_names, (
|
||||
"Cache directory names do not match between local and external storage"
|
||||
)
|
||||
|
||||
# Compare the contents of each corresponding cache directory
|
||||
for subdir_name in local_subdir_names:
|
||||
print(f"Comparing contents of cache directory: {subdir_name}")
|
||||
assert _compare_directories(
|
||||
storage_1_path / subdir_name, storage_2_path / subdir_name
|
||||
), (
|
||||
f"Contents differ for cache directory '{subdir_name}' between "
|
||||
f"{storage_1_path} and {storage_2_path}"
|
||||
)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens and update_state_after_alloc will be called
|
||||
# on each connector in turn.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage1-WORKER"][:5] == [
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-WORKER"][:5] == [
|
||||
"register_kv_caches",
|
||||
"bind_connector_metadata",
|
||||
"start_load_kv",
|
||||
"wait_for_layer_load",
|
||||
"save_kv_layer",
|
||||
]
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will return new tokens from the first
|
||||
# connector so update_state_after_alloc will be with allocated blocks
|
||||
# on that one but with zero blocks for others (first nonzero match is
|
||||
# chosen).
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Delete storage1 connector state
|
||||
shutil.rmtree(storage_1_path)
|
||||
|
||||
# Reset prefix cache or else we'll just get the tokens back from there.
|
||||
llm.reset_prefix_cache()
|
||||
|
||||
# Run generation again - this should trigger loading from the first
|
||||
# connector.
|
||||
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
|
||||
|
||||
events = get_connector_events()
|
||||
# get_num_new_matched_tokens will be called for both connectors but will
|
||||
# return 0 from the first connector, but the second connector should have
|
||||
# a hit, so update_state_after_alloc will only be called with allocated
|
||||
# blocks for the second connector.
|
||||
assert events["storage1-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[0] 0",
|
||||
"build_connector_meta",
|
||||
]
|
||||
assert events["storage2-SCHEDULER"][:3] == [
|
||||
"get_num_new_matched_tokens 0",
|
||||
"update_state_after_alloc num_blocks=[7] 96",
|
||||
"build_connector_meta",
|
||||
]
|
||||
|
||||
# Clean up
|
||||
shutil.rmtree(storage_1_path)
|
||||
shutil.rmtree(storage_2_path)
|
||||
|
||||
|
||||
def get_connector_events() -> dict[str, list[str]]:
|
||||
# Read in connector events and reset the files.
|
||||
import glob
|
||||
|
||||
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
|
||||
connector_events = {}
|
||||
for fname in event_files:
|
||||
name = fname.split("connector_")[1].split("_events.log")[0]
|
||||
try:
|
||||
with open(fname, "r+") as f:
|
||||
connector_events[name] = [line.strip() for line in f if line.strip()]
|
||||
f.truncate(0)
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not read connector events for {name}: {e}")
|
||||
|
||||
return connector_events
|
||||
|
||||
|
||||
def test_engine_id_conflict():
|
||||
configs = [KVTransferConfig() for _ in range(2)]
|
||||
ids = [config.engine_id for config in configs]
|
||||
assert ids[0] != ids[1], (
|
||||
f"Engine IDs should be different for different configs. Got {ids}"
|
||||
)
|
||||
|
||||
|
||||
class TestMultiConnectorStats:
|
||||
"""Tests for MultiConnector stats reconstruction and operations."""
|
||||
|
||||
def test_build_kv_connector_stats_with_none(self):
|
||||
"""Test that build_kv_connector_stats returns empty stats when given None."""
|
||||
stats = MultiConnector.build_kv_connector_stats(data=None)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 0
|
||||
assert stats.is_empty()
|
||||
|
||||
def test_build_kv_connector_stats_with_empty_dict(self):
|
||||
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
|
||||
stats = MultiConnector.build_kv_connector_stats(data={})
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 0
|
||||
assert stats.is_empty()
|
||||
|
||||
def test_build_kv_connector_stats_reconstructs_nixl_stats(self):
|
||||
"""Test that NixlConnector stats are properly reconstructed with
|
||||
correct data."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5, 2.3],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert "NixlConnector" in stats.data
|
||||
nixl_stats = stats.data["NixlConnector"]
|
||||
assert isinstance(nixl_stats, NixlKVConnectorStats)
|
||||
assert nixl_stats.data["transfer_duration"] == [1.5, 2.3]
|
||||
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||
|
||||
def test_build_kv_connector_stats_with_multiple_connectors(self):
|
||||
"""Test reconstruction with multiple connector types that have custom stats."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
"MockConnector": {"data": {"mock_field": [1, 2, 3]}},
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
# Both connectors should be reconstructed
|
||||
assert len(stats.data) == 2
|
||||
assert "NixlConnector" in stats.data
|
||||
assert "MockConnector" in stats.data
|
||||
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||
# Verify data is preserved
|
||||
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||
|
||||
def test_build_kv_connector_stats_raises_error_for_unknown_connector(self):
|
||||
"""Test that unknown connectors raise an error."""
|
||||
serialized_data = {
|
||||
"UnknownConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Connector 'UnknownConnector' is not registered."
|
||||
):
|
||||
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
def test_build_kv_connector_stats_with_already_instantiated_objects(self):
|
||||
"""Test that already-instantiated stats objects are preserved (same process)."""
|
||||
# This simulates the in-process case where stats are not serialized
|
||||
nixl_stats = NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]})
|
||||
|
||||
data_with_objects = {
|
||||
"NixlConnector": nixl_stats,
|
||||
"MockConnector": mock_stats,
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=data_with_objects)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 2
|
||||
# Verify objects are preserved as-is
|
||||
assert stats.data["NixlConnector"] is nixl_stats
|
||||
assert stats.data["MockConnector"] is mock_stats
|
||||
|
||||
def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self):
|
||||
"""Test handling mixed already-instantiated and serialized stats."""
|
||||
# This can happen during transition or partial serialization
|
||||
nixl_stats = NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
|
||||
mixed_data = {
|
||||
"NixlConnector": nixl_stats, # Already instantiated
|
||||
"MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=mixed_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
assert len(stats.data) == 2
|
||||
# Instantiated object preserved
|
||||
assert stats.data["NixlConnector"] is nixl_stats
|
||||
# Serialized object reconstructed
|
||||
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
|
||||
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
|
||||
|
||||
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
|
||||
"""Test that connectors without custom stats (return None) are skipped."""
|
||||
# ExampleConnector doesn't override build_kv_connector_stats,
|
||||
# so it returns None and should be skipped
|
||||
serialized_data = {
|
||||
"NixlConnector": {
|
||||
"data": {
|
||||
"transfer_duration": [1.5],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
},
|
||||
"ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
|
||||
}
|
||||
|
||||
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
assert stats is not None
|
||||
assert isinstance(stats, MultiKVConnectorStats)
|
||||
# Only NixlConnector should be reconstructed
|
||||
assert len(stats.data) == 1
|
||||
assert "NixlConnector" in stats.data
|
||||
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
|
||||
# ExampleConnector should be skipped (returns None)
|
||||
assert "ExampleConnector" not in stats.data
|
||||
|
||||
def test_build_kv_connector_stats_handles_malformed_data(self):
|
||||
"""Test that malformed data raises appropriate errors."""
|
||||
serialized_data = {
|
||||
"NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}}
|
||||
}
|
||||
|
||||
with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"):
|
||||
MultiConnector.build_kv_connector_stats(data=serialized_data)
|
||||
|
||||
def test_aggregate_same_connector(self):
|
||||
"""Test aggregating stats from the same connector type."""
|
||||
stats1 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
stats2 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [2.0],
|
||||
"post_duration": [0.2],
|
||||
"bytes_transferred": [2048],
|
||||
"num_descriptors": [20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
result = stats1.aggregate(stats2)
|
||||
|
||||
assert result is stats1 # Should return self
|
||||
assert "NixlConnector" in result.data
|
||||
nixl_stats = result.data["NixlConnector"]
|
||||
assert nixl_stats.data["transfer_duration"] == [1.0, 2.0]
|
||||
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
|
||||
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
|
||||
assert nixl_stats.data["num_descriptors"] == [10, 20]
|
||||
|
||||
def test_aggregate_new_connector(self):
|
||||
"""Test aggregating stats when a new connector type appears."""
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats,
|
||||
)
|
||||
|
||||
stats1 = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0],
|
||||
"post_duration": [0.1],
|
||||
"bytes_transferred": [1024],
|
||||
"num_descriptors": [10],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
stats2 = MultiKVConnectorStats(
|
||||
data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
|
||||
)
|
||||
|
||||
result = stats1.aggregate(stats2)
|
||||
|
||||
assert "NixlConnector" in result.data
|
||||
assert "ExampleConnector" in result.data
|
||||
|
||||
def test_reduce(self):
|
||||
"""Test that reduce() correctly reduces all nested connector stats."""
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0, 2.0],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
reduced = stats.reduce()
|
||||
|
||||
assert "NixlConnector" in reduced
|
||||
assert isinstance(reduced["NixlConnector"], dict)
|
||||
# Check that the stats were reduced (should have aggregated values)
|
||||
assert "Num successful transfers" in reduced["NixlConnector"]
|
||||
assert reduced["NixlConnector"]["Num successful transfers"] == 2
|
||||
|
||||
def test_reset(self):
|
||||
"""Test that reset() resets all nested connector stats."""
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(
|
||||
data={
|
||||
"transfer_duration": [1.0, 2.0],
|
||||
"post_duration": [0.1, 0.2],
|
||||
"bytes_transferred": [1024, 2048],
|
||||
"num_descriptors": [10, 20],
|
||||
"num_failed_transfers": [],
|
||||
"num_failed_notifications": [],
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
assert not stats.is_empty()
|
||||
|
||||
stats.reset()
|
||||
|
||||
# After reset, stats should be empty
|
||||
assert stats.is_empty()
|
||||
nixl_stats = stats.data["NixlConnector"]
|
||||
assert len(nixl_stats.data["transfer_duration"]) == 0
|
||||
|
||||
def test_is_empty_with_multiple_connectors(self):
|
||||
"""Test is_empty() returns correct value with multiple connectors."""
|
||||
# All empty
|
||||
stats = MultiKVConnectorStats(
|
||||
data={
|
||||
"NixlConnector": NixlKVConnectorStats(data={}),
|
||||
}
|
||||
)
|
||||
# Initialize empty stats
|
||||
stats.data["NixlConnector"].reset()
|
||||
assert stats.is_empty()
|
||||
|
||||
# One non-empty
|
||||
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
|
||||
assert not stats.is_empty()
|
||||
1791
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
1791
tests/v1/kv_connector/unit/test_nixl_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
534
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
534
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector,
|
||||
OffloadingConnectorMetadata,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
LoadStoreSpec,
|
||||
OffloadingEvent,
|
||||
OffloadingManager,
|
||||
PrepareStoreOutput,
|
||||
)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingHandler,
|
||||
TransferResult,
|
||||
TransferSpec,
|
||||
)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (
|
||||
EOS_TOKEN_ID,
|
||||
create_model_runner_output,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "Mock"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_hashes)
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
self.completed_transfers = []
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
return True
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (
|
||||
MockLoadStoreSpec(block_hashes)
|
||||
)
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _, __
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
return specs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferSummary:
|
||||
gpu_block_indices: list[int]
|
||||
offload_addresses: list[Any]
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
def __init__(
|
||||
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
|
||||
):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
block_size=gpu_block_size, max_num_batched_tokens=1000
|
||||
)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector", # noqa: E501
|
||||
"block_size": offloaded_block_size,
|
||||
},
|
||||
)
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(
|
||||
vllm_config, num_blocks=num_gpu_blocks
|
||||
)
|
||||
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_cross_layers_kv_cache(
|
||||
kv_cache=torch.empty(0),
|
||||
attn_backend=FlashAttentionBackend,
|
||||
)
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, OffloadingConnector)
|
||||
self.scheduler_connector: OffloadingConnector = scheduler_connector
|
||||
|
||||
# extract mocked OffloadingManager of scheduler connector
|
||||
connector_scheduler = scheduler_connector.connector_scheduler
|
||||
assert connector_scheduler is not None
|
||||
manager = connector_scheduler.manager
|
||||
assert isinstance(manager, MagicMock)
|
||||
self.manager: MagicMock = manager
|
||||
|
||||
assert connector_scheduler.gpu_block_size == gpu_block_size
|
||||
assert connector_scheduler.offloaded_block_size == offloaded_block_size
|
||||
|
||||
# extract OffloadingSpec of worker_connector
|
||||
connector_worker = self.worker_connector.connector_worker
|
||||
assert connector_worker is not None
|
||||
offloading_spec = connector_worker.spec
|
||||
assert isinstance(offloading_spec, MockOffloadingSpec)
|
||||
self.offloading_spec: MockOffloadingSpec = offloading_spec
|
||||
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(
|
||||
no_compile_layers={}, attn_metadata={}, virtual_engine=0
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=1000),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in self.offloading_spec.get_completed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = len(offload_addresses) - len(
|
||||
gpu_block_indices
|
||||
)
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks.values():
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
self._update_gpu_block_idx()
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
token_id=token_id,
|
||||
)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
assert not self.scheduler.running
|
||||
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending
|
||||
)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses
|
||||
):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
yield runner_factory # pass factory to the test
|
||||
|
||||
|
||||
def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
block_hashes = list(block_hashes)
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=list(block_hashes),
|
||||
store_spec=MockLoadStoreSpec(block_hashes),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
)
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(
|
||||
decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
|
||||
)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
|
||||
# verify hashes are the same, except for the last block
|
||||
assert block_hashes1[:5] == block_hashes2[:5]
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(
|
||||
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
|
||||
)
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(
|
||||
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
|
||||
)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
|
||||
)
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False
|
||||
)
|
||||
yield OffloadingEvent(
|
||||
block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True
|
||||
)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
assert len(events) == 2
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == to_hashes([1, 2, 3])
|
||||
assert event.block_size == 16
|
||||
assert event.medium == "A"
|
||||
assert event.token_ids == []
|
||||
assert event.parent_block_hash is None
|
||||
assert event.lora_id is None
|
||||
event = events[1]
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
||||
122
tests/v1/kv_connector/unit/test_output_aggregator.py
Normal file
122
tests/v1/kv_connector/unit/test_output_aggregator.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
def __init__(
|
||||
self,
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
expected_finished_count: int = 0,
|
||||
):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"DummyModelRunnerOutput("
|
||||
f"finished_sending={self.kv_connector_output.finished_sending},"
|
||||
f"finished_recving={self.kv_connector_output.finished_recving})"
|
||||
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})"
|
||||
)
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert not aggregated.invalid_block_ids
|
||||
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, finished_recving={"req2"}
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {1}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
|
||||
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending == {"req1"}
|
||||
assert aggregated.finished_recving is None
|
||||
assert aggregated.invalid_block_ids == {2}
|
||||
|
||||
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, invalid_block_ids={4, 5}
|
||||
)
|
||||
|
||||
aggregated = aggregator.aggregate([output1, output2])
|
||||
|
||||
assert aggregated is output1
|
||||
aggregated = aggregated.kv_connector_output
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
def test_aggregate_workers_output_with_expected_finished_count():
|
||||
# We create the aggregator expecting to collect from 4 workers
|
||||
aggregator = KVOutputAggregator(expected_finished_count=4)
|
||||
assert aggregator._expected_finished_count == 4
|
||||
# Some request with default expected finished requests
|
||||
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
aggregated = aggregator.aggregate([output1])
|
||||
# still expecting to collect from 4 workers
|
||||
assert aggregator._send_remaining_count["req1"] == 3
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
assert not aggregated.kv_connector_output.finished_recving
|
||||
|
||||
# Workers discover and find that in this setup they only need to
|
||||
# collect from 2
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, expected_finished_count=2
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, expected_finished_count=2
|
||||
)
|
||||
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
|
||||
# Req2 only needs 2 acks
|
||||
aggregated = aggregator.aggregate([output1, output2, output3])
|
||||
assert aggregated.kv_connector_output.expected_finished_count == 2
|
||||
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
|
||||
# Req2 is finished
|
||||
assert "req2" not in aggregator._recv_remaining_count
|
||||
assert aggregated.kv_connector_output.finished_recving == {"req2"}
|
||||
|
||||
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
|
||||
# NOTE: This is to showcase dynamic update. Workers are responsible for
|
||||
# ensuring "req1" termination in this case
|
||||
assert aggregator._send_remaining_count["req1"] == 2
|
||||
262
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
262
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# Ensure the request is finished after 1 token.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and in Persistent Batch ...
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
# STEP (2): Send Finished to PB.
|
||||
# (2a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_id}
|
||||
)
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_short_prompt_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request with short prompt."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Not enough tokens for full block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = BLOCK_SIZE // 2
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.requests) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
# Even though tokens < block_size, there will be kv xfer for partial block.
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == 1
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
# We need to mark sending finish to clear data for persistent batch.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# Use create_model_runner_output to pass kv_connector_output along
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request], finished_sending={request.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still work with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS
|
||||
)
|
||||
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
|
||||
# Ensure we send all block ids, including the partial blocks,
|
||||
# even if there is a cache hit.
|
||||
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_abort_during_kv_transfer():
|
||||
"""Test aborting request does not release blocks for remote decode."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
# Request removed from PB but blocks should not be freed.
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Abort the request, and check the blocks are still not freed
|
||||
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
|
||||
assert len(scheduler.requests) == 1
|
||||
|
||||
# Simulate a finished sending notification
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request.request_id]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert_scheduler_empty(scheduler)
|
||||
577
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
577
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from .utils import (
|
||||
assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request,
|
||||
create_scheduler,
|
||||
create_vllm_config,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == 0
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id}
|
||||
)
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_interleaved_lifecycle():
|
||||
"""Test Remote Prefills Work Well With Other Requests."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
request_local_a = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
request_local_b = create_request(
|
||||
request_id=3,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
|
||||
# STEP 1: Regular request is running.
|
||||
scheduler.add_request(request_local_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request_local_a])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 2: Add a local and remote request.
|
||||
scheduler.add_request(request_local_b)
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request_local_a, request_local_b])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 3: continue running, KVs not arrived yet.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_local_a, request_local_b]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
# STEP 4: KVs arrive.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP 6: Hit EOS and free.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
[request_local_a, request_local_b, request_remote],
|
||||
use_eos=True,
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
request_local = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
common_prefix_len=NUM_TOKENS,
|
||||
do_remote_prefill=False,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Schedule the local prefill request. This should
|
||||
# cause blocks to be cached, but separately from
|
||||
scheduler.add_request(request_local)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_local.request_id]
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Local should have cached blocks (but not all due to preallocate).
|
||||
num_hashed_blocks = 0
|
||||
for block in local_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += 1 if block._block_hash is not None else 0
|
||||
assert num_hashed_blocks > 0
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving={request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
|
||||
request_id
|
||||
]
|
||||
)
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1
|
||||
assert scheduler_output.num_scheduled_tokens[request_id] == 1
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
scheduler.schedule()
|
||||
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_cannot_schedule_after_recv():
|
||||
"""
|
||||
Test that we can handle no schedule after recv due to not
|
||||
enough remaining KV blocks.
|
||||
"""
|
||||
|
||||
# NOTE: the KVCacheManager will use 1 null block.
|
||||
# So there are 5 total working blocks.
|
||||
TOTAL_NUM_BLOCKS = 6
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
|
||||
|
||||
# Prime the KVCache.
|
||||
NUM_PROMPT_BLOCKS = 2
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 2: 5 blocks are in use (2 new for remote blocks).
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 3: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 4: try to schedule, remote request is put to running list
|
||||
# because the transfer is completed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal, request_remote]
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 5: Remote request will be put back to waiting list
|
||||
# because it needs new block to hold generated token.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 6: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 7: now we can schedule (with 2 blocks computed),
|
||||
# request is retrieved from preempted list.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
assert (
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
||||
== NUM_PROMPT_BLOCKS * BLOCK_SIZE
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 8: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_cannot_recv():
|
||||
"""
|
||||
Test that we can handle no schedule KV block transfer due to not
|
||||
enough remaining KV blocks.
|
||||
"""
|
||||
|
||||
# NOTE: the KVCacheManager will use 1 null block.
|
||||
# So there are 5 total working blocks.
|
||||
TOTAL_NUM_BLOCKS = 6
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
|
||||
|
||||
# Prime the KVCache.
|
||||
NUM_PROMPT_BLOCKS = 2
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(
|
||||
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
|
||||
)
|
||||
request_remote = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True,
|
||||
)
|
||||
|
||||
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 2: 3 blocks are in use,
|
||||
# need 3 new for remote blocks but only 2 are available.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
# Should not have KV transfer in progress.
|
||||
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 3: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_normal], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 5: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[], finished_recving={request_remote.request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
# Step 6: schedule remote request
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# Step 7: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=[request_remote], use_eos=True
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
402
tests/v1/kv_connector/unit/utils.py
Normal file
402
tests/v1/kv_connector/unit/utils.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain, count
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
|
||||
ExampleConnector,
|
||||
)
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
)
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
len(
|
||||
scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
].num_cached_block
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
|
||||
)
|
||||
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 64,
|
||||
block_size: int = 16,
|
||||
max_model_len: int = 10000,
|
||||
enable_chunked_prefill: bool = True,
|
||||
enable_permute_local_kv: bool = False,
|
||||
kv_connector_extra_config: dict[str, Any] | None = None,
|
||||
dtype: str = "float16",
|
||||
cache_dtype: str = "auto",
|
||||
hf_overrides: dict[str, Any] | None = None,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42,
|
||||
hf_overrides=hf_overrides or {},
|
||||
)
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype=cache_dtype,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
enable_permute_local_kv=enable_permute_local_kv,
|
||||
kv_connector_extra_config=kv_connector_extra_config or {},
|
||||
)
|
||||
return VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"),
|
||||
)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
|
||||
)
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
|
||||
_request_count = count(1)
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int | None = None,
|
||||
num_tokens: int = 10,
|
||||
common_prefix_len=0,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = sha256,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
assert num_tokens >= common_prefix_len >= 0
|
||||
|
||||
if request_id is None:
|
||||
request_id = next(_request_count)
|
||||
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash_fn)
|
||||
_none_hash_initialized = True
|
||||
|
||||
kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(
|
||||
do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_request_id=f"prefill-{request_id}",
|
||||
remote_block_ids=list(range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
|
||||
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
|
||||
prompt_token_ids = common_prefix + suffix
|
||||
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
mm_features=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = (
|
||||
None
|
||||
if (
|
||||
finished_sending is None
|
||||
and finished_recving is None
|
||||
and invalid_block_ids is None
|
||||
)
|
||||
else KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
)
|
||||
)
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=None,
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
|
||||
class TestExampleConnector(ExampleConnector):
|
||||
def __init__(self, config: VllmConfig, role, kv_cache_config):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = ExampleConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = (
|
||||
tempfile.gettempdir()
|
||||
+ f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
)
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in (
|
||||
"_connector",
|
||||
"call_record",
|
||||
"name",
|
||||
"_event_file",
|
||||
"__class__",
|
||||
"__dict__",
|
||||
"__getattribute__",
|
||||
"__init__",
|
||||
): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
attr = getattr(self._connector, name)
|
||||
|
||||
# Intercept calls to the connector interface and write an event
|
||||
# for each one to a file, which can be read back in the main test proc.
|
||||
if callable(attr):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_record[name] += 1
|
||||
|
||||
# Include args that we're interested in
|
||||
to_log = [name]
|
||||
for arg in args:
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(" ".join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MockKVConfig:
|
||||
matched_tokens: int = 0
|
||||
is_async: bool = False
|
||||
|
||||
|
||||
class MockKVConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
# Scheduler tests check metadata.requests
|
||||
self.requests: list = []
|
||||
|
||||
|
||||
class MockKVConnector(KVConnectorBase_V1):
|
||||
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config
|
||||
self.config = MockKVConfig(
|
||||
matched_tokens=extra_config["matched_tokens"],
|
||||
is_async=extra_config["is_async"],
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
return (self.config.matched_tokens, self.config.is_async)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: KVCacheBlocks,
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
metadata = MockKVConnectorMetadata()
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for req_id in chain(
|
||||
(req.req_id for req in scheduler_output.scheduled_new_reqs),
|
||||
(
|
||||
req_id
|
||||
for req_id in cached_reqs.req_ids
|
||||
if req_id in cached_reqs.resumed_req_ids
|
||||
),
|
||||
):
|
||||
metadata.requests.append({"req_id": req_id})
|
||||
return metadata
|
||||
|
||||
def start_load_kv(self, kv_caches, finished_req_ids):
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name):
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"TestExampleConnector", __name__, TestExampleConnector.__name__
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MockKVConnector", __name__, MockKVConnector.__name__
|
||||
)
|
||||
Reference in New Issue
Block a user