Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

View File

@@ -0,0 +1,257 @@
#!/bin/bash
set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
KV_BUFFER_DEVICE="$2"
shift 2
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
exit 1
;;
esac
done
echo "Running accuracy tests with kv_buffer_device=$KV_BUFFER_DEVICE"
DECODER_KV_LAYOUT=${DECODER_KV_LAYOUT:-"HND"} # Default to HND, optional NHD
if [[ "$DECODER_KV_LAYOUT" == "NHD" ]]; then
KV_CONFIG_HETERO_LAYOUT=',"enable_permute_local_kv":"True"'
else
KV_CONFIG_HETERO_LAYOUT=''
fi
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi
# Models to run
MODEL_NAMES=${MODEL_NAMES:-}
if [[ -n "$MODEL_NAMES" ]]; then
MODELS=("$MODEL_NAMES")
else
MODELS=(
"Qwen/Qwen3-0.6B"
)
fi
# Number of prefill and decode instances to create
NUM_PREFILL_INSTANCES=${NUM_PREFILL_INSTANCES:-1} # Default to 1
NUM_DECODE_INSTANCES=${NUM_DECODE_INSTANCES:-1} # Default to 1
PREFILLER_TP_SIZE=${PREFILLER_TP_SIZE:-1}
DECODER_TP_SIZE=${DECODER_TP_SIZE:-1}
GPU_MEMORY_UTILIZATION=${GPU_MEMORY_UTILIZATION:-0.2}
PREFILL_BLOCK_SIZE=${PREFILL_BLOCK_SIZE:-128}
DECODE_BLOCK_SIZE=${DECODE_BLOCK_SIZE:-128}
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
SMI_BIN=$(which nvidia-smi || which rocm-smi || echo "")
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
get_num_gpus() {
if [[ "$SMI_BIN" == *"nvidia"* ]]; then
echo "$($SMI_BIN --query-gpu=name --format=csv,noheader | wc -l)"
elif [[ "$SMI_BIN" == *"rocm"* ]]; then
echo "$($SMI_BIN -l | grep GPU | wc -l)"
else
# works for non-cuda platforms,
# assuming at least 1 device and
# let system to decide which card to use
echo "1"
fi
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Arrays to store all hosts and ports
PREFILL_HOSTS=()
PREFILL_PORTS=()
DECODE_HOSTS=()
DECODE_PORTS=()
# Start prefill instances
for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs
GPU_ID=$((i % $(get_num_gpus)))
NEXT_GPU=${GPU_ID}
# If PREFILLER_TP_SIZE is more than 1
for (( j=1; j < PREFILLER_TP_SIZE; j++ )); do
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
GPU_ID="${GPU_ID},${NEXT_GPU}"
done
# Calculate port number (base port + instance number)
PORT=$((8100 + i))
# Calculate side channel port. Avoid clash with with TP workers.
SIDE_CHANNEL_PORT=$((5559 + i))
echo "Starting prefill instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT='HND' \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--block-size ${PREFILL_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
PREFILL_HOSTS+=("localhost")
PREFILL_PORTS+=($PORT)
done
# Start decode instances
for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
# Calculate GPU ID - we'll distribute across available GPUs, starting from after prefill GPUs
GPU_ID=$(((i + NEXT_GPU + 1) % $(get_num_gpus)))
# If DECODER_TP_SIZE is more than 1
for (( j=1; j < DECODER_TP_SIZE; j++ )); do
NEXT_GPU=$(((GPU_ID + j) % $(get_num_gpus)))
GPU_ID="${GPU_ID},${NEXT_GPU}"
done
# Calculate port number (base port + instance number)
PORT=$((8200 + i))
# Calculate side channel port
SIDE_CHANNEL_PORT=$((5659 + i * $DECODER_TP_SIZE))
echo "Starting decode instance $i on GPU $GPU_ID, port $PORT"
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$GPU_ID \
VLLM_KV_CACHE_LAYOUT=$DECODER_KV_LAYOUT \
UCX_NET_DEVICES=all \
VLLM_NIXL_SIDE_CHANNEL_PORT=$SIDE_CHANNEL_PORT \
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--block-size ${DECODE_BLOCK_SIZE} \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--kv-transfer-config '$KV_CONFIG'"
# DP-EP attention mode
if [[ -z "$DP_EP" ]]; then
BASE_CMD="${BASE_CMD} --tensor-parallel-size $DECODER_TP_SIZE"
else
echo "DP-EP Attention enabled, deploying with dp=DECODER_TP_SIZE and tp=1"
BASE_CMD="${BASE_CMD} --data-parallel-size $DECODER_TP_SIZE \
--tensor-parallel-size 1 --enable-expert-parallel"
fi
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Store host and port for proxy configuration
DECODE_HOSTS+=("localhost")
DECODE_PORTS+=($PORT)
done
# Wait for all instances to start
for PORT in "${PREFILL_PORTS[@]}"; do
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PORT
done
for PORT in "${DECODE_PORTS[@]}"; do
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $PORT
done
# Build the command for the proxy server with all the hosts and ports
PROXY_CMD="python3 ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port 8192"
# Add all prefill hosts and ports
PROXY_CMD+=" --prefiller-hosts ${PREFILL_HOSTS[@]}"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORTS[@]}"
# Add all decode hosts and ports
PROXY_CMD+=" --decoder-hosts ${DECODE_HOSTS[@]}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORTS[@]}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
TEST_MODEL=$model_name python3 -m pytest -s -x ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_accuracy.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@@ -0,0 +1,148 @@
#!/bin/bash
set -xe
# Parse command line arguments
KV_BUFFER_DEVICE="cuda" # Default to cuda
PREFILL_GPU_ID=4 # Default GPU IDs
DECODE_GPU_ID=5
while [[ $# -gt 0 ]]; do
case $1 in
--kv_buffer_device)
KV_BUFFER_DEVICE="$2"
shift 2
;;
*)
echo "Unknown option $1"
echo "Usage: $0 [--kv_buffer_device <cuda|cpu>]"
exit 1
;;
esac
done
echo "Running edge case tests with kv_buffer_device=$KV_BUFFER_DEVICE (GPUs: $PREFILL_GPU_ID, $DECODE_GPU_ID)"
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\"}"
fi
# Models to run
MODELS=(
"Qwen/Qwen3-0.6B"
)
# Find the git repository root directory
GIT_ROOT=$(git rev-parse --show-toplevel)
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM to start.
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Function to clean up previous instances
cleanup_instances() {
echo "Cleaning up any running vLLM instances..."
pkill -f "vllm serve" || true
sleep 2
}
# Handle to get model-specific arguments for deepseek
get_model_args() {
local model_name=$1
local extra_args=""
if [[ "$model_name" == "deepseek-ai/deepseek-vl2-tiny" ]]; then
extra_args="--hf_overrides '{\"architectures\": [\"DeepseekVLV2ForCausalLM\"]}' --trust-remote-code"
fi
echo "$extra_args"
}
# Function to run tests for a specific model
run_tests_for_model() {
local model_name=$1
echo "================================"
echo "Testing model: $model_name"
echo "================================"
# Get model-specific arguments
local model_args=$(get_model_args "$model_name")
# Start prefill instance
PREFILL_PORT=8001
BASE_CMD="CUDA_VISIBLE_DEVICES=$PREFILL_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=5559 vllm serve $model_name \
--port $PREFILL_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Start decode instance
DECODE_PORT=8002
# Build the command with or without model-specific args
BASE_CMD="CUDA_VISIBLE_DEVICES=$DECODE_GPU_ID VLLM_NIXL_SIDE_CHANNEL_PORT=6000 vllm serve $model_name \
--port $DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
FULL_CMD="$BASE_CMD $model_args"
else
FULL_CMD="$BASE_CMD"
fi
eval "$FULL_CMD &"
# Wait for all instances to start
echo "Waiting for prefill instance on port $PORT to start..."
wait_for_server $PREFILL_PORT
echo "Waiting for decode instance on port $PORT to start..."
wait_for_server $DECODE_PORT
# Build the command for the proxy server with all the hosts and ports
PROXY_PORT=8192
PROXY_CMD="python ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --port $PROXY_PORT"
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
# Start the proxy server
echo "Starting proxy server with command: $PROXY_CMD"
$PROXY_CMD &
# Wait for the proxy to start
sleep 5
# Run lm eval for this model
echo "Running tests for $model_name"
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
# Clean up before running next model
cleanup_instances
sleep 3
}
# Run tests for each model
for model in "${MODELS[@]}"; do
run_tests_for_model "$model"
done
echo "All tests completed!"

View File

