Sync from v0.13
This commit is contained in:
171
tests/v1/ec_connector/integration/README.md
Normal file
171
tests/v1/ec_connector/integration/README.md
Normal file
@@ -0,0 +1,171 @@
|
||||
# EPD Correctness Test
|
||||
|
||||
This test verifies that EPD (Encoder-Prefill-Decode) disaggregation produces identical outputs to a baseline single instance.
|
||||
|
||||
## What It Tests
|
||||
|
||||
- **Baseline**: Single vLLM instance serving a multimodal model
|
||||
- **EPD (1E+1PD)**: 1 Encoder + 1 Prefill-Decode instance
|
||||
- **Baseline (1P+1D)**: 1 Prefill + 1 Decode instance
|
||||
- **EPD (1E+1P+1D)**: 1 Encoder + 1 Prefill + 1 Decode instance
|
||||
|
||||
The test ensures that disaggregated encoding produces **identical** outputs to the baseline.
|
||||
|
||||
Note that currently PD disaggregation set up may give slightly different results from a single instance. Therefore, we need the result from 1P+1D as the baseline for 1E+1P+1D
|
||||
|
||||
Please refer to [Disaggregated Encoder Feature](../../../docs/features/disagg_encoder.md) for the detailed explanation for the EPD features.
|
||||
|
||||
## Files
|
||||
|
||||
- `run_epd_correctness_test.sh` - Main test script (starts all instances and runs tests)
|
||||
- `test_epd_correctness.py` - Python test script (compares outputs)
|
||||
|
||||
## Usage
|
||||
|
||||
### Multimodal Prompts (Default)
|
||||
|
||||
```bash
|
||||
cd vllm
|
||||
./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
This runs the test with actual multimodal (image) prompts.
|
||||
|
||||
### Text-Only Prompts
|
||||
|
||||
```bash
|
||||
cd vllm
|
||||
USE_MM_PROMPTS=0 ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
This runs a quick test with text-only prompts to verify the setup works.
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```bash
|
||||
# Use specific GPUs
|
||||
GPU_E=0 GPU_PD=1 GPU_P=1 GPU_D=2 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific ports
|
||||
ENDPOINT_PORT=10001 bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific model
|
||||
MODEL="Qwen/Qwen2.5-VL-3B-Instruct" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
|
||||
# Use specific storage path
|
||||
EC_SHARED_STORAGE_PATH="/tmp/my_ec_cache" bash ./tests/v1/ec_connector/integration/run_epd_correctness_test.sh
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Step 1: Baseline
|
||||
|
||||
1. Start single vLLM instance on GPU
|
||||
2. Run test prompts (multimodal or text-only)
|
||||
3. Save outputs to `.vllm_epd_baseline.txt`
|
||||
4. Shutdown instance
|
||||
|
||||
### Step 2: EPD (1E + 1PD)
|
||||
|
||||
1. Clear encoder cache storage
|
||||
2. Start instances and proxy
|
||||
3. Run same test prompts
|
||||
4. Assert outputs match baseline exactly
|
||||
5. Shutdown instances
|
||||
|
||||
### Step 3: EPD (1E + 1P + 1D)
|
||||
|
||||
1. Clear encoder cache storage
|
||||
2. Start instances and proxy
|
||||
3. Run same test prompts
|
||||
4. Assert outputs match baseline exactly
|
||||
5. Shutdown instances
|
||||
|
||||
## Test Scenarios
|
||||
|
||||
### Multimodal Prompts (--use_mm_prompts)
|
||||
|
||||
Tests encoder cache transfer:
|
||||
|
||||
- Single image query
|
||||
- Multiple images in one request
|
||||
- Mixed image and text
|
||||
- Image with detailed questions
|
||||
|
||||
### Text-Only Prompts (default)
|
||||
|
||||
Quick sanity check:
|
||||
|
||||
- Simple text queries
|
||||
- Text-only explanations
|
||||
- Verifies proxy routing works
|
||||
|
||||
## Expected Behavior
|
||||
|
||||
### ✅ Test Passes When
|
||||
|
||||
- All disagg outputs match baseline outputs exactly
|
||||
- No errors during instance startup
|
||||
- Encoder cache is properly saved and loaded
|
||||
- Proxy correctly routes requests
|
||||
|
||||
### ❌ Test Fails When
|
||||
|
||||
- Outputs differ between baseline and disagg
|
||||
- Server startup fails
|
||||
- Encoder cache not found (should fall back to local execution)
|
||||
- Proxy routing errors
|
||||
|
||||
## Notes
|
||||
|
||||
- The test uses deterministic generation (`temperature=0.0`, `seed=42`)
|
||||
- Encoder cache should enable exact output reproduction
|
||||
- Test cleans up all instances and cache files after completion
|
||||
- Safe to run multiple times (idempotent)
|
||||
- We setup the PD disagg part with NixlConnector. Please read details about EPD in `examples/online_serving/disaggregated_encoder/README.md`
|
||||
|
||||
## Requirements
|
||||
|
||||
- Multiple GPUs (3 for 1E+1P+1D, 2 for 1E+1PD, 1 for baseline)
|
||||
- 1E+1P+1D is runnable with 2 GPU by assign E and P on the same GPU now.
|
||||
- Multimodal model (e.g., Qwen2.5-VL-3B-Instruct)
|
||||
- Internet access (for accessing vllm test images)
|
||||
|
||||
## Debugging
|
||||
|
||||
### Check Logs
|
||||
|
||||
Logs and baseline output are saved in `/tmp/` by default.
|
||||
Can be customized by changing the environment variables.
|
||||
|
||||
### Check Encoder Cache
|
||||
|
||||
```bash
|
||||
# Verify cache files are created
|
||||
ls -la $EC_SHARED_STORAGE_PATH/
|
||||
|
||||
# Should see directories with mm_hash names
|
||||
# Each containing encoder_cache.safetensors
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
Run individual components:
|
||||
|
||||
```bash
|
||||
# Baseline only
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode baseline \
|
||||
--baseline_file test_output.txt \
|
||||
--use_mm_prompts
|
||||
|
||||
# Disagg only (requires baseline output file!)
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode disagg \
|
||||
--baseline_file test_output.txt \
|
||||
--use_mm_prompts
|
||||
```
|
||||
BIN
tests/v1/ec_connector/integration/hato.jpg
Normal file
BIN
tests/v1/ec_connector/integration/hato.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 821 KiB |
476
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
Normal file
476
tests/v1/ec_connector/integration/run_epd_correctness_test.sh
Normal file
@@ -0,0 +1,476 @@
|
||||
#!/bin/bash
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
#
|
||||
# EPD (Encoder-Prefill-Decode) Correctness Test
|
||||
#
|
||||
# This script tests that EPD disaggregation produces the same outputs as baseline.
|
||||
# It runs:
|
||||
# 1. Baseline: Single vLLM instance
|
||||
# 2. EPD: 1E + 1PD setup
|
||||
# 3. Baseline for (E + P + D): 1P + 1D vLLM instances disagg
|
||||
# 4. EPD: 1E + 1P + 1D setup
|
||||
|
||||
# For GPU usage
|
||||
|
||||
# set -xe
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# Model to test
|
||||
MODEL="${MODEL:-Qwen/Qwen2.5-VL-3B-Instruct}"
|
||||
|
||||
# Set 1 to use multimodal prompts; else to use text-only
|
||||
USE_MM_PROMPTS="${USE_MM_PROMPTS:-1}"
|
||||
MM_FLAG=""
|
||||
if [ $USE_MM_PROMPTS = "1" ]; then
|
||||
MM_FLAG="--use_mm_prompts"
|
||||
fi
|
||||
|
||||
# GPU configuration
|
||||
GPU_E="${GPU_E:-0}"
|
||||
GPU_P="${GPU_P:-1}"
|
||||
GPU_D="${GPU_D:-2}"
|
||||
GPU_SINGLE="${GPU_SINGLE:-$GPU_P}"
|
||||
GPU_PD="${GPU_PD:-$GPU_P}"
|
||||
|
||||
# Port
|
||||
ENCODE_PORT="${ENCODE_PORT:-19534}"
|
||||
PREFILL_PORT="${PREFILL_PORT:-19535}"
|
||||
DECODE_PORT="${DECODE_PORT:-19536}"
|
||||
PREFILL_DECODE_PORT="${PREFILL_DECODE_PORT:-19537}"
|
||||
ENDPOINT_PORT="${ENDPOINT_PORT:-10001}"
|
||||
|
||||
# Storage path for encoder cache
|
||||
EC_SHARED_STORAGE_PATH="${EC_SHARED_STORAGE_PATH:-/tmp/ec_cache_test}"
|
||||
TIMEOUT_SECONDS="${TIMEOUT_SECONDS:-600}"
|
||||
|
||||
# Output file for baseline comparison and logs
|
||||
LOG_PATH="${LOG_PATH:-/tmp}"
|
||||
BASELINE_FILE="${BASELINE_FILE:-/tmp/vllm_baseline.txt}"
|
||||
BASELINE_PD_FILE="${BASELINE_PD_FILE:-/tmp/vllm_epd_baseline.txt}"
|
||||
|
||||
mkdir -p $LOG_PATH
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Wait for server to be ready
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout "$TIMEOUT_SECONDS" bash -c "
|
||||
until curl -s localhost:${port}/v1/chat/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Cleanup function
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
pkill -f "disagg_epd_proxy.py" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Function to run baseline (single instance)
|
||||
run_baseline() {
|
||||
echo "================================"
|
||||
echo "Running BASELINE (single instance)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local PORT=$ENDPOINT_PORT
|
||||
|
||||
# Start baseline instance
|
||||
echo "Starting baseline instance on GPU $GPU_SINGLE, port $PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_SINGLE" vllm serve "$MODEL" \
|
||||
--port $PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
> $LOG_PATH/baseline.log 2>&1 &
|
||||
|
||||
local BASELINE_PID=$!
|
||||
|
||||
# Wait for baseline to start
|
||||
echo "Waiting for baseline instance to start..."
|
||||
wait_for_server $PORT
|
||||
|
||||
curl http://127.0.0.1:$PORT/v1/models
|
||||
echo ""
|
||||
|
||||
# Run test in baseline mode
|
||||
echo "Running baseline..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode baseline \
|
||||
--baseline_file "$BASELINE_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup baseline
|
||||
echo "Stopping baseline instance..."
|
||||
kill $BASELINE_PID 2>/dev/null || true
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run EPD with 1E + 1PD
|
||||
run_epd_1e_1pd() {
|
||||
echo "================================"
|
||||
echo "Running EPD (1E + 1PD)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local ENCODE_PORT=$ENCODE_PORT
|
||||
local PREFILL_DECODE_PORT=$PREFILL_DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start encoder instance
|
||||
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--port $ENCODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1pd_encoder.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start prefill+decode instance
|
||||
echo "Starting PD instance on GPU $GPU_PD, port $PREFILL_DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_PD" vllm serve "$MODEL" \
|
||||
--port $PREFILL_DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1pd_pd.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for encoder instance..."
|
||||
wait_for_server $ENCODE_PORT
|
||||
echo "Waiting for PD instance..."
|
||||
wait_for_server $PREFILL_DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "disable" \
|
||||
--decode-servers-urls "http://localhost:$PREFILL_DECODE_PORT" \
|
||||
> $LOG_PATH/1e1pd_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/v1/models
|
||||
curl http://127.0.0.1:$PROXY_PORT/health
|
||||
echo ""
|
||||
|
||||
echo "All EPD (1E+1PD) services are up!"
|
||||
|
||||
# Run test in disagg mode
|
||||
echo "Running EPD (1E+1PD) correctness test..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode disagg \
|
||||
--baseline_file "$BASELINE_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "✓✓ 1E+1PD Correctness Test finished"
|
||||
echo "Stopping EPD (1E+1PD) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run baseline for 1E + 1P + 1D (PD disagg)
|
||||
run_baseline_1p_1d() {
|
||||
echo "================================"
|
||||
echo "Running PD BASELINE (1P + 1D)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local PREFILL_PORT=$PREFILL_PORT
|
||||
local DECODE_PORT=$DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start prefill instance
|
||||
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
> $LOG_PATH/1p1d_prefill.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start decode instance
|
||||
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
> $LOG_PATH/1p1d_decode.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for prefill instance..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--prefiller-ports $PREFILL_PORT \
|
||||
--decoder-ports $DECODE_PORT \
|
||||
> $LOG_PATH/1p1d_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/healthcheck
|
||||
echo ""
|
||||
|
||||
echo "All PD (1P+1D) services are up!"
|
||||
|
||||
# Run test in baseline mode
|
||||
echo "Running PD disagg baseline..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode baseline_pd \
|
||||
--baseline_file "$BASELINE_PD_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "Stopping PD (1P+1D) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Function to run EPD with 1E + 1P + 1D
|
||||
run_epd_1e_1p_1d() {
|
||||
echo "================================"
|
||||
echo "Running EPD (1E + 1P + 1D)"
|
||||
echo "================================"
|
||||
|
||||
cleanup_instances
|
||||
rm -rf "$EC_SHARED_STORAGE_PATH"
|
||||
mkdir -p "$EC_SHARED_STORAGE_PATH"
|
||||
|
||||
local ENCODE_PORT=$ENCODE_PORT
|
||||
local PREFILL_PORT=$PREFILL_PORT
|
||||
local DECODE_PORT=$DECODE_PORT
|
||||
local PROXY_PORT=$ENDPOINT_PORT
|
||||
|
||||
declare -a PIDS=()
|
||||
|
||||
# Start encoder instance
|
||||
echo "Starting encoder instance on GPU $GPU_E, port $ENCODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_E" vllm serve "$MODEL" \
|
||||
--port $ENCODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.01 \
|
||||
--enable-request-id-headers \
|
||||
--no-enable-prefix-caching \
|
||||
--max-num-batched-tokens 114688 \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_producer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_encoder.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start prefill instance
|
||||
echo "Starting prefill instance on GPU $GPU_P, port $PREFILL_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_P" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5559 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $PREFILL_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--ec-transfer-config '{
|
||||
"ec_connector": "ECExampleConnector",
|
||||
"ec_role": "ec_consumer",
|
||||
"ec_connector_extra_config": {
|
||||
"shared_storage_path": "'"$EC_SHARED_STORAGE_PATH"'"
|
||||
}
|
||||
}' \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_producer"
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_prefill.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Start decode instance
|
||||
echo "Starting decode instance on GPU $GPU_D, port $DECODE_PORT"
|
||||
CUDA_VISIBLE_DEVICES="$GPU_D" \
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=6000 \
|
||||
vllm serve "$MODEL" \
|
||||
--port $DECODE_PORT \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--enable-request-id-headers \
|
||||
--max-num-seqs 128 \
|
||||
--allowed-local-media-path ${GIT_ROOT}/tests/v1/ec_connector/integration \
|
||||
--kv-transfer-config '{
|
||||
"kv_connector": "NixlConnector",
|
||||
"kv_role": "kv_consumer"
|
||||
}' \
|
||||
> $LOG_PATH/1e1p1d_decode.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for instances to start
|
||||
echo "Waiting for encoder instance..."
|
||||
wait_for_server $ENCODE_PORT
|
||||
echo "Waiting for prefill instance..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Start proxy
|
||||
echo "Starting EPD proxy on port $PROXY_PORT"
|
||||
python "${GIT_ROOT}/examples/online_serving/disaggregated_encoder/disagg_epd_proxy.py" \
|
||||
--host "0.0.0.0" \
|
||||
--port $PROXY_PORT \
|
||||
--encode-servers-urls "http://localhost:$ENCODE_PORT" \
|
||||
--prefill-servers-urls "http://localhost:$PREFILL_PORT" \
|
||||
--decode-servers-urls "http://localhost:$DECODE_PORT" \
|
||||
> $LOG_PATH/1e1p1d_proxy.log 2>&1 &
|
||||
PIDS+=($!)
|
||||
|
||||
# Wait for proxy
|
||||
echo "Waiting for proxy..."
|
||||
wait_for_server $PROXY_PORT
|
||||
|
||||
curl http://127.0.0.1:$PROXY_PORT/v1/models
|
||||
curl http://127.0.0.1:$PROXY_PORT/health
|
||||
echo ""
|
||||
|
||||
echo "All EPD (1E+1P+1D) services are up!"
|
||||
|
||||
# Run test in disagg mode
|
||||
echo "Running EPD (1E+1P+1D) correctness test..."
|
||||
|
||||
python "${GIT_ROOT}/tests/v1/ec_connector/integration/test_epd_correctness.py" \
|
||||
--service_url "http://localhost:$PROXY_PORT" \
|
||||
--model_name "$MODEL" \
|
||||
--mode disagg \
|
||||
--baseline_file "$BASELINE_PD_FILE" \
|
||||
$MM_FLAG
|
||||
|
||||
# Cleanup
|
||||
echo "✓✓ 1E+1P+1D Correctness Test finished"
|
||||
echo "Stopping EPD (1E+1P+1D) instances..."
|
||||
for pid in "${PIDS[@]}"; do
|
||||
kill $pid 2>/dev/null || true
|
||||
done
|
||||
sleep 2
|
||||
cleanup_instances
|
||||
}
|
||||
|
||||
# Main execution
|
||||
echo "================================"
|
||||
echo "EPD Correctness Test Suite"
|
||||
echo "Model: $MODEL"
|
||||
echo "================================"
|
||||
|
||||
# Step 1: Run baseline
|
||||
run_baseline
|
||||
|
||||
# Step 2: Test 1E + 1PD
|
||||
run_epd_1e_1pd
|
||||
|
||||
# Step 3: Test baseline 1P + 1D
|
||||
run_baseline_1p_1d
|
||||
|
||||
# Step 4: Test 1E + 1P + 1D
|
||||
run_epd_1e_1p_1d
|
||||
|
||||
# Cleanup output file
|
||||
rm -f "$BASELINE_FILE"
|
||||
rm -f "$BASELINE_PD_FILE"
|
||||
|
||||
echo "================================"
|
||||
echo "✓✓ All EPD correctness tests finished!"
|
||||
echo "================================"
|
||||
304
tests/v1/ec_connector/integration/test_epd_correctness.py
Normal file
304
tests/v1/ec_connector/integration/test_epd_correctness.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
EPD Correctness Test
|
||||
|
||||
Tests that EPD (Encoder-Prefill-Decode) disaggregation produces the same
|
||||
outputs as a baseline single instance.
|
||||
|
||||
Usage:
|
||||
# Baseline mode (saves outputs):
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode baseline \
|
||||
--baseline_file .vllm_epd_baseline.txt
|
||||
|
||||
# Disagg mode (compares outputs):
|
||||
python test_epd_correctness.py \
|
||||
--service_url http://localhost:8000 \
|
||||
--model_name Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--mode disagg \
|
||||
--baseline_file .vllm_epd_baseline.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
import requests
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
# Sample prompts with multimodal content
|
||||
image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720))
|
||||
image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720))
|
||||
|
||||
image_local_path = f"{os.path.dirname(os.path.abspath(__file__))}/hato.jpg"
|
||||
|
||||
SAMPLE_PROMPTS_MM: list[dict] = [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_1)}"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "Single image query",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image;base64,{encode_image_base64(image_2)}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"file://{image_local_path}"},
|
||||
},
|
||||
{"type": "text", "text": "Describe these 2 images in detail."},
|
||||
],
|
||||
}
|
||||
],
|
||||
"description": "2 images with detailed query",
|
||||
},
|
||||
]
|
||||
|
||||
# Text-only prompts for mixed testing
|
||||
SAMPLE_PROMPTS_TEXT: list[dict] = [
|
||||
{
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"description": "Simple text-only query",
|
||||
},
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "Explain quantum computing in simple terms."}
|
||||
],
|
||||
"description": "Text-only explanation request",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def check_vllm_server(url: str, timeout=5, retries=10) -> bool:
|
||||
"""Check if the vLLM server is ready.
|
||||
|
||||
Args:
|
||||
url: The URL to check (usually /health or /healthcheck endpoint)
|
||||
timeout: Timeout in seconds for each request
|
||||
retries: Number of retries if the server is not ready
|
||||
|
||||
Returns:
|
||||
True if the server is ready, False otherwise
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
response = requests.get(url, timeout=timeout)
|
||||
if response.status_code == 200:
|
||||
print(f"Server is ready at {url}")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
f"Attempt {attempt + 1}/{retries}: Server returned "
|
||||
f"status code {response.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Attempt {attempt + 1}/{retries}: Error connecting: {e}")
|
||||
time.sleep(2) # Wait before retrying
|
||||
return False
|
||||
|
||||
|
||||
def run_chat_completion(
|
||||
base_url: str,
|
||||
model_name: str,
|
||||
messages: list,
|
||||
max_tokens: int = MAX_OUTPUT_LEN,
|
||||
) -> str:
|
||||
"""Run a chat completion request.
|
||||
|
||||
Args:
|
||||
base_url: Base URL of the vLLM server
|
||||
model_name: Name of the model
|
||||
messages: Messages for chat completion
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Generated text content
|
||||
"""
|
||||
client = openai.OpenAI(api_key="EMPTY", base_url=base_url)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
return completion.choices[0].message.content
|
||||
|
||||
|
||||
def main():
|
||||
"""Main test function."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="EPD correctness test - compare disagg vs baseline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--service_url",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The vLLM service URL (e.g., http://localhost:8000)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
type=str,
|
||||
default="baseline",
|
||||
choices=["baseline", "baseline_pd", "disagg"],
|
||||
help="Mode: baseline/baseline_pd (saves outputs) or disagg (compares outputs)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--baseline_file",
|
||||
type=str,
|
||||
default=".vllm_epd_baseline.txt",
|
||||
help="File to save/load baseline outputs",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_mm_prompts",
|
||||
action="store_true",
|
||||
help="Use multimodal prompts (default: use text-only for quick testing)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Service URL: {args.service_url}")
|
||||
print(f"Model: {args.model_name}")
|
||||
print(f"Mode: {args.mode}")
|
||||
print(f"Output file: {args.baseline_file}")
|
||||
print(f"Use MM prompts: {args.use_mm_prompts}")
|
||||
|
||||
# Determine health check endpoint
|
||||
if args.mode == "baseline":
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
elif args.mode == "baseline_pd":
|
||||
# Nixl toy proxy use /healthcheck
|
||||
health_check_url = f"{args.service_url}/healthcheck"
|
||||
else:
|
||||
# Disagg EPD proxy uses /health
|
||||
health_check_url = f"{args.service_url}/health"
|
||||
if not os.path.exists(args.baseline_file):
|
||||
raise ValueError(
|
||||
f"In disagg mode, the output file {args.baseline_file} from "
|
||||
"baseline does not exist. Run baseline mode first."
|
||||
)
|
||||
|
||||
# Check if server is ready
|
||||
if not check_vllm_server(health_check_url):
|
||||
raise RuntimeError(f"vLLM server at {args.service_url} is not ready!")
|
||||
|
||||
# Select prompts to use
|
||||
if args.use_mm_prompts:
|
||||
test_prompts = SAMPLE_PROMPTS_MM
|
||||
print("Using multimodal prompts")
|
||||
else:
|
||||
test_prompts = SAMPLE_PROMPTS_TEXT
|
||||
print("Using text-only prompts for quick testing")
|
||||
|
||||
# Run completions
|
||||
service_url = f"{args.service_url}/v1"
|
||||
output_strs = {}
|
||||
|
||||
for i, prompt_data in enumerate(test_prompts):
|
||||
print(
|
||||
f"\nRunning prompt {i + 1}/{len(test_prompts)}: "
|
||||
f"{prompt_data['description']}"
|
||||
)
|
||||
|
||||
output_str = run_chat_completion(
|
||||
base_url=service_url,
|
||||
model_name=args.model_name,
|
||||
messages=prompt_data["messages"],
|
||||
max_tokens=MAX_OUTPUT_LEN,
|
||||
)
|
||||
|
||||
# Use description as key for comparison
|
||||
key = prompt_data["description"]
|
||||
output_strs[key] = output_str
|
||||
print(f"Output: {output_str}")
|
||||
|
||||
if args.mode in ("baseline", "baseline_pd"):
|
||||
# Baseline mode: Save outputs
|
||||
print(f"\nSaving baseline outputs to {args.baseline_file}")
|
||||
try:
|
||||
with open(args.baseline_file, "w") as json_file:
|
||||
json.dump(output_strs, json_file, indent=4)
|
||||
print("✅ Baseline outputs saved successfully")
|
||||
except OSError as e:
|
||||
print(f"Error writing to file: {e}")
|
||||
raise
|
||||
else:
|
||||
# Disagg mode: Load and compare outputs
|
||||
print(f"\nLoading baseline outputs from {args.baseline_file}")
|
||||
baseline_outputs = None
|
||||
try:
|
||||
with open(args.baseline_file) as json_file:
|
||||
baseline_outputs = json.load(json_file)
|
||||
except OSError as e:
|
||||
print(f"Error reading from file: {e}")
|
||||
raise
|
||||
|
||||
# Verify outputs match
|
||||
print("\nComparing disagg outputs with baseline...")
|
||||
assert isinstance(baseline_outputs, dict), "Baseline outputs should be a dict"
|
||||
assert len(baseline_outputs) == len(output_strs), (
|
||||
f"Length mismatch: baseline has {len(baseline_outputs)}, "
|
||||
f"disagg has {len(output_strs)}"
|
||||
)
|
||||
|
||||
all_match = True
|
||||
for key, baseline_output in baseline_outputs.items():
|
||||
assert key in output_strs, f"{key} not in disagg outputs"
|
||||
|
||||
disagg_output = output_strs[key]
|
||||
if baseline_output == disagg_output:
|
||||
print(f"✅ {key}: MATCH")
|
||||
else:
|
||||
print(f"❌ {key}: MISMATCH")
|
||||
print(f" Baseline: {baseline_output}")
|
||||
print(f" Disagg: {disagg_output}")
|
||||
all_match = False
|
||||
|
||||
assert all_match, "❌❌Disagg outputs do not match baseline!❌❌"
|
||||
if all_match:
|
||||
print("\n✅ All outputs match! Test PASSED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
609
tests/v1/ec_connector/unit/test_ec_example_connector.py
Normal file
609
tests/v1/ec_connector/unit/test_ec_example_connector.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for ECExampleConnector.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorRole
|
||||
from vllm.distributed.ec_transfer.ec_connector.example_connector import (
|
||||
ECExampleConnector,
|
||||
ECExampleConnectorMetadata,
|
||||
MMMeta,
|
||||
)
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
# ------------------ Mock Classes ------------------ #
|
||||
class MockRequest:
|
||||
def __init__(self, request_id, mm_hashes: list[str], token_counts: list[int]):
|
||||
assert len(mm_hashes) == len(token_counts)
|
||||
self.request_id = request_id
|
||||
self._token_counts = token_counts
|
||||
self.mm_features = []
|
||||
for i, mm_hash in enumerate(mm_hashes):
|
||||
feature = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier=mm_hash,
|
||||
mm_position=PlaceholderRange(offset=0, length=self._token_counts[i]),
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_storage(tmp_path):
|
||||
"""Fixture providing temporary storage path."""
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_config_producer(temp_storage):
|
||||
"""Fixture providing mock VllmConfig for producer role."""
|
||||
config = Mock(spec=VllmConfig)
|
||||
config.ec_transfer_config = Mock()
|
||||
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
|
||||
config.ec_transfer_config.is_ec_producer = True
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vllm_config_consumer(temp_storage):
|
||||
"""Fixture providing mock VllmConfig for consumer role."""
|
||||
config = Mock(spec=VllmConfig)
|
||||
config.ec_transfer_config = Mock()
|
||||
config.ec_transfer_config.get_from_extra_config = Mock(return_value=temp_storage)
|
||||
config.ec_transfer_config.is_ec_producer = False
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_with_3_mm():
|
||||
"""Fixture providing mock Request with 3 multimodal items."""
|
||||
request_id = "test_req_123"
|
||||
mm_hashes = ["img_hash_1", "img_hash_2", "img_hash_3"]
|
||||
token_counts = [100, 150, 200]
|
||||
|
||||
request = MockRequest(request_id, mm_hashes, token_counts)
|
||||
return request
|
||||
|
||||
|
||||
# ------------------ Unit Tests ------------------ #
|
||||
class TestECExampleConnectorBasics:
|
||||
"""Test basic EC connector functionality."""
|
||||
|
||||
def test_initialization_producer(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test connector initializes correctly as producer."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
assert connector.role == ECConnectorRole.SCHEDULER
|
||||
assert connector.is_producer
|
||||
assert connector._storage_path == temp_storage
|
||||
assert connector._mm_datas_need_loads == {}
|
||||
|
||||
def test_initialization_consumer(self, mock_vllm_config_consumer, temp_storage):
|
||||
"""Test connector initializes correctly as consumer."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
assert connector.role == ECConnectorRole.WORKER
|
||||
assert not connector.is_producer
|
||||
assert connector._storage_path == temp_storage
|
||||
|
||||
def test_role_assignment(self, mock_vllm_config_producer):
|
||||
"""Test role is correctly assigned."""
|
||||
scheduler_connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
worker_connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
assert scheduler_connector.role == ECConnectorRole.SCHEDULER
|
||||
assert worker_connector.role == ECConnectorRole.WORKER
|
||||
|
||||
|
||||
class TestCacheExistence:
|
||||
"""Test cache existence checking using has_caches() API."""
|
||||
|
||||
def test_has_caches_all_exist_3_items(
|
||||
self,
|
||||
mock_vllm_config_producer,
|
||||
mock_vllm_config_consumer,
|
||||
mock_request_with_3_mm,
|
||||
):
|
||||
"""Test has_caches returns True when all 3 caches exist."""
|
||||
# Test for producer first
|
||||
producer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Create cache files using save_caches (proper way)
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
for mm_feature in mock_request_with_3_mm.mm_features:
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||
producer.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Test using has_caches API
|
||||
producer_result = producer.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(producer_result) == 3
|
||||
assert all(producer_result), f"Expected all True, got {producer_result}"
|
||||
|
||||
# Also test consumer can check if cache exists
|
||||
consumer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Test using has_caches API
|
||||
consumer_result = consumer.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(consumer_result) == 3
|
||||
assert all(consumer_result), f"Expected all True, got {consumer_result}"
|
||||
|
||||
def test_has_caches_none_exist(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test has_caches returns False when no caches exist."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Test without creating any files
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert not any(result), f"Expected all False, got {result}"
|
||||
|
||||
def test_has_caches_partial_exist(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test has_caches with some caches existing (1 of 3)."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Create only the second cache file
|
||||
mm_hash_second = mock_request_with_3_mm.mm_features[1].identifier
|
||||
encoder_cache = {mm_hash_second: torch.randn(10, 768)}
|
||||
connector.save_caches(encoder_cache, mm_hash_second)
|
||||
|
||||
# Test
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 3
|
||||
assert not result[0] # First doesn't exist
|
||||
assert result[1] # Second exists
|
||||
assert not result[2] # Third doesn't exist
|
||||
|
||||
|
||||
class TestStateManagement:
|
||||
"""Test connector state management."""
|
||||
|
||||
def test_update_state_after_alloc_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test state update after allocation for 3 MM items."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Initial state should be empty
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
# Update state for all 3 items
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
|
||||
# Check state updated for all 3
|
||||
assert len(connector._mm_datas_need_loads) == 3
|
||||
assert "img_hash_1" in connector._mm_datas_need_loads
|
||||
assert "img_hash_2" in connector._mm_datas_need_loads
|
||||
assert "img_hash_3" in connector._mm_datas_need_loads
|
||||
assert connector._mm_datas_need_loads["img_hash_1"] == 100
|
||||
assert connector._mm_datas_need_loads["img_hash_2"] == 150
|
||||
assert connector._mm_datas_need_loads["img_hash_3"] == 200
|
||||
|
||||
def test_build_connector_meta_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test metadata building for 3 MM items."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Setup state for all 3 items
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
|
||||
# Build metadata
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
metadata = connector.build_connector_meta(scheduler_output)
|
||||
|
||||
# Assert
|
||||
assert isinstance(metadata, ECExampleConnectorMetadata)
|
||||
assert len(metadata.mm_datas) == 3
|
||||
assert metadata.mm_datas[0].mm_hash == "img_hash_1"
|
||||
assert metadata.mm_datas[0].num_token == 100
|
||||
assert metadata.mm_datas[1].mm_hash == "img_hash_2"
|
||||
assert metadata.mm_datas[1].num_token == 150
|
||||
assert metadata.mm_datas[2].mm_hash == "img_hash_3"
|
||||
assert metadata.mm_datas[2].num_token == 200
|
||||
|
||||
# State should be cleared after building
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
def test_build_connector_meta_empty(self, mock_vllm_config_producer):
|
||||
"""Test metadata building with empty state."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
metadata = connector.build_connector_meta(scheduler_output)
|
||||
|
||||
assert isinstance(metadata, ECExampleConnectorMetadata)
|
||||
assert len(metadata.mm_datas) == 0
|
||||
|
||||
def test_state_cleared_after_metadata_build(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||
):
|
||||
"""Test that state is properly cleared after building metadata."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
# Add state
|
||||
for i in range(3):
|
||||
connector.update_state_after_alloc(mock_request_with_3_mm, index=i)
|
||||
assert len(connector._mm_datas_need_loads) == 3
|
||||
|
||||
# Build metadata (should clear state)
|
||||
scheduler_output = Mock(spec=SchedulerOutput)
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
|
||||
# State should be empty
|
||||
assert len(connector._mm_datas_need_loads) == 0
|
||||
|
||||
# Build again should return empty metadata
|
||||
metadata2 = connector.build_connector_meta(scheduler_output)
|
||||
assert len(metadata2.mm_datas) == 0
|
||||
|
||||
|
||||
class TestCacheSaving:
|
||||
"""Test encoder cache saving (producer only)."""
|
||||
|
||||
def test_save_caches_producer_3_items(
|
||||
self, mock_vllm_config_producer, mock_request_with_3_mm, temp_storage
|
||||
):
|
||||
"""Test cache saving as producer for 3 different MM items."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Create and save 3 different caches
|
||||
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
for mm_hash in mm_hashes:
|
||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Verify all files exist using has_caches
|
||||
result = connector.has_caches(mock_request_with_3_mm)
|
||||
assert all(result), f"Not all caches were saved: {result}"
|
||||
|
||||
# Verify each file's content
|
||||
for mm_hash in mm_hashes:
|
||||
filename = connector._generate_filename_debug(mm_hash)
|
||||
loaded = safetensors.torch.load_file(filename)
|
||||
assert "ec_cache" in loaded
|
||||
assert torch.allclose(loaded["ec_cache"], encoder_cache[mm_hash].cpu())
|
||||
|
||||
def test_save_caches_consumer_skips(self, mock_vllm_config_consumer):
|
||||
"""Test cache saving is skipped for consumer."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_hash_consumer"
|
||||
encoder_cache = {mm_hash: torch.randn(10, 768)}
|
||||
|
||||
# Save should not raise but also not create file
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
# Verify file doesn't exist using has_caches
|
||||
mock_request = MockRequest("req_consumer", [mm_hash], [10])
|
||||
result = connector.has_caches(mock_request)
|
||||
assert not result[0], "Consumer should not save caches"
|
||||
|
||||
|
||||
class TestCacheLoading:
|
||||
"""Test encoder cache loading (consumer)."""
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_start_load_caches_consumer_3_items(
|
||||
self,
|
||||
mock_vllm_config_producer,
|
||||
mock_vllm_config_consumer,
|
||||
mock_request_with_3_mm,
|
||||
temp_storage,
|
||||
):
|
||||
"""Test consumer loads 3 caches from storage."""
|
||||
# First, create producer to save caches
|
||||
producer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Producer saves 3 caches
|
||||
mm_hashes = [f.identifier for f in mock_request_with_3_mm.mm_features]
|
||||
saved_caches = {}
|
||||
for mm_hash in mm_hashes:
|
||||
saved_caches[mm_hash] = torch.randn(10, 768)
|
||||
producer.save_caches(saved_caches, mm_hash)
|
||||
|
||||
# Now consumer loads
|
||||
consumer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Setup metadata for all 3
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
for mm_hash in mm_hashes:
|
||||
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Load
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
# Verify all 3 loaded
|
||||
assert len(encoder_cache) == 3
|
||||
for mm_hash in mm_hashes:
|
||||
assert mm_hash in encoder_cache, f"{mm_hash} missing in encoder_cache"
|
||||
assert encoder_cache[mm_hash].is_cuda, (
|
||||
f"{mm_hash} cache is in {encoder_cache[mm_hash].device}"
|
||||
)
|
||||
assert torch.allclose(
|
||||
encoder_cache[mm_hash].cpu(), saved_caches[mm_hash]
|
||||
), f"{mm_hash} cache saved and loaded tesnor are not the same"
|
||||
|
||||
def test_start_load_caches_skip_existing(
|
||||
self, mock_vllm_config_producer, mock_vllm_config_consumer, temp_storage
|
||||
):
|
||||
"""Test cache loading skips already cached items."""
|
||||
# Setup: producer saves cache
|
||||
producer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "existing_hash"
|
||||
saved_cache = torch.randn(10, 768)
|
||||
producer.save_caches({mm_hash: saved_cache}, mm_hash)
|
||||
|
||||
# Consumer setup
|
||||
consumer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta(mm_hash, 100))
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Pre-populate encoder_cache with different value
|
||||
existing_cache = torch.randn(5, 512)
|
||||
encoder_cache = {mm_hash: existing_cache}
|
||||
|
||||
# Load (should skip since already exists)
|
||||
with patch("safetensors.torch.load_file") as mock_load:
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
# Should not call load_file since cache exists
|
||||
mock_load.assert_not_called()
|
||||
|
||||
# Verify original cache unchanged
|
||||
assert torch.equal(encoder_cache[mm_hash], existing_cache)
|
||||
|
||||
def test_start_load_caches_empty_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test loading with empty metadata does nothing."""
|
||||
consumer = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
# Setup empty metadata
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
consumer.bind_connector_metadata(metadata)
|
||||
|
||||
# Load (should not raise)
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
consumer.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
# Cache should remain empty
|
||||
assert len(encoder_cache) == 0
|
||||
|
||||
|
||||
class TestFilenameGeneration:
|
||||
"""Test filename and path generation."""
|
||||
|
||||
def test_generate_foldername(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test folder name generation."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_folder_hash"
|
||||
folder = connector._generate_foldername_debug(mm_hash)
|
||||
|
||||
assert folder == os.path.join(temp_storage, mm_hash)
|
||||
assert os.path.isdir(folder) # Should be created
|
||||
|
||||
def test_generate_filename(self, mock_vllm_config_producer, temp_storage):
|
||||
"""Test filename generation."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "test_file_hash"
|
||||
filename = connector._generate_filename_debug(mm_hash)
|
||||
|
||||
expected = os.path.join(temp_storage, mm_hash, "encoder_cache.safetensors")
|
||||
assert filename == expected
|
||||
assert os.path.isdir(os.path.dirname(filename)) # Folder created
|
||||
|
||||
def test_generate_filename_consistency(self, mock_vllm_config_producer):
|
||||
"""Test filename generation is consistent."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "consistency_hash"
|
||||
filename1 = connector._generate_filename_debug(mm_hash)
|
||||
filename2 = connector._generate_filename_debug(mm_hash)
|
||||
|
||||
assert filename1 == filename2
|
||||
|
||||
|
||||
class TestMetadataBindingLifecycle:
|
||||
"""Test metadata binding and clearing lifecycle."""
|
||||
|
||||
def test_bind_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test binding connector metadata."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta("hash_1", 100))
|
||||
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
assert connector._connector_metadata is metadata
|
||||
|
||||
def test_clear_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test clearing connector metadata."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
connector.clear_connector_metadata()
|
||||
|
||||
assert connector._connector_metadata is None
|
||||
|
||||
def test_get_connector_metadata(self, mock_vllm_config_consumer):
|
||||
"""Test getting connector metadata."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
retrieved = connector._get_connector_metadata()
|
||||
|
||||
assert retrieved is metadata
|
||||
|
||||
def test_get_connector_metadata_not_set(self, mock_vllm_config_consumer):
|
||||
"""Test getting metadata when not set raises."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
connector._get_connector_metadata()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_save_empty_cache(self, mock_vllm_config_producer):
|
||||
"""Test saving empty tensor."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
mm_hash = "empty_hash"
|
||||
encoder_cache = {mm_hash: torch.empty(0)}
|
||||
|
||||
# Should not raise
|
||||
connector.save_caches(encoder_cache, mm_hash)
|
||||
|
||||
def test_load_nonexistent_cache(self, mock_vllm_config_consumer):
|
||||
"""Test loading cache that doesn't exist raises error."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_consumer,
|
||||
role=ECConnectorRole.WORKER,
|
||||
)
|
||||
|
||||
metadata = ECExampleConnectorMetadata()
|
||||
metadata.add_mm_data(MMMeta.make_meta("nonexistent_hash", 100))
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
encoder_cache: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Should raise FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
connector.start_load_caches(encoder_cache=encoder_cache)
|
||||
|
||||
def test_has_caches_empty_request(self, mock_vllm_config_producer):
|
||||
"""Test has_caches with request that has no MM data."""
|
||||
connector = ECExampleConnector(
|
||||
vllm_config=mock_vllm_config_producer,
|
||||
role=ECConnectorRole.SCHEDULER,
|
||||
)
|
||||
|
||||
mock_request = MockRequest("req_empty", [], [])
|
||||
|
||||
result = connector.has_caches(mock_request)
|
||||
|
||||
assert len(result) == 0
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user