@@ -0,0 +1,156 @@
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_baseline() {
BASELINE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
VLLM_LOGGING_LEVEL=DEBUG \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${BASELINE_HOST} \
--port ${BASELINE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--gpu-memory-utilization 0.5 \
--enforce-eager"
echo ${BASELINE_BASE_CMD}
ssh -tt ${BASELINE_HOST} "${BASELINE_BASE_CMD}" &
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
run_tests(){
local service_url=$1
local mode=$2
python3 ${EXP_ROOT}/test_disagg_accuracy.py --service_url=${service_url} --model_name=${MODEL_NAME} --mode=${mode} --file_name=${OUTPUT_FILE}
}
# run non-disagg. baseline & save outputs
launch_baseline
sleep 2
wait_for_server ${BASELINE_HOST} ${BASELINE_PORT}
run_tests "http://${BASELINE_HOST}:${BASELINE_PORT}" "baseline"
cleanup
sleep 10
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
run_tests "http://${PROXY_HOST}:${PROXY_PORT}" "disagg"
echo "-----P/D success----"
rm ${OUTPUT_FILE}
cleanup
exit 0

View File

@@ -0,0 +1,124 @@
#!/bin/bash
set -xe
# Hosts / ports
PREFILL_HOST=${PREFILL_HOST:-"localhost"}
PREFILL_PORT=${PREFILL_PORT:-8100}
PREFILL_NIXL_SIDE_PORT=${PREFILL_NIXL_SIDE_PORT:-5577}
DECODE_HOST=${DECODE_HOST:-"localhost"}
DECODE_PORT=${DECODE_PORT:-8200}
PROXY_HOST=${PROXY_HOST:-"localhost"}
PROXY_PORT=${PROXY_PORT:-8192}
BASELINE_HOST=${BASELINE_HOST:-"localhost"}
BASELINE_PORT=${BASELINE_PORT:-9290}
# Model to run.
MODEL_NAME=${MODEL_NAME:-"meta-llama/Llama-3.2-3B-Instruct"}
MAX_MODEL_LEN=${MAX_MODEL_LEN:-1024}
BLOCK_SIZE=${BLOCK_SIZE:-32}
# execution env
GIT_ROOT=$(git rev-parse --show-toplevel)
EXP_ROOT="${GIT_ROOT}/tests/v1/kv_connector/nixl_integration"
CONDA_PATH=${CONDA_PATH:-"/home/${USER}/anaconda3"}
CONDA_ENV_NAME=${CONDA_ENV_NAME:-"nixl"}
OUTPUT_FILE=${OUTPUT_FILE:-"${EXP_ROOT}/.tpu_accuracy_test_outputs.txt"}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
# Waits for vLLM server to start.
wait_for_server() {
local host=$1
local port=$2
timeout 1200 bash -c "
until curl -s ${host}:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9 || true
# pkill -f python || true
echo "Cleanup complete. Exiting."
}
launch_pd() {
PREFILL_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
VLLM_NIXL_SIDE_CHANNEL_HOST=${PREFILL_HOST} \
VLLM_NIXL_SIDE_CHANNEL_PORT=${PREFILL_NIXL_SIDE_PORT} \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${PREFILL_HOST} \
--port ${PREFILL_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
DECODE_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
UCX_TLS=tcp \
VLLM_MULTIPROC_EXECUTE_MODEL_TIMEOUT_S=200 \
VLLM_LOGGING_LEVEL=DEBUG \
PJRT_DEVICE=TPU \
VLLM_WORKER_MULTIPROC_METHOD=spawn \
VLLM_ENABLE_V1_MULTIPROCESSING=0 vllm serve $MODEL_NAME \
--host ${DECODE_HOST} \
--port ${DECODE_PORT} \
--max-model-len ${MAX_MODEL_LEN}\
--seed 42 \
--block-size ${BLOCK_SIZE} \
--enforce-eager \
--gpu-memory-utilization 0.5 \
--kv-transfer-config '{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"cpu\"}'"
echo ${PREFILL_BASE_CMD}
echo ${DECODE_BASE_CMD}
sleep 2
# execute on hosts
ssh -tt ${PREFILL_HOST} "${PREFILL_BASE_CMD}" &
ssh -tt ${DECODE_HOST} "${DECODE_BASE_CMD}" &
sleep 1
wait_for_server ${PREFILL_HOST} ${PREFILL_PORT}
sleep 1
wait_for_server ${DECODE_HOST} ${DECODE_PORT}
sleep 1
}
launch_pd_proxy(){
PROXY_BASE_CMD="source ${CONDA_PATH}/bin/activate ${CONDA_ENV_NAME};
python3 ${EXP_ROOT}/toy_proxy_server.py \
--prefiller-host ${PREFILL_HOST} --prefiller-port ${PREFILL_PORT} \
--decoder-host ${DECODE_HOST} --decoder-port ${DECODE_PORT} \
--host=${PROXY_HOST} --port ${PROXY_PORT}"
echo ${PROXY_BASE_CMD}
ssh -tt ${PROXY_HOST} "${PROXY_BASE_CMD}" &
}
# run disagg. & do exact-match with the outputs from baseline
launch_pd
launch_pd_proxy
sleep 10
PREFILL_HOST=${PREFILL_HOST} \
PREFILL_PORT=${PREFILL_PORT} \
DECODE_HOST=${DECODE_HOST} \
DECODE_PORT=${DECODE_PORT} \
PROXY_HOST=${PROXY_HOST} \
PROXY_PORT=${PROXY_PORT} python -m pytest -s -v ${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/test_edge_cases.py

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import lm_eval
import openai
BASE_URL = "http://localhost:8192/v1"
NUM_CONCURRENT = 100
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
# Model-specific expected values
EXPECTED_VALUES = {
"Qwen/Qwen3-0.6B": 0.41,
"deepseek-ai/deepseek-vl2-small": 0.59,
"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
}
SIMPLE_PROMPT = (
"The best part about working on vLLM is that I got to meet so many people across "
"various different organizations like UCB, Google, and Meta which means",
)
# Get model name from environment variable
MODEL_NAME = os.environ.get("TEST_MODEL", "Qwen/Qwen3-0.6B")
def run_simple_prompt():
client = openai.OpenAI(api_key="EMPTY", base_url=BASE_URL)
completion = client.completions.create(model=MODEL_NAME, prompt=SIMPLE_PROMPT)
print("-" * 50)
print(f"Completion results for {MODEL_NAME}:")
print(completion)
print("-" * 50)
def test_accuracy():
"""Run the end to end accuracy test."""
run_simple_prompt()
model_args = (
f"model={MODEL_NAME},"
f"base_url={BASE_URL}/completions,"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False"
)
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
expected_value = EXPECTED_VALUES.get(MODEL_NAME)
if expected_value is None:
print(
f"Warning: No expected value found for {MODEL_NAME}. "
"Skipping accuracy check."
)
print(f"Measured value: {measured_value}")
return
assert (
measured_value - RTOL < expected_value
and measured_value + RTOL > expected_value
), f"Expected: {expected_value} | Measured: {measured_value}"

View File

@@ -0,0 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import os
import time
import openai
import requests
MAX_OUTPUT_LEN = 30
SAMPLE_PROMPTS = (
"Red Hat is the best company in the world to work for because it works on "
"open source software, which means that all the contributions are "
"delivered to the community. As a result, when working on projects like "
"vLLM we are able to meet many amazing people from various organizations "
"like AMD, Google, NVIDIA, ",
"We hold these truths to be self-evident, that all men are created equal, "
"that they are endowed by their Creator with certain unalienable Rights, "
"that among these are Life, Liberty and the pursuit of Happiness.--That "
"to secure these rights, Governments are instituted among Men, deriving "
"their just powers from the consent of the governed, ",
)
def check_vllm_server(url: str, timeout=5, retries=3) -> bool:
"""
Checks if the vLLM server is ready by sending a GET request to the
/health endpoint.
Args:
url (str): The base URL of the vLLM server.
timeout (int): Timeout in seconds for the request.
retries (int): Number of retries if the server is not ready.
Returns:
bool: True if the server is ready, False otherwise.
"""
for attempt in range(retries):
try:
response = requests.get(url, timeout=timeout)
if response.status_code == 200:
return True
else:
print(
f"Attempt {attempt + 1}: Server returned status code "
"{response.status_code}"
)
except requests.exceptions.RequestException as e:
print(f"Attempt {attempt + 1}: Error connecting to server: {e}")
time.sleep(1) # Wait before retrying
return False
def run_simple_prompt(
base_url: str, model_name: str, input_prompt: str, use_chat_endpoint: bool
) -> str:
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
if use_chat_endpoint:
completion = client.chat.completions.create(
model=model_name,
messages=[
{"role": "user", "content": [{"type": "text", "text": input_prompt}]}
],
max_completion_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42,
)
return completion.choices[0].message.content
else:
completion = client.completions.create(
model=model_name,
prompt=input_prompt,
max_tokens=MAX_OUTPUT_LEN,
temperature=0.0,
seed=42,
)
return completion.choices[0].text
def main():
"""
This script demonstrates how to accept two optional string arguments
("service_url" and "file_name") from the command line, each with a
default value of an empty string, using the argparse module.
"""
parser = argparse.ArgumentParser(description="vLLM client script")
parser.add_argument(
"--service_url", # Name of the first argument
type=str,
required=True,
help="The vLLM service URL.",
)
parser.add_argument(
"--model_name", # Name of the first argument
type=str,
required=True,
help="model_name",
)
parser.add_argument(
"--mode", # Name of the second argument
type=str,
default="baseline",
help="mode: baseline==non-disagg, or disagg",
)
parser.add_argument(
"--file_name", # Name of the second argument
type=str,
default=".vllm_output.txt",
help="the file that saves the output tokens ",
)
args = parser.parse_args()
for arg in vars(args):
print(f"{arg}: {getattr(args, arg)}")
if args.mode == "baseline":
# non-disagg
health_check_url = f"{args.service_url}/health"
else:
# disagg proxy
health_check_url = f"{args.service_url}/healthcheck"
if not os.path.exists(args.file_name):
raise ValueError(
f"In disagg mode, the output file {args.file_name} from "
"non-disagg. baseline does not exist."
)
service_url = f"{args.service_url}/v1"
if not check_vllm_server(health_check_url):
raise RuntimeError(f"vllm server: {args.service_url} is not ready yet!")
output_strs = dict()
for i, prompt in enumerate(SAMPLE_PROMPTS):
use_chat_endpoint = i % 2 == 1
output_str = run_simple_prompt(
base_url=service_url,
model_name=args.model_name,
input_prompt=prompt,
use_chat_endpoint=use_chat_endpoint,
)
print(f"Prompt: {prompt}, output: {output_str}")
output_strs[prompt] = output_str
if args.mode == "baseline":
# baseline: save outputs
try:
with open(args.file_name, "w") as json_file:
json.dump(output_strs, json_file, indent=4)
except OSError as e:
print(f"Error writing to file: {e}")
raise
else:
# disagg. verify outputs
baseline_outputs = None
try:
with open(args.file_name) as json_file:
baseline_outputs = json.load(json_file)
except OSError as e:
print(f"Error writing to file: {e}")
raise
assert isinstance(baseline_outputs, dict)
assert len(baseline_outputs) == len(output_strs)
for prompt, output in baseline_outputs.items():
assert prompt in output_strs, f"{prompt} not included"
assert output == output_strs[prompt], (
f"baseline_output: {output} != PD output: {output_strs[prompt]}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import openai
PREFILL_HOST = os.getenv("PREFILL_HOST", "localhost")
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
DECODE_HOST = os.getenv("DECODE_HOST", "localhost")
DECODE_PORT = os.getenv("DECODE_PORT", None)
PROXY_HOST = os.getenv("PROXY_HOST", "localhost")
PROXY_PORT = os.getenv("PROXY_PORT", None)
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
raise ValueError("Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
SHORT_PROMPT = "Red Hat is "
def test_edge_cases():
# Set the OpenAI API key and base URL
decode_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://{DECODE_HOST}:{DECODE_PORT}/v1",
)
prefill_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://{PREFILL_HOST}:{PREFILL_PORT}/v1",
)
proxy_client = openai.OpenAI(
api_key="MY_KEY",
base_url=f"http://{PROXY_HOST}:{PROXY_PORT}/v1",
)
# Get the list of models
models = decode_client.models.list()
MODEL = models.data[0].id
# (1) Check that we can handle a very short prompt,
# less than the length of the block size.
completion = proxy_client.completions.create(
model=MODEL, prompt=SHORT_PROMPT, temperature=0
)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(
model=MODEL, prompt=SHORT_PROMPT, temperature=0
)
prefill_response = completion.choices[0].text
print(f"SMALL PROMPT: {proxy_response=}")
assert proxy_response == prefill_response
# (2) Check that we can handle a full prefix cache
# hit on the D worker but not on the P worker.
# (2a): prime the D worker.
completion = decode_client.completions.create(
model=MODEL, prompt=PROMPT, temperature=0
)
decode_response = completion.choices[0].text
# (2b): send via the P/D setup
completion = proxy_client.completions.create(
model=MODEL, prompt=PROMPT, temperature=0
)
proxy_response = completion.choices[0].text
print(f"FULL CACHE HIT: {proxy_response=}")
assert proxy_response == decode_response
# (3) Check that we can handle a partial prefix cache
# hit on the D worker.
completion = proxy_client.completions.create(
model=MODEL, prompt=LONG_PROMPT, temperature=0
)
proxy_response = completion.choices[0].text
completion = prefill_client.completions.create(
model=MODEL, prompt=LONG_PROMPT, temperature=0
)
prefill_response = completion.choices[0].text
print(f"PARTIAL CACHE HIT: {proxy_response=}")
assert proxy_response == prefill_response

View File

@@ -0,0 +1,283 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import itertools
import logging
import os
import uuid
from contextlib import asynccontextmanager
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan context manager to handle startup and shutdown events.
"""
# Startup: Initialize client pools for prefiller and decoder services
app.state.prefill_clients = []
app.state.decode_clients = []
# Create prefill clients
for i, (host, port) in enumerate(global_args.prefiller_instances):
prefiller_base_url = f"http://{host}:{port}/v1"
app.state.prefill_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=prefiller_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
}
)
# Create decode clients
for i, (host, port) in enumerate(global_args.decoder_instances):
decoder_base_url = f"http://{host}:{port}/v1"
app.state.decode_clients.append(
{
"client": httpx.AsyncClient(
timeout=None,
base_url=decoder_base_url,
limits=httpx.Limits(
max_connections=None,
max_keepalive_connections=None,
),
),
"host": host,
"port": port,
"id": i,
}
)
# Initialize round-robin iterators
app.state.prefill_iterator = itertools.cycle(range(len(app.state.prefill_clients)))
app.state.decode_iterator = itertools.cycle(range(len(app.state.decode_clients)))
print(
f"Initialized {len(app.state.prefill_clients)} prefill clients "
f"and {len(app.state.decode_clients)} decode clients."
)
yield
# Shutdown: Close all clients
for client_info in app.state.prefill_clients:
await client_info["client"].aclose()
for client_info in app.state.decode_clients:
await client_info["client"].aclose()
# Update FastAPI app initialization to use lifespan
app = FastAPI(lifespan=lifespan)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
# Always use 127.0.0.1 as localhost binds to IPv6 which is blocked on CI
parser.add_argument("--host", type=str, default="127.0.0.1")
# For prefiller instances
parser.add_argument(
"--prefiller-hosts",
"--prefiller-host",
type=str,
nargs="+",
default=["localhost"],
)
parser.add_argument(
"--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100]
)
# For decoder instances
parser.add_argument(
"--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"]
)
parser.add_argument(
"--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]
)
args = parser.parse_args()
# Validate and pair hosts with ports
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports"
)
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError("Number of decoder hosts must match number of decoder ports")
# Create tuples of (host, port) for each service type
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
def get_next_client(app, service_type: str):
"""
Get the next client in round-robin fashion.
Args:
app: The FastAPI app instance
service_type: Either 'prefill' or 'decode'
Returns:
The next client to use
"""
if service_type == "prefill":
client_idx = next(app.state.prefill_iterator)
return app.state.prefill_clients[client_idx]
elif service_type == "decode":
client_idx = next(app.state.decode_iterator)
return app.state.decode_clients[client_idx]
else:
raise ValueError(f"Unknown service type: {service_type}")
async def send_request_to_service(
client_info: dict, endpoint: str, req_data: dict, request_id: str
):
"""
Send a request to a service using a client from the pool.
"""
req_data = req_data.copy()
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "max_completion_tokens" in req_data:
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
response = await client_info["client"].post(
endpoint, json=req_data, headers=headers
)
response.raise_for_status()
# read/consume the response body to release the connection
# otherwise, it would http.ReadError
await response.aread()
return response
async def stream_service_response(
client_info: dict, endpoint: str, req_data: dict, request_id: str
):
"""
Asynchronously stream response from a service using a client from the pool.
"""
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with client_info["client"].stream(
"POST", endpoint, json=req_data, headers=headers
) as response:
response.raise_for_status()
async for chunk in response.aiter_bytes():
yield chunk
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
# Get the next prefill client in round-robin fashion
prefill_client_info = get_next_client(request.app, "prefill")
# Send request to prefill service
response = await send_request_to_service(
prefill_client_info, api, req_data, request_id
)
# Extract the needed fields
response_json = response.json()
await response.aclose() # CRITICAL: Release connection back to pool
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Get the next decode client in round-robin fashion
decode_client_info = get_next_client(request.app, "decode")
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(
decode_client_info, api, req_data, request_id=request_id
):
yield chunk
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
return {
"status": "ok",
"prefill_instances": len(app.state.prefill_clients),
"decode_instances": len(app.state.decode_clients),
}
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -0,0 +1,41 @@
#!/usr/bin/env bash
set -euo pipefail
# Utility to run integration tests sequentially with varying TP configurations.
SCRIPT="v1/kv_connector/nixl_integration/run_accuracy_test.sh"
# Define test configurations
configs=(
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.6 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2"
"GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA case
"GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny"
"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=1 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP1, D-DPEP=2 (TP=1)
)
run_tests() {
local label=$1
local extra_env=$2
echo "=== Running tests (${label}) ==="
for cfg in "${configs[@]}"; do
echo "-> Running with ${cfg} ${extra_env:+and ${extra_env}}"
# Use 'env' to safely set variables without eval
if ! env ${extra_env} ${cfg} bash "${SCRIPT}"; then
echo "❌ Test failed for config: ${cfg} ${extra_env:+(${extra_env})}"
exit 1
fi
done
echo "✅ All ${label} tests passed!"
}
# Run tests
run_tests "default backend" ""
# Check if FLASHINFER is set (non-empty)
if [[ -n "${FLASHINFER:-}" ]]; then
echo "FLASHINFER is set, rerunning with VLLM_ATTENTION_BACKEND=FLASHINFER"
run_tests "FLASHINFER backend" "VLLM_ATTENTION_BACKEND=FLASHINFER"
else
echo "FLASHINFER not set, skipping FLASHINFER runs."
fi

View File

View File

@@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
class OldStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Old-style call to super().__init__ with only 2 arguments
super().__init__(vllm_config=vllm_config, role=role)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
class NewStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the new signature with 3 required arguments.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
# New-style call to super().__init__ with all 3 arguments
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_old_signature_factory_instantiation(role):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, OldStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is None
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_new_signature_factory_instantiation(role):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, NewStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_old_signature_super_init(role):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config = create_vllm_config()
connector = OldStyleTestConnector(vllm_config, role)
assert connector is not None
assert connector.role == role
assert connector._kv_cache_config is None
def test_old_signature_super_init_with_kwargs():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config = create_vllm_config()
# Test with vllm_config= and role= kwargs
connector1 = OldStyleTestConnector(
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
)
assert connector1 is not None
assert connector1._kv_cache_config is None
# Test with role= and vllm_config= in reversed order
connector2 = OldStyleTestConnector(
role=KVConnectorRole.WORKER, vllm_config=vllm_config
)
assert connector2 is not None
assert connector2._kv_cache_config is None
def test_internal_connector_uses_new_signature():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import (
ExampleConnector,
)
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "ExampleConnector"
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert connector is not None
assert isinstance(connector, ExampleConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
def test_signature_detection_with_mocking():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
# Mock _get_connector_class_with_compat to return old-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(OldStyleTestConnector, True),
):
old_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert old_connector is not None
assert isinstance(old_connector, OldStyleTestConnector)
assert old_connector._kv_cache_config is None
# Mock _get_connector_class_with_compat to return new-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(NewStyleTestConnector, False),
):
new_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert new_connector is not None
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config

View File

@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
test that invalid blocks are evicted from prefix cache to prevent pollution.
verifies that when sync-loading fails, invalid blocks are removed from the
prefix cache hash table so future requests cannot match and reuse corrupted data.
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_invalid_blocks_evicted_prevents_cache_pollution(
fail_scheduler: Scheduler,
):
"""
verify invalid blocks are evicted to prevent future cache hits.
scenario:
1. request 1 loads externally-computed blocks (sync mode)
2. some blocks fail to load and are marked invalid
3. with fail policy, invalid blocks should be evicted from prefix cache
4. request is marked as FINISHED_ERROR
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
# request 1: will have invalid blocks
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
fail_scheduler.add_request(request=request1)
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request1.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify eviction later
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# cache the blocks to simulate they've been computed and cached
# (in real scenario blocks would be cached after compute)
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
# verify block has a hash (is cached) before reporting invalid blocks
assert block.block_hash is not None, (
f"block {invalid_block_id} should be cached (have a hash) before "
f"eviction test, but hash is None"
)
# report invalid blocks
model_runner_output = create_model_runner_output(
[request1],
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request finished with error (fail policy)
assert request1.status == RequestStatus.FINISHED_ERROR
# critical assertion: invalid block and all subsequent blocks should be evicted
# all blocks from invalid_block_idx onwards become invalid since they were
# computed based on the failed block
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is None, (
f"block {block_id} at index {idx} should have been evicted "
f"(hash reset to None), but hash is {block_obj.block_hash}. "
f"All blocks from index {invalid_block_idx} onwards should be evicted "
f"since they depend on the invalid block at index {invalid_block_idx}."
)
# verify cache contains exactly the valid blocks (before first affected block)
# and none of the invalid blocks (from first affected block onwards)
# valid blocks: all blocks before invalid_block_idx should be cached
for idx in range(invalid_block_idx):
block_id = req_block_ids[idx]
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
assert block_obj.block_hash is not None, (
f"valid block {block_id} at index {idx} should still be cached "
f"(have a hash), but hash is None. Only blocks from index "
f"{invalid_block_idx} onwards should be evicted."
)
# invalid blocks: verify they're not in the cached_block_hash_to_block map
cached_blocks = (
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
)
cached_block_ids = {
b.block_id
for blocks_val in cached_blocks._cache.values()
for b in (
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
)
}
for idx in range(invalid_block_idx, len(req_block_ids)):
block_id = req_block_ids[idx]
assert block_id not in cached_block_ids, (
f"invalid block {block_id} at index {idx} should not be in cache hash table"
)

View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for KV cache offloading configuration."""
import pytest
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
pytestmark = pytest.mark.cpu_test
@pytest.mark.parametrize(
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
[
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
(None, None, 1, 1, None, None),
],
)
def test_kv_connector(
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
):
kv_transfer_config = (
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
if expected_backend is not None
else None
)
vllm_config = VllmConfig(
cache_config=CacheConfig(
kv_offloading_backend=kv_offloading_backend,
kv_offloading_size=kv_offloading_size,
),
kv_transfer_config=kv_transfer_config,
parallel_config=ParallelConfig(
tensor_parallel_size=tp, pipeline_parallel_size=pp
),
)
# No KV transfer config expected
if expected_backend is None:
assert vllm_config.kv_transfer_config is expected_backend
return
kv_transfer_config = vllm_config.kv_transfer_config
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
assert kv_transfer_config.kv_connector == expected_backend
assert kv_transfer_config.kv_role == "kv_both"
if kv_offloading_backend == "native":
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
assert kv_connector_extra_config["num_cpu_blocks"] == 0
# Existing config should be preserved
assert kv_connector_extra_config["existing_key"] == "existing_value"
elif kv_offloading_backend == "lmcache":
assert kv_connector_extra_config["lmcache.local_cpu"] is True
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
# Existing config should be replaced
assert "existing_key" not in kv_connector_extra_config

View File

@@ -0,0 +1,415 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for DecodeBenchConnector.
Tests the functionality of the DecodeBenchConnector which fills KV cache
with dummy values for decode performance benchmarking.
"""
import pytest
import torch
from vllm import SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
# ruff: noqa: E501
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import (
DecodeBenchConnector,
DecodeBenchConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request
from .utils import (
EOS_TOKEN_ID,
create_model_runner_output,
create_scheduler,
create_vllm_config,
)
class DecodeBenchTestRunner:
"""Test runner for DecodeBenchConnector."""
def __init__(self, block_size: int, num_gpu_blocks: int):
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.req_id = -1
# Create vllm config with DecodeBenchConnector
vllm_config = create_vllm_config(
block_size=block_size, max_num_batched_tokens=1000
)
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="DecodeBenchConnector",
kv_role="kv_both",
)
self.vllm_config = vllm_config
self.scheduler: Scheduler = create_scheduler(
vllm_config, num_blocks=num_gpu_blocks
)
# Create worker-side connector
self.worker_connector = DecodeBenchConnector(
vllm_config, KVConnectorRole.WORKER
)
# Create dummy KV caches for testing
# Shape: [num_blocks, 2, num_heads, block_size, head_dim]
# Using simplified shape for testing
num_heads = 4
head_dim = 64
self.kv_caches = {
f"layer_{i}": torch.zeros(
num_gpu_blocks, 2, num_heads, block_size, head_dim
)
for i in range(2) # 2 layers for testing
}
# Register KV caches with worker connector
self.worker_connector.register_kv_caches(self.kv_caches)
# Extract scheduler-side connector
scheduler_connector = self.scheduler.connector
assert scheduler_connector is not None
assert isinstance(scheduler_connector, DecodeBenchConnector)
self.scheduler_connector: DecodeBenchConnector = scheduler_connector
init_none_hash(sha256)
self._block_hasher = get_request_block_hasher(block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0
)
def new_request(self, token_ids: list[int]) -> Request:
"""Create a new request with given token IDs."""
self.req_id += 1
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=100),
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
self.scheduler.add_request(req)
return req
def run_single_step(self, token_id: int = 0):
"""Run a single scheduler + worker step."""
scheduler_output = self.scheduler.schedule()
# Get connector metadata
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, DecodeBenchConnectorMetadata)
# Bind metadata and load KV
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
self.worker_connector.clear_connector_metadata()
# Create model runner output
model_runner_output = create_model_runner_output(
reqs=self.scheduler.running,
token_id=token_id,
)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
return scheduler_output, kv_connector_metadata
def test_decode_bench_connector_basic():
"""Test basic functionality of DecodeBenchConnector."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with multiple blocks worth of tokens
num_tokens = block_size * 3 # 3 blocks
token_ids = [1] * num_tokens
req = runner.new_request(token_ids)
# Run first step - should fill KV cache with dummy values
scheduler_output, metadata = runner.run_single_step()
# Check that get_num_new_matched_tokens returned correct value
# Should be num_tokens - 1 (all except the last token for decode)
expected_fill_tokens = num_tokens - 1
# Check metadata has the request to fill
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Calculate expected number of blocks
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
assert len(block_ids) == expected_num_blocks
# Verify KV caches were filled with constant value
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
# Check that the block was filled
block_data = kv_cache[block_id]
# Should be filled with constant value 0.015
assert torch.allclose(block_data, torch.tensor(0.015))
def test_decode_bench_connector_no_refill():
"""Test that DecodeBenchConnector only fills once per request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request
num_tokens = block_size * 2
token_ids = [1] * num_tokens
runner.new_request(token_ids)
# Run first step - should fill KV cache
_, metadata1 = runner.run_single_step()
assert len(metadata1.reqs_to_fill) == 1
# Run second step - should NOT fill again (already filled)
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 0
def test_decode_bench_connector_single_token():
"""Test DecodeBenchConnector with single token request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with just 1 token
# Should not fill anything (need at least 2 tokens: 1 to fill, 1 to decode)
token_ids = [1]
runner.new_request(token_ids)
# Run step - should NOT fill KV cache
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 0
def test_decode_bench_connector_two_tokens():
"""Test DecodeBenchConnector with two token request."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with 2 tokens
# Should fill 1 token (first token), decode the second
token_ids = [1, 2]
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
assert num_tokens_to_fill == 1
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
assert len(block_ids_per_group[0]) == 1 # 1 token needs 1 block
def test_decode_bench_connector_large_context():
"""Test DecodeBenchConnector with large context size."""
block_size = 16
num_gpu_blocks = 1000
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request with many blocks
num_blocks = 20
num_tokens = block_size * num_blocks
token_ids = list(range(num_tokens))
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
# Should fill all tokens except the last one
expected_fill_tokens = num_tokens - 1
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Calculate expected number of blocks
expected_num_blocks = (expected_fill_tokens + block_size - 1) // block_size
assert len(block_ids) == expected_num_blocks
# Verify blocks were filled
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
block_data = kv_cache[block_id]
assert torch.allclose(block_data, torch.tensor(0.015))
def test_decode_bench_connector_multiple_requests():
"""Test DecodeBenchConnector with multiple sequential requests."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# First request
req1 = runner.new_request([1] * (block_size * 2))
_, metadata1 = runner.run_single_step()
assert len(metadata1.reqs_to_fill) == 1
assert req1.request_id in metadata1.reqs_to_fill
# Complete first request
while runner.scheduler.running:
runner.run_single_step()
# Add EOS to finish
scheduler_output = runner.scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=runner.scheduler.running,
token_id=EOS_TOKEN_ID,
use_eos=True,
)
runner.scheduler.update_from_output(scheduler_output, model_runner_output)
# Second request - should also get filled
req2 = runner.new_request([2] * (block_size * 3))
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 1
assert req2.request_id in metadata2.reqs_to_fill
# Different request should have different metadata
_, num_tokens1 = metadata1.reqs_to_fill[req1.request_id]
_, num_tokens2 = metadata2.reqs_to_fill[req2.request_id]
assert num_tokens1 == block_size * 2 - 1
assert num_tokens2 == block_size * 3 - 1
def test_decode_bench_connector_partial_block():
"""Test DecodeBenchConnector with partial block filling."""
block_size = 16
num_gpu_blocks = 100
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create a request that doesn't align to block boundaries
# e.g., 2.5 blocks worth of tokens
num_tokens = block_size * 2 + block_size // 2
token_ids = [1] * num_tokens
req = runner.new_request(token_ids)
# Run step
_, metadata = runner.run_single_step()
assert len(metadata.reqs_to_fill) == 1
assert req.request_id in metadata.reqs_to_fill
block_ids_per_group, num_tokens_to_fill = metadata.reqs_to_fill[req.request_id]
# Should fill all tokens except the last one
expected_fill_tokens = num_tokens - 1
assert num_tokens_to_fill == expected_fill_tokens
# For standard attention, there's only one group
assert len(block_ids_per_group) == 1
block_ids = block_ids_per_group[0]
# Should allocate 3 blocks to hold the partial data
expected_num_blocks = 3
assert len(block_ids) == expected_num_blocks
def test_decode_bench_connector_concurrent_requests():
"""Test DecodeBenchConnector with multiple concurrent requests in the same batch."""
block_size = 16
num_gpu_blocks = 1000
runner = DecodeBenchTestRunner(block_size=block_size, num_gpu_blocks=num_gpu_blocks)
# Create multiple requests that will be batched together
req1 = runner.new_request([1] * (block_size * 2))
req2 = runner.new_request([2] * (block_size * 3))
req3 = runner.new_request([3] * (block_size * 1))
# Run first step - all requests should be filled concurrently
_, metadata = runner.run_single_step()
# All three requests should be in the metadata
assert len(metadata.reqs_to_fill) == 3
assert req1.request_id in metadata.reqs_to_fill
assert req2.request_id in metadata.reqs_to_fill
assert req3.request_id in metadata.reqs_to_fill
# Verify each request has correct fill info
block_ids_per_group1, num_tokens1 = metadata.reqs_to_fill[req1.request_id]
block_ids_per_group2, num_tokens2 = metadata.reqs_to_fill[req2.request_id]
block_ids_per_group3, num_tokens3 = metadata.reqs_to_fill[req3.request_id]
# Verify token counts (all tokens except last one)
assert num_tokens1 == block_size * 2 - 1
assert num_tokens2 == block_size * 3 - 1
assert num_tokens3 == block_size * 1 - 1
# Verify block counts for each request
assert len(block_ids_per_group1[0]) == 2 # 2 blocks
assert len(block_ids_per_group2[0]) == 3 # 3 blocks
assert len(block_ids_per_group3[0]) == 1 # 1 block
# Verify all blocks are filled in KV cache
for req_id, (block_ids_per_group, _) in metadata.reqs_to_fill.items():
block_ids = block_ids_per_group[0]
for layer_name, kv_cache in runner.kv_caches.items():
for block_id in block_ids:
block_data = kv_cache[block_id]
assert torch.allclose(block_data, torch.tensor(0.015))
# Run second step - should NOT fill again (already filled)
_, metadata2 = runner.run_single_step()
assert len(metadata2.reqs_to_fill) == 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.running) == 0
def test_error_propagation_async_load(fail_scheduler: Scheduler):
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0

View File

@@ -0,0 +1,256 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
from vllm.platforms import current_platform
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128)
TEXT_PROMPTS = [
"What's in the image(s)? Around 30 words. What's special in 2nd image?",
"The future of AI is",
]
class InputCase(NamedTuple):
text: str
img: list[Image]
expected_len: int
info: str
def _check_path_len(path):
"""Return the latest length in path"""
return len(list(path.iterdir()))
def _list_path(path):
"""Return the list of foldername (hashes generated) under the path"""
return list(path.iterdir())
def run_test(
tmp_path,
processor,
llm: LLM,
question: str,
image_urls: list[Image],
expected_len: int,
info: str,
):
"""
One individual test to process the prompt and output base on 1 set of input
Then check if the length in the storage path matches the expected length
`info` introduces details or purpose of the individual test
"""
print(f"***info: {info}***")
print(f"**Expected storage path length after llm generate: {expected_len}**")
process_prompt(processor, llm, question, image_urls)
print(f"Path matched expected length: {_check_path_len(tmp_path)}")
print(f"Hashes under the storage path: {_list_path(tmp_path)}")
assert _check_path_len(tmp_path) == expected_len, (
f"Expect storage path length {expected_len} ;",
f"but end up {_check_path_len(tmp_path)} instead. ",
f"Info: {info}",
)
def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
"""
Form the prompt based on the text and image input, then llm generate output
"""
placeholders = [
{
"type": "image_url",
"image_url": {"url": f"data:image;base64,{encode_image_base64(image_pil)}"},
}
for image_pil in image_urls
]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
outputs = llm.generate(
{
"prompt": prompt,
**({"multi_modal_data": {"image": [*image_urls]}} if image_urls else {}),
},
sampling_params=SAMPLING_PARAMS,
)
print("-" * 50)
print("Output:")
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
print("-" * 50)
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that ExampleConnector saves KV to the storage locations
with proper hashes; that are unique for inputs with identical text but
different images (same size), or same multiple images but different orders.
"""
# Using tmp_path as the storage path to store KV
print(f"KV storage path at: {str(tmp_path)}")
# Configure the ExampleConnector
kv_transfer_config = KVTransferConfig(
kv_connector="ExampleConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": str(tmp_path)},
)
engine_args = EngineArgs(
model=MODEL_NAME,
max_model_len=8192,
max_num_seqs=1,
gpu_memory_utilization=0.4,
enforce_eager=True,
kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2},
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
# Create processor to handle the chat prompt
processor = AutoProcessor.from_pretrained(MODEL_NAME)
# Prepare images for the tests
# Resize to the same size to check hashes correctness
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
# Make sure that they are not the same picture
assert image_1 != image_2, "The images should not be identical"
# Create the LLM instance
engine_args = asdict(engine_args)
llm = LLM(**engine_args)
# Prepare the input cases
input_cases = [
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=1,
info="image_1 single input the first time.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=(
"image_2 single input the first time. "
"It is in same pixel size with image_1, yet it "
"should be able to form a new unique hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1],
expected_len=2,
info=(
"image_1 single input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2],
expected_len=2,
info=(
"image_2 single input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=3,
info="image_1 with image_2 input the first time.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info="The image order is swapped. Should form new hash.",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_1, image_2],
expected_len=4,
info=(
"[image_1, image_2] input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[image_2, image_1],
expected_len=4,
info=(
"[image_2, image_1] input the 2nd time. "
"It should not form another new hash."
),
),
InputCase(
text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Pure text input test as a case-control",
),
InputCase(
text=TEXT_PROMPTS[0],
img=[],
expected_len=5,
info="Identical pure text input as a case-control",
),
InputCase(
text=TEXT_PROMPTS[1],
img=[],
expected_len=6,
info="Another pure text input as a case-control",
),
]
# Run tests
for case_id, (text, img, expected_len, info) in enumerate(input_cases):
print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25)
run_test(tmp_path, processor, llm, text, img, expected_len, info)
print("All tests passed successfully!")

View File

@@ -0,0 +1,454 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for correctness in invalid block handling.
These tests verify correct behavior in three scenarios:
1. Sync recompute case: Blocks should not be freed for running requests
that need to recompute invalid blocks
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
3. Async recompute case: Invalid blocks should not be cached after transfer
"""
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import FinishReason, Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load: bool,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def fail_scheduler():
"""scheduler with kv_load_failure_policy='fail'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
return create_scheduler(vllm_config)
@pytest.fixture
def recompute_scheduler():
"""scheduler with kv_load_failure_policy='recompute'"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
return create_scheduler(vllm_config)
def test_sync_recompute_blocks_not_freed_for_running_requests(
recompute_scheduler: Scheduler,
):
"""
Test sync recompute case - blocks must not be freed for running requests.
When a running request has invalid blocks and retry_policy is 'recompute':
1. Request should remain in RUNNING state
2. num_computed_tokens should be truncated to invalid block boundary
3. Blocks should NOT be freed (request still needs them for recomputation)
4. Request should remain in scheduler.requests and scheduler.running
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be running with sync KV load
assert len(recompute_scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert request.status == RequestStatus.RUNNING
# get the allocated block IDs before invalid blocks are reported
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req_block_ids[invalid_block_idx]}
# store original num_computed_tokens for comparison
original_num_computed_tokens = request.num_computed_tokens
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=False, # not finished - should continue running
)
outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# critical assertions for recompute case:
# 1. request should still be RUNNING (not finished, not aborted)
assert request.status == RequestStatus.RUNNING, (
f"Request should remain RUNNING for recompute, got {request.status}"
)
# 2. num_computed_tokens should be truncated to first invalid block
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_truncated_tokens, (
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
f"got {request.num_computed_tokens}"
)
assert request.num_computed_tokens < original_num_computed_tokens, (
"num_computed_tokens should be reduced after invalid block detection"
)
# 3. no output should be generated (request is still running)
# the request should be skipped in the output loop
assert len(outputs) == 0 or request.request_id not in [
out.request_id for outs in outputs.values() for out in outs.outputs
], "No output should be generated for recompute requests"
# 4. request should still be in running queue
assert request in recompute_scheduler.running, (
"Request should remain in running queue for recomputation"
)
# 5. request should still be in scheduler.requests (not deleted)
assert request.request_id in recompute_scheduler.requests, (
"Request should not be deleted from scheduler.requests"
)
# 6. blocks should NOT be freed - verify blocks are still allocated
try:
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
assert allocated_blocks is not None
assert len(allocated_blocks[0]) > 0, (
"Blocks should still be allocated for recomputation"
)
except KeyError:
pytest.fail(
"Blocks were freed incorrectly! Running requests need their blocks "
"to recompute invalid portions."
)
# 7. verify request can be rescheduled in next step
scheduler_output_2 = recompute_scheduler.schedule()
# request should appear in the new schedule to recompute invalid blocks
scheduled_req_ids = [
req.request_id for req in scheduler_output_2.scheduled_new_reqs
]
if scheduler_output_2.num_scheduled_tokens:
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
assert (
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
), "Request should be reschedulable for recomputation"
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
"""
Test sync fail case - invalid blocks must be evicted from cache.
When a request fails with policy='fail' and has invalid blocks from sync loading:
1. Request should be finished with FINISHED_ERROR
2. Invalid blocks should be evicted from the KV cache
3. Valid blocks (if shared) should remain in cache
4. Future requests should not reuse the invalid blocks
This test verifies that invalid blocks are properly evicted to prevent
cache corruption and reuse of invalid data.
"""
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * fail_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
fail_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating sync load
fail_scheduler.connector = Mock()
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
)
fail_scheduler.connector.request_finished.return_value = (False, None)
fail_scheduler.connector.take_events.return_value = ()
scheduler_output = fail_scheduler.schedule()
# request should be running with sync KV load
assert len(fail_scheduler.running) == 1
assert request.status == RequestStatus.RUNNING
# get allocated block IDs
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# verify the block is in the block pool before we report it as invalid
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
assert block is not None
# report invalid blocks - request should fail
model_runner_output = create_model_runner_output(
[request],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify request is finished with error
assert request.status == RequestStatus.FINISHED_ERROR
assert request.get_finished_reason() == FinishReason.ERROR
# verify output is generated
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert len(engine_outputs.outputs) == 1
output = engine_outputs.outputs[0]
assert output.request_id == request.request_id
assert output.finish_reason == FinishReason.ERROR
# verify the request was removed from scheduler
assert request.request_id not in fail_scheduler.requests
assert len(fail_scheduler.running) == 0
# critical: verify invalid block was actually freed from cache
# this is the key assertion - the invalid block should no longer be
# tracked by the KV cache manager for this request
# if it's still there, a future request could reuse the invalid data
try:
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
# if we get here, check if blocks were actually freed
if block_ids is not None and len(block_ids[0]) > 0:
pytest.fail(
f"Invalid blocks still tracked for finished request! "
f"Request {request.request_id} should have been freed but "
f"still has {len(block_ids[0])} blocks allocated."
)
# blocks list exists but is empty - this is fine, they were freed
except KeyError:
# expected - request completely removed from tracking
pass
# critical: verify invalid block was evicted from prefix cache
# the block should no longer have a hash (hash is reset on eviction)
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should have been evicted from cache "
f"(hash should be None), but hash is still {block.block_hash}"
)
def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler,
):
"""
Test async recompute case - invalid blocks not cached after transfer.
When async KV loading has invalid blocks and retry_policy is 'recompute':
1. Blocks are allocated but not cached yet
2. When async transfer completes, only valid blocks should be cached
3. Invalid blocks should never enter the prefix cache
This test verifies correctness, the failed_recving_kv_req_ids protection
ensures only valid blocks are cached when the transfer completes, and we
only evict blocks from cache that are already hashed in the block table.
"""
from unittest.mock import patch
num_prompt_blocks = 100
num_external_computed_blocks = 99
invalid_block_idx = 50
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
num_external_computed_tokens = (
num_external_computed_blocks * recompute_scheduler.block_size
)
request = create_request(num_tokens=num_prompt_tokens)
recompute_scheduler.add_request(request=request)
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
# mock connector indicating async load
recompute_scheduler.connector = Mock()
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
)
recompute_scheduler.connector.request_finished.return_value = (False, None)
recompute_scheduler.connector.take_events.return_value = ()
scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
request.request_id
)
invalid_block_id = req_block_ids[invalid_block_idx]
invalid_block_ids = {invalid_block_id}
# get the block object to verify it's not cached yet and stays uncached
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
# verify block has no hash before invalid blocks are reported
assert block.block_hash is None, (
"Async loading blocks should not be cached yet (no hash)"
)
# report invalid blocks (transfer not finished yet)
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=None, # transfer NOT finished
invalid_block_ids=invalid_block_ids,
use_eos=False,
)
# critical: spy on evict_blocks to verify it's NOT called for async blocks
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
evict_blocks_calls = []
def evict_blocks_spy(block_ids):
evict_blocks_calls.append(set(block_ids))
return original_evict_blocks(block_ids)
with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
# verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, (
f"evict_blocks should not be called for async-only invalid blocks, "
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
)
# request should still be waiting (not finished with error due to recompute policy)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# verify num_computed_tokens was truncated to before invalid block
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
assert request.num_computed_tokens == expected_valid_tokens
# verify invalid block still has no hash (was not evicted)
assert block.block_hash is None, (
f"Async loading blocks shouldn't be cached or evicted. "
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
)
# now simulate async transfer completing
model_runner_output_2 = create_model_runner_output(
reqs=[],
finished_recving={request.request_id},
invalid_block_ids=None,
use_eos=False,
)
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
# verify request is now marked as finished receiving and ready to be processed
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
# critical: verify invalid block still has no hash before recompute
# the async transfer invalid data was never cached
assert block.block_hash is None, (
f"Invalid block {invalid_block_id} should not be cached before recompute "
f"(hash should be None), but hash is {block.block_hash}"
)
# critical end-to-end test: spy on cache_blocks to verify it's called with
# the truncated num_computed_tokens value
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
cache_blocks_calls = []
def cache_blocks_spy(req, num_tokens):
cache_blocks_calls.append((req.request_id, num_tokens))
return original_cache_blocks(req, num_tokens)
with patch.object(
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
):
# call schedule() again - this triggers _update_waiting_for_remote_kv()
# which should call cache_blocks with the truncated value
recompute_scheduler.schedule()
# verify cache_blocks was called with the truncated value
assert len(cache_blocks_calls) == 1, (
f"cache_blocks should be called exactly once, "
f"got {len(cache_blocks_calls)} calls"
)
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
assert cached_req_id == request.request_id
assert cached_num_tokens == expected_valid_tokens, (
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
f"but was called with {cached_num_tokens}"
)
# request should now be RUNNING (scheduled immediately after transfer completes)
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
assert request.status == RequestStatus.RUNNING
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
assert request.num_computed_tokens >= expected_valid_tokens, (
f"num_computed_tokens should be at least {expected_valid_tokens}, "
f"got {request.num_computed_tokens}"
)
# request should no longer be in the failed/finished receiving sets
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
# request should be in the running queue
assert request in recompute_scheduler.running

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa: E501
ExampleConnectorMetadata,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_initialized,
get_kv_transfer_group,
)
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
# Importing utils registers TestExampleConnector with the factory
from .utils import create_vllm_config
def _make_empty_scheduler_output():
return SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
kv_connector_metadata=ExampleConnectorMetadata(),
)
def test_kv_connector_mixin_clears_metadata():
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "TestExampleConnector"
vllm_config.kv_transfer_config.kv_role = "kv_both"
vllm_config.kv_transfer_config.kv_connector_extra_config["name"] = "unit"
# Initialize the global connector instance
ensure_kv_transfer_initialized(vllm_config)
try:
# Minimal scheduler output with empty metadata; mixin should still
# bind/clear metadata even if no loads happen
scheduler_output = _make_empty_scheduler_output()
# Invoke the no-forward path which uses the mixin context manager
KVConnectorModelRunnerMixin.kv_connector_no_forward(
scheduler_output, vllm_config
)
# Verify clear_connector_metadata was called on the connector
connector = get_kv_transfer_group()
assert connector._connector_metadata is None
# Test connector wrapper records method calls
assert connector.call_record.get("bind_connector_metadata", 0) == 1
assert connector.call_record.get("clear_connector_metadata", 0) == 1
finally:
# Ensure we clean up the global connector between tests
KVConnectorModelRunnerMixin.ensure_kv_transfer_shutdown()

View File

@@ -0,0 +1,335 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from unittest.mock import Mock
import pytest
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
from .utils import (
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
def _make_get_num_new_matched_tokens(
req_num_new_matched_tokens: dict[str, int],
async_load,
) -> Callable[[Request, int], tuple[int, bool]]:
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
value = req_num_new_matched_tokens.get(request.request_id, 0)
return value, async_load
return get_num_new_matched_tokens
@pytest.fixture
def scheduler():
vllm_config = create_vllm_config()
return create_scheduler(vllm_config)
@pytest.mark.parametrize(
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
(100, 99, {98}),
],
)
def test_async_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request2)
request3 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request3)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
(req2_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request2.request_id)
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving={request1.request_id, request3.request_id},
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
min_invalid_block_idx = min(invalid_block_idxs)
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
if request.request_id == request2.request_id:
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
)
else:
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
@pytest.mark.parametrize(
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 98}),
(100, 99, {50, 98}),
(100, 99, {98}),
],
)
def test_sync_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request1 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request1)
request2 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request2)
request3 = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request3)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
)
scheduler.connector.request_finished.return_value = (False, None)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
# req_id -> num_computed_tokens
expected_computed_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: num_external_computed_tokens,
request3.request_id: num_external_computed_tokens,
}
assert len(scheduler.running) == 3
assert len(scheduler_output.scheduled_new_reqs) == 3
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
# Simulate a failure in loading some of request2 blocks.
req2_block_ids = scheduler_output.scheduled_new_reqs[1].block_ids[0]
invalid_block_ids = {req2_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
[request1, request2, request3],
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == request2.request_id
assert scheduler.running[0].num_computed_tokens == (
min(invalid_block_idxs) * scheduler.block_size
)
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
assert scheduler.connector.request_finished.call_count == 2
@pytest.mark.parametrize(
"num_prompt_blocks,"
"num_external_computed_blocks,"
"num_common_prefix_blocks,"
"invalid_block_idxs",
[
(100, 99, 50, {0, 49}),
(100, 99, 50, {25, 49}),
(100, 99, 50, {49}),
],
)
def test_sync_load_failure_with_shared_blocks(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
num_common_prefix_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks >= num_common_prefix_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
common_prefix_len = num_common_prefix_blocks * scheduler.block_size
request1 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request1)
request2 = create_request(
num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len
)
scheduler.add_request(request=request2)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request1.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
# req_id -> num_computed_tokens
expected_computed_tokens = {
request1.request_id: num_external_computed_tokens,
request2.request_id: common_prefix_len,
}
assert len(scheduler.running) == 2
assert len(scheduler_output.scheduled_new_reqs) == 2
for request in scheduler_output.scheduled_new_reqs:
assert request.num_computed_tokens == expected_computed_tokens[request.req_id]
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
# Simulate a failure in loading some of the shared blocks.
req1_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
invalid_block_ids = {req1_block_ids[i] for i in invalid_block_idxs}
model_runner_output = create_model_runner_output(
[request1, request2], invalid_block_ids=invalid_block_ids, use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# req_id -> num_computed_tokens
# all the common prefix blocks will be computed by request1
expected_computed_tokens = {
request1.request_id: min(invalid_block_idxs) * scheduler.block_size,
request2.request_id: common_prefix_len,
}
assert len(scheduler.running) == 2
for request in scheduler.running:
assert (
request.num_computed_tokens == expected_computed_tokens[request.request_id]
)
assert scheduler.connector.get_num_new_matched_tokens.call_count == 2
@pytest.mark.parametrize(
"num_prompt_blocks,num_external_computed_blocks,invalid_block_idxs",
[
(100, 99, {0, 50, 98}),
(100, 99, {98, 50, 0}),
],
)
def test_async_progressive_load_failure(
scheduler: Scheduler,
num_prompt_blocks: int,
num_external_computed_blocks: int,
invalid_block_idxs: set[int],
):
assert num_prompt_blocks >= num_external_computed_blocks
num_prompt_tokens = num_prompt_blocks * scheduler.block_size
num_external_computed_tokens = num_external_computed_blocks * scheduler.block_size
request = create_request(num_tokens=num_prompt_tokens)
scheduler.add_request(request=request)
# Mock KV connector method.
# req_id -> num_external_computed_tokens
req_num_new_matched_tokens = {
request.request_id: num_external_computed_tokens,
}
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.side_effect = (
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True)
)
scheduler.connector.take_events.return_value = ()
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == 0
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
min_invalid_block_idx = max(invalid_block_idxs) + 1
# Simulate failures when progressively loading request blocks.
for invalid_block_idx in invalid_block_idxs:
(req_block_ids,) = scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
model_runner_output = create_model_runner_output(
reqs=[],
finished_recving=set(),
invalid_block_ids=invalid_block_ids,
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
)
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1

View File

@@ -0,0 +1,756 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
)
from vllm.v1.outputs import KVConnectorOutput
@pytest.fixture
def mock_lmcache_engine_event():
"""Create a mock event object that mimics what the lmcache engine returns."""
class MockEvent:
def __init__(
self,
block_hashes,
parent_block_hash,
token_ids,
lora_id,
block_size,
medium,
):
self.block_hashes = block_hashes
self.parent_block_hash = parent_block_hash
self.token_ids = token_ids
self.lora_id = lora_id
self.block_size = block_size
self.medium = medium
return MockEvent(
block_hashes=["hash1", "hash2"],
parent_block_hash="parent_hash",
token_ids=[1, 2, 3, 4],
lora_id=None,
block_size=16,
medium="GPU",
)
@pytest.fixture
def mock_connector():
"""Create a mock LMCacheConnectorV1 instance with mocked dependencies."""
connector = MagicMock(spec=LMCacheConnectorV1)
connector._kv_cache_events = None
connector._lmcache_engine = MagicMock()
# Make the methods use the real implementation
connector.get_kv_connector_kv_cache_events = (
LMCacheConnectorV1.get_kv_connector_kv_cache_events.__get__(
connector, LMCacheConnectorV1
)
)
connector.update_connector_output = (
LMCacheConnectorV1.update_connector_output.__get__(
connector, LMCacheConnectorV1
)
)
connector.take_events = LMCacheConnectorV1.take_events.__get__(
connector, LMCacheConnectorV1
)
return connector
class TestGetKVConnectorKVCacheEvents:
"""Test get_kv_connector_kv_cache_events method."""
def test_returns_none_when_no_events(self, mock_connector):
"""Test that None is returned when lmcache engine has no events."""
mock_connector._lmcache_engine.get_kv_events.return_value = None
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
mock_connector._lmcache_engine.get_kv_events.assert_called_once()
def test_returns_none_when_empty_list(self, mock_connector):
"""Test that None is returned when lmcache engine returns empty list."""
mock_connector._lmcache_engine.get_kv_events.return_value = []
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is None
def test_converts_single_event(self, mock_connector, mock_lmcache_engine_event):
"""Test conversion of a single event from lmcache engine format."""
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
assert result.get_number_of_workers() == 1
events = result.get_all_events()
assert len(events) == 1
assert isinstance(events[0], BlockStored)
assert events[0].block_hashes == ["hash1", "hash2"]
assert events[0].parent_block_hash == "parent_hash"
assert events[0].token_ids == [1, 2, 3, 4]
assert events[0].lora_id is None
assert events[0].block_size == 16
assert events[0].medium == "GPU"
def test_converts_multiple_events(self, mock_connector):
"""Test conversion of multiple events from lmcache engine format."""
class MockEvent:
def __init__(self, i):
self.block_hashes = [f"hash{i}"]
self.parent_block_hash = f"parent{i}"
self.token_ids = [i]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
events = [MockEvent(i) for i in range(5)]
mock_connector._lmcache_engine.get_kv_events.return_value = events
result = mock_connector.get_kv_connector_kv_cache_events()
assert result is not None
assert isinstance(result, LMCacheKVEvents)
converted_events = result.get_all_events()
assert len(converted_events) == 5
for i, event in enumerate(converted_events):
assert isinstance(event, BlockStored)
assert event.block_hashes == [f"hash{i}"]
assert event.parent_block_hash == f"parent{i}"
assert event.token_ids == [i]
def test_preserves_event_attributes(self, mock_connector):
"""Test that all event attributes are correctly preserved."""
class MockEventWithLora:
def __init__(self):
self.block_hashes = ["hash_a", "hash_b", "hash_c"]
self.parent_block_hash = "parent_xyz"
self.token_ids = [100, 200, 300]
self.lora_id = 42
self.block_size = 32
self.medium = "DISK"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventWithLora()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
event = events[0]
assert event.block_hashes == ["hash_a", "hash_b", "hash_c"]
assert event.parent_block_hash == "parent_xyz"
assert event.token_ids == [100, 200, 300]
assert event.lora_id == 42
assert event.block_size == 32
assert event.medium == "DISK"
def test_handles_none_parent_block_hash(self, mock_connector):
"""Test handling of events with None parent_block_hash."""
class MockEventNoParent:
def __init__(self):
self.block_hashes = ["hash1"]
self.parent_block_hash = None
self.token_ids = [1, 2]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEventNoParent()
]
result = mock_connector.get_kv_connector_kv_cache_events()
events = result.get_all_events()
assert events[0].parent_block_hash is None
class TestUpdateConnectorOutput:
"""Test update_connector_output method."""
def test_does_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that method returns early when kv_cache_events is None."""
connector_output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_does_nothing_when_kv_cache_events_is_not_lmcache_kv_events(
self, mock_connector
):
"""Test that method returns early when kv_cache_events is not
LMCacheKVEvents."""
# Create a mock object that is not LMCacheKVEvents
fake_events = MagicMock()
connector_output = KVConnectorOutput(kv_cache_events=fake_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is None
def test_sets_kv_cache_events_when_none(self, mock_connector):
"""Test that _kv_cache_events is set when it was None."""
kv_events = LMCacheKVEvents(num_workers=1)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1, 2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(connector_output)
assert mock_connector._kv_cache_events is kv_events
def test_adds_events_when_kv_cache_events_already_exists(self, mock_connector):
"""Test that events are added when _kv_cache_events already exists."""
# Set up existing events
existing_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
existing_events.add_events([event1])
existing_events.add_events([event1]) # Simulate 2 workers reporting
mock_connector._kv_cache_events = existing_events
# Create new events to add
new_events = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event2])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Check that events were added
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3 # 2 from existing + 1 from new
assert event1 in all_events
assert event2 in all_events
def test_increments_workers_when_kv_cache_events_already_exists(
self, mock_connector
):
"""Test that worker count is incremented correctly."""
# Set up existing events with 2 workers
existing_events = LMCacheKVEvents(num_workers=2)
mock_connector._kv_cache_events = existing_events
# Create new events from 3 workers
new_events = LMCacheKVEvents(num_workers=3)
event = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
new_events.add_events([event])
connector_output = KVConnectorOutput(kv_cache_events=new_events)
mock_connector.update_connector_output(connector_output)
# Worker count should be 2 + 3 = 5
assert mock_connector._kv_cache_events.get_number_of_workers() == 5
def test_multiple_updates(self, mock_connector):
"""Test multiple consecutive updates."""
# First update
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update
events2 = LMCacheKVEvents(num_workers=2)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
events2.add_events([event2])
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Third update
events3 = LMCacheKVEvents(num_workers=1)
event3 = BlockStored(
block_hashes=["hash3"],
parent_block_hash=None,
token_ids=[3],
block_size=16,
lora_id=None,
medium="GPU",
)
events3.add_events([event3])
output3 = KVConnectorOutput(kv_cache_events=events3)
mock_connector.update_connector_output(output3)
# Check final state
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 3
assert mock_connector._kv_cache_events.get_number_of_workers() == 4 # 1+2+1
def test_updates_with_empty_events(self, mock_connector):
"""Test updating with empty event lists."""
# First update with actual events
events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
events1.add_events([event1])
output1 = KVConnectorOutput(kv_cache_events=events1)
mock_connector.update_connector_output(output1)
# Second update with empty events
events2 = LMCacheKVEvents(num_workers=2)
# No events added
output2 = KVConnectorOutput(kv_cache_events=events2)
mock_connector.update_connector_output(output2)
# Should still have the original event
all_events = mock_connector._kv_cache_events.get_all_events()
assert len(all_events) == 1
assert mock_connector._kv_cache_events.get_number_of_workers() == 3
class TestTakeEvents:
"""Test take_events method."""
def test_yields_nothing_when_kv_cache_events_is_none(self, mock_connector):
"""Test that nothing is yielded when _kv_cache_events is None."""
mock_connector._kv_cache_events = None
events = list(mock_connector.take_events())
assert events == []
def test_yields_events_and_clears(self, mock_connector):
"""Test that events are yielded and then cleared."""
# Set up events
kv_events = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events.add_events([event1, event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Check that events were yielded
assert len(events) == 2
assert event1 in events
assert event2 in events
# Check that _kv_cache_events was cleared
assert mock_connector._kv_cache_events is None
def test_aggregates_before_yielding(self, mock_connector):
"""Test that events are aggregated before yielding."""
# Set up events from multiple workers
kv_events = LMCacheKVEvents(num_workers=3)
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
uncommon_event = BlockStored(
block_hashes=["hash_uncommon"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# All 3 workers report common_event
kv_events.add_events([common_event])
kv_events.add_events([common_event])
kv_events.add_events([common_event])
# Only 1 worker reports uncommon_event
kv_events.add_events([uncommon_event])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# Only the common event should be yielded
assert len(events) == 1
assert events[0] == common_event
def test_multiple_take_events_calls(self, mock_connector):
"""Test calling take_events multiple times."""
# First call with events
kv_events1 = LMCacheKVEvents(num_workers=1)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events1.add_events([event1])
mock_connector._kv_cache_events = kv_events1
events1 = list(mock_connector.take_events())
assert len(events1) == 1
assert events1[0] == event1
assert mock_connector._kv_cache_events is None
# Second call with no events
events2 = list(mock_connector.take_events())
assert events2 == []
# Third call after adding new events
kv_events2 = LMCacheKVEvents(num_workers=1)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
kv_events2.add_events([event2])
mock_connector._kv_cache_events = kv_events2
events3 = list(mock_connector.take_events())
assert len(events3) == 1
assert events3[0] == event2
def test_yields_empty_after_aggregation_removes_all(self, mock_connector):
"""Test that nothing is yielded if aggregation removes all events."""
# Set up events from 2 workers with no common events
kv_events = LMCacheKVEvents(num_workers=2)
event1 = BlockStored(
block_hashes=["hash1"],
parent_block_hash=None,
token_ids=[1],
block_size=16,
lora_id=None,
medium="GPU",
)
event2 = BlockStored(
block_hashes=["hash2"],
parent_block_hash=None,
token_ids=[2],
block_size=16,
lora_id=None,
medium="GPU",
)
# Worker 1 reports event1
kv_events.add_events([event1])
# Worker 2 reports event2
kv_events.add_events([event2])
mock_connector._kv_cache_events = kv_events
# Take events
events = list(mock_connector.take_events())
# No common events, so nothing should be yielded
assert events == []
assert mock_connector._kv_cache_events is None
class TestIntegrationScenarios:
"""Test integration scenarios."""
def test_full_workflow(self, mock_connector, mock_lmcache_engine_event):
"""Test a complete workflow from getting events to taking them."""
# Step 1: Get events from lmcache engine
mock_connector._lmcache_engine.get_kv_events.return_value = [
mock_lmcache_engine_event
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is not None
assert len(kv_events.get_all_events()) == 1
# Step 2: Update connector output (simulate receiving from worker)
output1 = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output1)
assert mock_connector._kv_cache_events is not None
# Step 3: Take events
taken_events = list(mock_connector.take_events())
assert len(taken_events) == 1
assert mock_connector._kv_cache_events is None
def test_multiple_workers_workflow(self, mock_connector):
"""Test workflow with multiple workers."""
class MockEvent:
def __init__(self, hash_val):
self.block_hashes = [hash_val]
self.parent_block_hash = None
self.token_ids = [1]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
# Worker 1
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker1"),
]
kv_events1 = mock_connector.get_kv_connector_kv_cache_events()
output1 = KVConnectorOutput(kv_cache_events=kv_events1)
mock_connector.update_connector_output(output1)
# Worker 2
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent("hash_common"),
MockEvent("hash_worker2"),
]
kv_events2 = mock_connector.get_kv_connector_kv_cache_events()
output2 = KVConnectorOutput(kv_cache_events=kv_events2)
mock_connector.update_connector_output(output2)
# Take events (should only get common events)
taken_events = list(mock_connector.take_events())
# With aggregation, only events reported by both workers should be present
# In this case, hash_common was reported by both
event_hashes = [e.block_hashes[0] for e in taken_events]
assert "hash_common" in event_hashes
def test_empty_workflow(self, mock_connector):
"""Test workflow when there are no events at any stage."""
# Get events returns None
mock_connector._lmcache_engine.get_kv_events.return_value = None
kv_events = mock_connector.get_kv_connector_kv_cache_events()
assert kv_events is None
# Update with None
output = KVConnectorOutput(kv_cache_events=None)
mock_connector.update_connector_output(output)
# Take events
taken_events = list(mock_connector.take_events())
assert taken_events == []
assert mock_connector._kv_cache_events is None
def test_repeated_cycles(self, mock_connector):
"""Test multiple cycles of the complete workflow."""
class MockEvent:
def __init__(self, cycle_num):
self.block_hashes = [f"hash_cycle_{cycle_num}"]
self.parent_block_hash = None
self.token_ids = [cycle_num]
self.lora_id = None
self.block_size = 16
self.medium = "GPU"
for cycle in range(3):
# Get events
mock_connector._lmcache_engine.get_kv_events.return_value = [
MockEvent(cycle)
]
kv_events = mock_connector.get_kv_connector_kv_cache_events()
# Update
output = KVConnectorOutput(kv_cache_events=kv_events)
mock_connector.update_connector_output(output)
# Take
taken_events = list(mock_connector.take_events())
# Verify
assert len(taken_events) == 1
assert taken_events[0].block_hashes[0] == f"hash_cycle_{cycle}"
assert mock_connector._kv_cache_events is None
def test_lmcache_kv_events_aggregation(self):
"""
Test LMCacheKVEvents aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput
# Create KVOutputAggregator for 3 workers (simulating TP=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Define common and unique events
common_event = BlockStored(
block_hashes=["hash_common"],
parent_block_hash="parent_common",
token_ids=[1, 2, 3],
block_size=16,
lora_id=None,
medium="GPU",
)
worker1_unique_event = BlockStored(
block_hashes=["hash_worker1"],
parent_block_hash="parent_w1",
token_ids=[4, 5],
block_size=16,
lora_id=None,
medium="GPU",
)
worker2_unique_event = BlockStored(
block_hashes=["hash_worker2"],
parent_block_hash="parent_w2",
token_ids=[6, 7],
block_size=16,
lora_id=None,
medium="GPU",
)
worker3_unique_event = BlockStored(
block_hashes=["hash_worker3"],
parent_block_hash="parent_w3",
token_ids=[8, 9],
block_size=16,
lora_id=None,
medium="GPU",
)
# Create events for each worker
# Worker 0: reports common event and its unique event
worker0_events = LMCacheKVEvents(num_workers=1)
worker0_events.add_events([common_event, worker1_unique_event])
# Worker 1: reports common event and its unique event
worker1_events = LMCacheKVEvents(num_workers=1)
worker1_events.add_events([common_event, worker2_unique_event])
# Worker 2: reports common event and its unique event
worker2_events = LMCacheKVEvents(num_workers=1)
worker2_events.add_events([common_event, worker3_unique_event])
# Create ModelRunnerOutput instances for each worker
worker_outputs = []
for i, worker_events in enumerate(
[worker0_events, worker1_events, worker2_events]
):
output = ModelRunnerOutput(
req_ids=[f"req_{i}"],
req_id_to_index={f"req_{i}": 0},
sampled_token_ids=[[123]], # dummy token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[None],
kv_connector_output=KVConnectorOutput(
finished_sending=set([f"req_{i}_send"])
if i < 2
else None, # Workers 0,1 finished sending
finished_recving=set([f"req_{i}_recv"])
if i > 0
else None, # Workers 1,2 finished receiving
kv_cache_events=worker_events,
),
)
worker_outputs.append(output)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
kv_cache_events = aggregated_output.kv_connector_output.kv_cache_events
assert isinstance(kv_cache_events, LMCacheKVEvents)
# After aggregation, events should be combined from all workers
# The aggregator doesn't automatically aggregate events, so we need to call
# aggregate() to get only common events
kv_cache_events.aggregate()
aggregated_events = kv_cache_events.get_all_events()
# Only the common event should remain after aggregation
# because it's the only event reported by all 3 workers
assert len(aggregated_events) == 1
assert aggregated_events[0] == common_event
# Verify the common event properties
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]

View File

@@ -0,0 +1,228 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# NOTE: if your PR has broken one of the tests here (sorry),
# kindly patch the corresponding integration in
# /vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
# or reach out to @aposataC for assistance
# Assumption vs. Correctness Tests:
# these unit tests do *not* test correctness of LMCache-side or vLLM-side logic
# it is to ensure that assumptions LMCache makes about vLLM's interface are stable
import pytest
from vllm.platforms import current_platform
def assumes(obj, attr, is_callable=False, is_instance_of=None):
import inspect
from dataclasses import is_dataclass
assumption_msg = (
f"LMCache connector currently assumes that {obj} has a(n) {attr} attribute"
)
if hasattr(obj, attr):
attr_value = getattr(obj, attr)
elif is_dataclass(obj) and attr in getattr(obj, "__dataclass_fields__", {}):
field = obj.__dataclass_fields__[attr]
field_type = field.type
origin = getattr(field_type, "__origin__", None)
if origin is not None:
field_type = origin
attr_value = field_type
else:
raise AssertionError(assumption_msg)
if is_callable:
assumption_msg += f" and that {obj}.{attr} is a callable"
assert callable(attr_value), assumption_msg
if is_instance_of:
assumption_msg += f" and that {obj}.{attr} is an instance of {is_instance_of}"
if isinstance(attr_value, property):
fget = attr_value.fget
assert fget is not None, f"Property {obj}.{attr} has no fget"
sig = inspect.signature(fget)
ret_anno = sig.return_annotation
assert ret_anno is not inspect._empty, (
f"Property {obj}.{attr} has no return annotation"
)
assert ret_anno == is_instance_of, assumption_msg
else:
if isinstance(attr_value, type):
assert attr_value is is_instance_of, assumption_msg
else:
assert isinstance(attr_value, is_instance_of), assumption_msg
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
)
def test_multimodal_interface():
# protect against interface changes
from vllm.multimodal.inputs import PlaceholderRange
assumes(PlaceholderRange, "offset")
assumes(PlaceholderRange, "length")
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
)
def test_config_interface():
# protect against interface changes
from vllm.config import VllmConfig
from vllm.config.cache import CacheConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
assumes(VllmConfig, "model_config")
assumes(VllmConfig, "cache_config")
assumes(VllmConfig, "parallel_config")
assumes(VllmConfig, "kv_transfer_config")
assumes(KVTransferConfig, "kv_role")
assumes(KVTransferConfig, "kv_connector_extra_config")
assumes(ModelConfig, "use_mla", is_instance_of=bool)
assumes(ModelConfig, "dtype")
assumes(ModelConfig, "max_model_len")
assumes(ModelConfig, "get_vocab_size", is_callable=True)
assumes(ModelConfig, "get_num_attention_heads", is_callable=True)
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
assumes(ModelConfig, "get_head_size", is_callable=True)
assumes(ModelConfig, "get_num_layers", is_callable=True)
assumes(ModelConfig, "get_num_kv_heads", is_callable=True)
assumes(ModelConfig, "model")
assumes(ParallelConfig, "world_size")
assumes(ParallelConfig, "rank")
assumes(ParallelConfig, "tensor_parallel_size")
assumes(ParallelConfig, "pipeline_parallel_size")
assumes(ParallelConfig, "data_parallel_size_local")
assumes(ParallelConfig, "data_parallel_rank_local")
assumes(CacheConfig, "cache_dtype")
assumes(CacheConfig, "block_size")
assumes(CacheConfig, "gpu_memory_utilization")
# kv metadata minimal case
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
model_config = ModelConfig(dtype="bfloat16")
parallel_config = ParallelConfig()
cache_config = CacheConfig(cache_dtype="bfloat16")
kv_dtype = get_kv_cache_torch_dtype(cache_config.cache_dtype, model_config.dtype)
use_mla = False
chunk_size = 256
num_layer = model_config.get_num_layers(parallel_config)
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
# dummy lmcache metadata creation example
_ = (
model_config.model,
parallel_config.world_size,
parallel_config.rank,
"vllm",
kv_dtype,
kv_shape,
use_mla,
)
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Requires libcudart.so, not available on ROCm"
)
def test_request_interface():
# protect against interface changes
from types import NoneType
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
req = Request(
request_id="test_request",
prompt_token_ids=[1, 2, 3],
sampling_params=SamplingParams(max_tokens=10),
pooling_params=None,
eos_token_id=100,
lora_request=None,
)
assumes(req, "mm_features", is_instance_of=(list, NoneType))
assumes(req, "request_id")
assumes(req, "priority")
assumes(req, "prompt_token_ids")
assumes(req, "sampling_params")
assumes(req, "num_tokens")
assumes(req, "kv_transfer_params", is_instance_of=(dict, NoneType))
from vllm.multimodal.inputs import MultiModalFeatureSpec
assumes(MultiModalFeatureSpec, "identifier")
assumes(MultiModalFeatureSpec, "mm_position")
def test_new_request_interface():
# protect against interface changes
from vllm.v1.core.sched.output import NewRequestData
assumes(NewRequestData, "req_id")
assumes(NewRequestData, "block_ids")
assumes(NewRequestData, "prompt_token_ids")
assumes(NewRequestData, "sampling_params")
def test_sampling_params_interface():
# protect against interface changes
from vllm.sampling_params import SamplingParams
assumes(SamplingParams, "extra_args")
# dumb example use case in LMCache
kv_transfer_params = {
"lmcache.tag.user": "example_user_1",
"lmcache.ttl": 60,
}
sampling_params = SamplingParams(
extra_args={"kv_transfer_params": kv_transfer_params}
)
assert sampling_params.extra_args["kv_transfer_params"] == kv_transfer_params
def test_tp_interface():
# protect against interface changes
import inspect
from vllm.distributed.parallel_state import get_tp_group
sig = inspect.signature(get_tp_group)
GroupCoordinator = sig.return_annotation
assumes(GroupCoordinator, "broadcast", is_callable=True)
assumes(GroupCoordinator, "broadcast_object", is_callable=True)
def test_forward_context_interface():
# protect against interface changes
from vllm.forward_context import ForwardContext
assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
assumes(ForwardContext, "virtual_engine")
assumes(ForwardContext, "attn_metadata")
def test_scheduler_output_interface():
# protect against interface changes
from vllm.v1.core.sched.output import SchedulerOutput
assumes(SchedulerOutput, "finished_req_ids")
assumes(SchedulerOutput, "scheduled_new_reqs", is_instance_of=list)
assumes(SchedulerOutput, "num_scheduled_tokens", is_instance_of=dict)
assumes(SchedulerOutput, "scheduled_cached_reqs")
from vllm.v1.core.sched.output import CachedRequestData
assumes(CachedRequestData, "req_ids", is_instance_of=list)
assumes(CachedRequestData, "new_block_ids", is_instance_of=list)

View File

@@ -0,0 +1,603 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import filecmp
import shutil
import tempfile
from pathlib import Path
from typing import Any
import pytest
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorStats,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
from vllm.platforms import current_platform
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
PROMPT_CONTEXT = "Hi " * 100
PROMPTS = [
PROMPT_CONTEXT + "Hello, my name is",
PROMPT_CONTEXT + "The capital of France is",
]
SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20)
# Test connector with custom stats for testing MultiConnector
class MockConnectorStats(KVConnectorStats):
"""Mock stats class for testing."""
pass
class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing."""
@classmethod
def build_kv_connector_stats(
cls, data: dict[str, Any] | None = None
) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None
# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
dcmp = filecmp.dircmp(dir1, dir2)
if dcmp.left_only or dcmp.right_only or dcmp.diff_files:
print(f"Differences found between {dir1} and {dir2}:")
print(f" Left only: {dcmp.left_only}")
print(f" Right only: {dcmp.right_only}")
print(f" Different files: {dcmp.diff_files}")
return False
for sub_dir in dcmp.common_dirs:
if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir):
return False
return True
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_multi_example_connector_consistency():
"""
Tests that MultiConnector with two ExampleConnectors saves
identical KV cache data to separate storage locations.
"""
storage_1_path = Path("storage_1/")
storage_2_path = Path("storage_2/")
shutil.rmtree(storage_1_path, ignore_errors=True)
shutil.rmtree(storage_2_path, ignore_errors=True)
storage_1_path.mkdir()
storage_2_path.mkdir()
# Configure MultiConnector with two ExampleConnectors
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_1_path),
"name": "storage1",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
{
"kv_connector": "TestExampleConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"shared_storage_path": str(storage_2_path),
"name": "storage2",
},
"kv_connector_module_path": "tests.v1.kv_connector.unit.utils",
},
]
},
)
llm = LLM(
model=MODEL_NAME,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
)
# Run generation - this should trigger saving KV cache
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
# --- Verification ---
# Check that both storage directories were populated
local_subdirs = list(storage_1_path.iterdir())
external_subdirs = list(storage_2_path.iterdir())
assert len(local_subdirs) > 0, (
f"Local storage path {storage_1_path} is empty after generation."
)
assert len(external_subdirs) > 0, (
f"External storage path {storage_2_path} is empty after generation."
)
assert len(local_subdirs) == len(external_subdirs), (
f"Mismatch in number of cache entries: "
f"Local={len(local_subdirs)}, External={len(external_subdirs)}"
)
# The subdirectories should correspond to the prompt hashes
# Since prompts are the same, the hash directories should be the same name
local_subdir_names = sorted([d.name for d in local_subdirs])
external_subdir_names = sorted([d.name for d in external_subdirs])
assert local_subdir_names == external_subdir_names, (
"Cache directory names do not match between local and external storage"
)
# Compare the contents of each corresponding cache directory
for subdir_name in local_subdir_names:
print(f"Comparing contents of cache directory: {subdir_name}")
assert _compare_directories(
storage_1_path / subdir_name, storage_2_path / subdir_name
), (
f"Contents differ for cache directory '{subdir_name}' between "
f"{storage_1_path} and {storage_2_path}"
)
events = get_connector_events()
# get_num_new_matched_tokens and update_state_after_alloc will be called
# on each connector in turn.
assert events["storage1-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage1-WORKER"][:5] == [
"register_kv_caches",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
"save_kv_layer",
]
assert events["storage2-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-WORKER"][:5] == [
"register_kv_caches",
"bind_connector_metadata",
"start_load_kv",
"wait_for_layer_load",
"save_kv_layer",
]
# Reset prefix cache or else we'll just get the tokens back from there.
llm.reset_prefix_cache()
# Run generation again - this should trigger loading from the first
# connector.
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
events = get_connector_events()
# get_num_new_matched_tokens will return new tokens from the first
# connector so update_state_after_alloc will be with allocated blocks
# on that one but with zero blocks for others (first nonzero match is
# chosen).
assert events["storage1-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
# Delete storage1 connector state
shutil.rmtree(storage_1_path)
# Reset prefix cache or else we'll just get the tokens back from there.
llm.reset_prefix_cache()
# Run generation again - this should trigger loading from the first
# connector.
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)
events = get_connector_events()
# get_num_new_matched_tokens will be called for both connectors but will
# return 0 from the first connector, but the second connector should have
# a hit, so update_state_after_alloc will only be called with allocated
# blocks for the second connector.
assert events["storage1-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta",
]
assert events["storage2-SCHEDULER"][:3] == [
"get_num_new_matched_tokens 0",
"update_state_after_alloc num_blocks=[7] 96",
"build_connector_meta",
]
# Clean up
shutil.rmtree(storage_1_path)
shutil.rmtree(storage_2_path)
def get_connector_events() -> dict[str, list[str]]:
# Read in connector events and reset the files.
import glob
event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log")
connector_events = {}
for fname in event_files:
name = fname.split("connector_")[1].split("_events.log")[0]
try:
with open(fname, "r+") as f:
connector_events[name] = [line.strip() for line in f if line.strip()]
f.truncate(0)
except Exception as e:
print(f"[ERROR] Could not read connector events for {name}: {e}")
return connector_events
def test_engine_id_conflict():
configs = [KVTransferConfig() for _ in range(2)]
ids = [config.engine_id for config in configs]
assert ids[0] != ids[1], (
f"Engine IDs should be different for different configs. Got {ids}"
)
class TestMultiConnectorStats:
"""Tests for MultiConnector stats reconstruction and operations."""
def test_build_kv_connector_stats_with_none(self):
"""Test that build_kv_connector_stats returns empty stats when given None."""
stats = MultiConnector.build_kv_connector_stats(data=None)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_with_empty_dict(self):
"""Test that build_kv_connector_stats returns empty stats with empty dict."""
stats = MultiConnector.build_kv_connector_stats(data={})
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 0
assert stats.is_empty()
def test_build_kv_connector_stats_reconstructs_nixl_stats(self):
"""Test that NixlConnector stats are properly reconstructed with
correct data."""
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5, 2.3],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
}
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert "NixlConnector" in stats.data
nixl_stats = stats.data["NixlConnector"]
assert isinstance(nixl_stats, NixlKVConnectorStats)
assert nixl_stats.data["transfer_duration"] == [1.5, 2.3]
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
assert nixl_stats.data["num_descriptors"] == [10, 20]
def test_build_kv_connector_stats_with_multiple_connectors(self):
"""Test reconstruction with multiple connector types that have custom stats."""
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
"MockConnector": {"data": {"mock_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
# Both connectors should be reconstructed
assert len(stats.data) == 2
assert "NixlConnector" in stats.data
assert "MockConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
# Verify data is preserved
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
def test_build_kv_connector_stats_raises_error_for_unknown_connector(self):
"""Test that unknown connectors raise an error."""
serialized_data = {
"UnknownConnector": {"data": {"some_field": [1, 2, 3]}},
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
}
with pytest.raises(
ValueError, match="Connector 'UnknownConnector' is not registered."
):
MultiConnector.build_kv_connector_stats(data=serialized_data)
def test_build_kv_connector_stats_with_already_instantiated_objects(self):
"""Test that already-instantiated stats objects are preserved (same process)."""
# This simulates the in-process case where stats are not serialized
nixl_stats = NixlKVConnectorStats(
data={
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
mock_stats = MockConnectorStats(data={"mock_field": [1, 2, 3]})
data_with_objects = {
"NixlConnector": nixl_stats,
"MockConnector": mock_stats,
}
stats = MultiConnector.build_kv_connector_stats(data=data_with_objects)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 2
# Verify objects are preserved as-is
assert stats.data["NixlConnector"] is nixl_stats
assert stats.data["MockConnector"] is mock_stats
def test_build_kv_connector_stats_with_mixed_objects_and_dicts(self):
"""Test handling mixed already-instantiated and serialized stats."""
# This can happen during transition or partial serialization
nixl_stats = NixlKVConnectorStats(
data={
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
mixed_data = {
"NixlConnector": nixl_stats, # Already instantiated
"MockConnector": {"data": {"mock_field": [1, 2, 3]}}, # Serialized
}
stats = MultiConnector.build_kv_connector_stats(data=mixed_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
assert len(stats.data) == 2
# Instantiated object preserved
assert stats.data["NixlConnector"] is nixl_stats
# Serialized object reconstructed
assert isinstance(stats.data["MockConnector"], MockConnectorStats)
assert stats.data["MockConnector"].data == {"mock_field": [1, 2, 3]}
def test_build_kv_connector_stats_skips_connectors_without_custom_stats(self):
"""Test that connectors without custom stats (return None) are skipped."""
# ExampleConnector doesn't override build_kv_connector_stats,
# so it returns None and should be skipped
serialized_data = {
"NixlConnector": {
"data": {
"transfer_duration": [1.5],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
},
"ExampleConnector": {"data": {"some_field": [1, 2, 3]}},
}
stats = MultiConnector.build_kv_connector_stats(data=serialized_data)
assert stats is not None
assert isinstance(stats, MultiKVConnectorStats)
# Only NixlConnector should be reconstructed
assert len(stats.data) == 1
assert "NixlConnector" in stats.data
assert isinstance(stats.data["NixlConnector"], NixlKVConnectorStats)
# ExampleConnector should be skipped (returns None)
assert "ExampleConnector" not in stats.data
def test_build_kv_connector_stats_handles_malformed_data(self):
"""Test that malformed data raises appropriate errors."""
serialized_data = {
"NixlConnector": {"wrong_field": {"transfer_duration": [1.5]}}
}
with pytest.raises(AssertionError, match="Expected a dict with a 'data' field"):
MultiConnector.build_kv_connector_stats(data=serialized_data)
def test_aggregate_same_connector(self):
"""Test aggregating stats from the same connector type."""
stats1 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
stats2 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [2.0],
"post_duration": [0.2],
"bytes_transferred": [2048],
"num_descriptors": [20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
result = stats1.aggregate(stats2)
assert result is stats1 # Should return self
assert "NixlConnector" in result.data
nixl_stats = result.data["NixlConnector"]
assert nixl_stats.data["transfer_duration"] == [1.0, 2.0]
assert nixl_stats.data["post_duration"] == [0.1, 0.2]
assert nixl_stats.data["bytes_transferred"] == [1024, 2048]
assert nixl_stats.data["num_descriptors"] == [10, 20]
def test_aggregate_new_connector(self):
"""Test aggregating stats when a new connector type appears."""
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorStats,
)
stats1 = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0],
"post_duration": [0.1],
"bytes_transferred": [1024],
"num_descriptors": [10],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
stats2 = MultiKVConnectorStats(
data={"ExampleConnector": KVConnectorStats(data={"field": [1, 2]})}
)
result = stats1.aggregate(stats2)
assert "NixlConnector" in result.data
assert "ExampleConnector" in result.data
def test_reduce(self):
"""Test that reduce() correctly reduces all nested connector stats."""
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0, 2.0],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
reduced = stats.reduce()
assert "NixlConnector" in reduced
assert isinstance(reduced["NixlConnector"], dict)
# Check that the stats were reduced (should have aggregated values)
assert "Num successful transfers" in reduced["NixlConnector"]
assert reduced["NixlConnector"]["Num successful transfers"] == 2
def test_reset(self):
"""Test that reset() resets all nested connector stats."""
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(
data={
"transfer_duration": [1.0, 2.0],
"post_duration": [0.1, 0.2],
"bytes_transferred": [1024, 2048],
"num_descriptors": [10, 20],
"num_failed_transfers": [],
"num_failed_notifications": [],
}
)
}
)
assert not stats.is_empty()
stats.reset()
# After reset, stats should be empty
assert stats.is_empty()
nixl_stats = stats.data["NixlConnector"]
assert len(nixl_stats.data["transfer_duration"]) == 0
def test_is_empty_with_multiple_connectors(self):
"""Test is_empty() returns correct value with multiple connectors."""
# All empty
stats = MultiKVConnectorStats(
data={
"NixlConnector": NixlKVConnectorStats(data={}),
}
)
# Initialize empty stats
stats.data["NixlConnector"].reset()
assert stats.is_empty()
# One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,534 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock
import pytest
import torch
from vllm import SamplingParams
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
OffloadingConnector,
OffloadingConnectorMetadata,
)
from vllm.forward_context import ForwardContext
from vllm.utils.hashing import sha256
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.core.kv_cache_utils import (
BlockHash,
get_request_block_hasher,
init_none_hash,
)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingEvent,
OffloadingManager,
PrepareStoreOutput,
)
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.worker import (
OffloadingHandler,
TransferResult,
TransferSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request
from .utils import (
EOS_TOKEN_ID,
create_model_runner_output,
create_scheduler,
create_vllm_config,
)
class MockLoadStoreSpec(LoadStoreSpec):
def __init__(self, block_hashes: Iterable[BlockHash]):
self.block_hashes: list[BlockHash] = list(block_hashes)
@staticmethod
def medium() -> str:
return "Mock"
def __repr__(self) -> str:
return repr(self.block_hashes)
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
self.completed_transfers = []
return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
return True
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda block_hashes: (
MockLoadStoreSpec(block_hashes)
)
self.handler = MockOffloadingHandler()
def get_manager(self) -> OffloadingManager:
return self.manager
def get_handlers(
self, _, __
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
return specs
@dataclass
class TransferSummary:
gpu_block_indices: list[int]
offload_addresses: list[Any]
class RequestRunner:
def __init__(
self, offloaded_block_size: int, gpu_block_size: int, num_gpu_blocks: int
):
self.offloaded_block_size: int = offloaded_block_size
self.gpu_block_size: int = gpu_block_size
self.num_gpu_blocks: int = num_gpu_blocks
self.req_id: int = -1
vllm_config = create_vllm_config(
block_size=gpu_block_size, max_num_batched_tokens=1000
)
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="OffloadingConnector",
kv_role="kv_both",
kv_connector_extra_config={
"spec_name": "MockOffloadingSpec",
"spec_module_path": "tests.v1.kv_connector.unit.test_offloading_connector", # noqa: E501
"block_size": offloaded_block_size,
},
)
self.scheduler: Scheduler = create_scheduler(
vllm_config, num_blocks=num_gpu_blocks
)
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_cross_layers_kv_cache(
kv_cache=torch.empty(0),
attn_backend=FlashAttentionBackend,
)
# extract connector of scheduler
scheduler_connector = self.scheduler.connector
assert scheduler_connector is not None
assert isinstance(scheduler_connector, OffloadingConnector)
self.scheduler_connector: OffloadingConnector = scheduler_connector
# extract mocked OffloadingManager of scheduler connector
connector_scheduler = scheduler_connector.connector_scheduler
assert connector_scheduler is not None
manager = connector_scheduler.manager
assert isinstance(manager, MagicMock)
self.manager: MagicMock = manager
assert connector_scheduler.gpu_block_size == gpu_block_size
assert connector_scheduler.offloaded_block_size == offloaded_block_size
# extract OffloadingSpec of worker_connector
connector_worker = self.worker_connector.connector_worker
assert connector_worker is not None
offloading_spec = connector_worker.spec
assert isinstance(offloading_spec, MockOffloadingSpec)
self.offloading_spec: MockOffloadingSpec = offloading_spec
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
init_none_hash(sha256)
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0
)
def new_request(self, token_ids: list[int]):
assert not self.scheduler.requests
self.req_id += 1
req = Request(
request_id=str(self.req_id),
prompt_token_ids=token_ids,
sampling_params=SamplingParams(max_tokens=1000),
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=self._block_hasher,
)
self.scheduler.add_request(req)
def _wait_for_transfers(self):
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks.values():
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
"""
tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
assert self.scheduler.requests
scheduler_output = self.scheduler.schedule()
self._update_gpu_block_idx()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
self.worker_connector.clear_connector_metadata()
model_runner_output = create_model_runner_output(
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id,
)
if self.scheduler.running:
token_id = next(tokens_iter, None)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers()
# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
assert not self.scheduler.running
while self.scheduler.requests:
scheduler_output = self.scheduler.schedule()
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
assert not finished_recving
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending
)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
def run(
self,
decoded_tokens: list[int],
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
"""
self.manager.reset_mock()
self._run(decoded_tokens)
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses
):
loaded_gpu_block_indexes.add(gpu_block_idx)
assert gpu_block_idx == self.offloaded[offloaded_address]
assert set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes
self.completed_loads.clear()
stored_gpu_block_indexes: set[int] = set()
for transfer in self.completed_stores:
for gpu_block_idx, offloaded_address in zip(
transfer.gpu_block_indices, transfer.offload_addresses
):
stored_gpu_block_indexes.add(gpu_block_idx)
self.offloaded[offloaded_address] = gpu_block_idx
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()
@pytest.fixture
def request_runner():
runners = []
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
runner = RequestRunner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
runners.append(runner)
return runner
yield runner_factory # pass factory to the test
def generate_store_output(block_hashes: Iterable[BlockHash]):
block_hashes = list(block_hashes)
return PrepareStoreOutput(
block_hashes_to_store=list(block_hashes),
store_spec=MockLoadStoreSpec(block_hashes),
block_hashes_evicted=[],
)
def test_offloading_connector(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
block_size_factor = offloaded_block_size // gpu_block_size
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
# 3 blocks, store just the middle block (skip first and last)
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
)
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
# add block missing 1 token -> no offload
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
runner.manager.prepare_store.side_effect = lambda block_hashes: None
runner.run(decoded_tokens=[0])
runner.manager.prepare_store.assert_called()
# 1 more block, now set block_hashes_to_store = []
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[0] * offloaded_block_size)
# 1 more block, now check touch was called with all 6 blocks
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17),
)
runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
# create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run(
decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6
# verify hashes are the same, except for the last block
assert block_hashes1[:5] == block_hashes2[:5]
assert block_hashes1[5] != block_hashes2[5]
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request(
token_ids=[0] * gpu_block_size + [1] * (offloaded_block_size - gpu_block_size)
)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_not_called()
# single block lookup with no hits
runner.new_request(token_ids=[1] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
# single block lookup with a hit
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
)
# single block lookup with a hit in a middle block
runner.new_request(
token_ids=[0] * offloaded_block_size * 2 + [1] * offloaded_block_size
)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
)
# test take_events
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
return [BlockHash(str(i).encode()) for i in int_hashes]
def take_events() -> Iterable[OffloadingEvent]:
yield OffloadingEvent(
block_hashes=to_hashes([1, 2, 3]), block_size=16, medium="A", removed=False
)
yield OffloadingEvent(
block_hashes=to_hashes([4, 5, 6]), block_size=32, medium="B", removed=True
)
runner.manager.take_events.side_effect = take_events
events = list(runner.scheduler_connector.take_events())
assert len(events) == 2
event = events[0]
assert isinstance(event, BlockStored)
assert event.block_hashes == to_hashes([1, 2, 3])
assert event.block_size == 16
assert event.medium == "A"
assert event.token_ids == []
assert event.parent_block_hash is None
assert event.lora_id is None
event = events[1]
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"

View File

@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
pytestmark = pytest.mark.cpu_test
class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(
self,
finished_sending: set[str] | None = None,
finished_recving: set[str] | None = None,
invalid_block_ids: set[int] | None = None,
expected_finished_count: int = 0,
):
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
expected_finished_count=expected_finished_count,
)
def __repr__(self):
return (
f"DummyModelRunnerOutput("
f"finished_sending={self.kv_connector_output.finished_sending},"
f"finished_recving={self.kv_connector_output.finished_recving})"
f"invalid_block_ids={self.kv_connector_output.invalid_block_ids})"
)
def test_aggregate_workers_output():
aggregator = KVOutputAggregator(expected_finished_count=2)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.invalid_block_ids
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, finished_recving={"req2"}
)
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {1}
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending == {"req1"}
assert aggregated.finished_recving is None
assert aggregated.invalid_block_ids == {2}
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, invalid_block_ids={4, 5}
)
aggregated = aggregator.aggregate([output1, output2])
assert aggregated is output1
aggregated = aggregated.kv_connector_output
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}
def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4)
assert aggregator._expected_finished_count == 4
# Some request with default expected finished requests
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1])
# still expecting to collect from 4 workers
assert aggregator._send_remaining_count["req1"] == 3
assert not aggregated.kv_connector_output.finished_sending
assert not aggregated.kv_connector_output.finished_recving
# Workers discover and find that in this setup they only need to
# collect from 2
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, expected_finished_count=2
)
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, expected_finished_count=2
)
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
# Req2 only needs 2 acks
aggregated = aggregator.aggregate([output1, output2, output3])
assert aggregated.kv_connector_output.expected_finished_count == 2
assert not aggregated.kv_connector_output.finished_sending
# Req2 is finished
assert "req2" not in aggregator._recv_remaining_count
assert aggregated.kv_connector_output.finished_recving == {"req2"}
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
# NOTE: This is to showcase dynamic update. Workers are responsible for
# ensuring "req1" termination in this case
assert aggregator._send_remaining_count["req1"] == 2

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (
assert_scheduler_empty,
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
# Ensure the request is finished after 1 token.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and in Persistent Batch ...
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
assert len(scheduler.requests) == 1
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
# STEP (2): Send Finished to PB.
# (2a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending={request_id}
)
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_short_prompt_lifecycle():
"""Test lifecycle of a Remote Decode request with short prompt."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = BLOCK_SIZE // 2
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.requests) == 1
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
# Even though tokens < block_size, there will be kv xfer for partial block.
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
assert len(kv_transfer_params["remote_block_ids"]) == 1
# Confirm we do not have any memory leaks after req lifecycle.
# We need to mark sending finish to clear data for persistent batch.
scheduler_output = scheduler.schedule()
# Use create_model_runner_output to pass kv_connector_output along
model_runner_output = create_model_runner_output(
reqs=[request], finished_sending={request.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still work with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS
)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, including the partial blocks,
# even if there is a cache hit.
assert len(kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
def test_abort_during_kv_transfer():
"""Test aborting request does not release blocks for remote decode."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True,
)
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request])
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
# Request removed from PB but blocks should not be freed.
assert len(scheduler.requests) == 1
# Abort the request, and check the blocks are still not freed
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)
assert len(scheduler.requests) == 1
# Simulate a finished sending notification
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request.request_id]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,577 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import pytest
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import FinishReason, RequestStatus
from .utils import (
assert_scheduler_empty,
create_model_runner_output,
create_request,
create_scheduler,
create_vllm_config,
)
pytestmark = pytest.mark.cpu_test
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert block_pool.free_block_queue.num_free_blocks < START_FREE_BLOCK_QUEUE_SIZE
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving={request_id}
)
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
assert len(scheduler.waiting) == 1
assert request_id in scheduler.finished_recving_kv_req_ids
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += 1 if block._block_hash is not None else 0
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert num_scheduled_tokens == total_prompt_tokens - num_computed_tokens
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
scheduler.schedule()
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_interleaved_lifecycle():
"""Test Remote Prefills Work Well With Other Requests."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
request_local_a = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
request_local_b = create_request(
request_id=3,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
# STEP 1: Regular request is running.
scheduler.add_request(request_local_a)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
model_runner_output = create_model_runner_output([request_local_a])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 2: Add a local and remote request.
scheduler.add_request(request_local_b)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output([request_local_a, request_local_b])
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 3: continue running, KVs not arrived yet.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 5: RECVed KVs are sent to ModelRunner.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP 6: Hit EOS and free.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote],
use_eos=True,
)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
common_prefix_len=NUM_TOKENS,
do_remote_prefill=True,
)
request_local = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
common_prefix_len=NUM_TOKENS,
do_remote_prefill=False,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
# Schedule the local prefill request. This should
# cause blocks to be cached, but separately from
scheduler.add_request(request_local)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_local.request_id]
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].req_to_blocks[request_remote.request_id]
# Local should have cached blocks (but not all due to preallocate).
num_hashed_blocks = 0
for block in local_blocks:
assert block.ref_cnt == 1
num_hashed_blocks += 1 if block._block_hash is not None else 0
assert num_hashed_blocks > 0
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
]
)
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving={request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert request_id in scheduler.finished_recving_kv_req_ids
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks[
request_id
]
)
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert scheduler_output.scheduled_new_reqs[0].num_computed_tokens == NUM_TOKENS - 1
assert scheduler_output.num_scheduled_tokens[request_id] == 1
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
scheduler.schedule()
outputs = engine_core_outputs[0].outputs
assert len(outputs) == 1
output = outputs[0]
assert output.finish_reason == FinishReason.STOP
assert_scheduler_empty(scheduler)
def test_cannot_schedule_after_recv():
"""
Test that we can handle no schedule after recv due to not
enough remaining KV blocks.
"""
# NOTE: the KVCacheManager will use 1 null block.
# So there are 5 total working blocks.
TOTAL_NUM_BLOCKS = 6
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
# Prime the KVCache.
NUM_PROMPT_BLOCKS = 2
BLOCK_SIZE = vllm_config.cache_config.block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
)
request_remote = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True,
)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 2: 5 blocks are in use (2 new for remote blocks).
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 4: try to schedule, remote request is put to running list
# because the transfer is completed.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal, request_remote]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 0
# Step 5: Remote request will be put back to waiting list
# because it needs new block to hold generated token.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Step 6: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Step 7: now we can schedule (with 2 blocks computed),
# request is retrieved from preempted list.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
assert (
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
== NUM_PROMPT_BLOCKS * BLOCK_SIZE
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 8: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_remote], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_cannot_recv():
"""
Test that we can handle no schedule KV block transfer due to not
enough remaining KV blocks.
"""
# NOTE: the KVCacheManager will use 1 null block.
# So there are 5 total working blocks.
TOTAL_NUM_BLOCKS = 6
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
# Prime the KVCache.
NUM_PROMPT_BLOCKS = 2
BLOCK_SIZE = vllm_config.cache_config.block_size
# Prompt will use 2 blocks + 1 block after we schedule.
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(
request_id=1, block_size=BLOCK_SIZE, num_tokens=NUM_TOKENS_LOCAL
)
request_remote = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True,
)
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 2: 3 blocks are in use,
# need 3 new for remote blocks but only 2 are available.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
# Should not have KV transfer in progress.
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 3: finish the request, free it.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_normal], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 5: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[], finished_recving={request_remote.request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# Step 6: schedule remote request
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
# Step 7: free everything.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(
reqs=[request_remote], use_eos=True
)
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,402 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from itertools import chain, count
from typing import Any
import torch
from vllm import SamplingParams
from vllm.config import (
CacheConfig,
DeviceConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
ExampleConnector,
)
from vllm.utils.hashing import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks
)
== 0
)
assert (
len(
scheduler.kv_cache_manager.coordinator.single_type_managers[
0
].num_cached_block
)
== 0
)
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks
)
assert num_free_blocks == (scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 64,
block_size: int = 16,
max_model_len: int = 10000,
enable_chunked_prefill: bool = True,
enable_permute_local_kv: bool = False,
kv_connector_extra_config: dict[str, Any] | None = None,
dtype: str = "float16",
cache_dtype: str = "auto",
hf_overrides: dict[str, Any] | None = None,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype=dtype,
seed=42,
hf_overrides=hf_overrides or {},
)
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=model_config.is_encoder_decoder,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype=cache_dtype,
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
enable_permute_local_kv=enable_permute_local_kv,
kv_connector_extra_config=kv_connector_extra_config or {},
)
return VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"),
)
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"], FullAttentionSpec(block_size, 1, 1, torch.float32, False)
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
_request_count = count(1)
_none_hash_initialized = False
def create_request(
request_id: int | None = None,
num_tokens: int = 10,
common_prefix_len=0,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = sha256,
) -> Request:
"""Make dummy request for testing."""
assert num_tokens >= common_prefix_len >= 0
if request_id is None:
request_id = next(_request_count)
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash_fn)
_none_hash_initialized = True
kv_transfer_params: dict[str, Any] | None = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False, do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_request_id=f"prefill-{request_id}",
remote_block_ids=list(range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
common_prefix = [1] * common_prefix_len if common_prefix_len > 0 else []
suffix = [i * request_id for i in range(num_tokens - common_prefix_len)]
prompt_token_ids = common_prefix + suffix
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
mm_features=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: set[str] | None = None,
finished_recving: set[str] | None = None,
invalid_block_ids: set[int] | None = None,
use_eos: bool = False,
token_id: int = 0,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else token_id
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = (
None
if (
finished_sending is None
and finished_recving is None
and invalid_block_ids is None
)
else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
)
)
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
kv_connector_output=kv_connector_output,
)
class TestExampleConnector(ExampleConnector):
def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = ExampleConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = (
tempfile.gettempdir()
+ f"/connector_{self.name}-{self.role.name}_events.log"
)
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in (
"_connector",
"call_record",
"name",
"_event_file",
"__class__",
"__dict__",
"__getattribute__",
"__init__",
): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(" ".join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
@dataclass(frozen=True)
class MockKVConfig:
matched_tokens: int = 0
is_async: bool = False
class MockKVConnectorMetadata(KVConnectorMetadata):
def __init__(self):
# Scheduler tests check metadata.requests
self.requests: list = []
class MockKVConnector(KVConnectorBase_V1):
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
extra_config = self._kv_transfer_config.kv_connector_extra_config
self.config = MockKVConfig(
matched_tokens=extra_config["matched_tokens"],
is_async=extra_config["is_async"],
)
def get_num_new_matched_tokens(
self,
request: Request,
num_computed_tokens: int,
) -> tuple[int | None, bool]:
return (self.config.matched_tokens, self.config.is_async)
def update_state_after_alloc(
self,
request: Request,
blocks: KVCacheBlocks,
num_external_tokens: int,
):
pass
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
metadata = MockKVConnectorMetadata()
cached_reqs = scheduler_output.scheduled_cached_reqs
for req_id in chain(
(req.req_id for req in scheduler_output.scheduled_new_reqs),
(
req_id
for req_id in cached_reqs.req_ids
if req_id in cached_reqs.resumed_req_ids
),
):
metadata.requests.append({"req_id": req_id})
return metadata
def start_load_kv(self, kv_caches, finished_req_ids):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass
def wait_for_save(self):
pass
KVConnectorFactory.register_connector(
"TestExampleConnector", __name__, TestExampleConnector.__name__
)
KVConnectorFactory.register_connector(
"MockKVConnector", __name__, MockKVConnector.__name__
)