[CI] Add CI Testing for Prefill-Decode Disaggregation with Router (#7540)
This commit is contained in:
249
.github/workflows/pr-test-pd-router.yml
vendored
Normal file
249
.github/workflows/pr-test-pd-router.yml
vendored
Normal file
@@ -0,0 +1,249 @@
|
||||
name: Test Disaggregation Mode
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'python/sglang/srt/disaggregation/**'
|
||||
- 'scripts/ci_start_disaggregation_servers.sh'
|
||||
- 'sgl-router/**'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
- 'python/sglang/srt/disaggregation/**'
|
||||
- 'scripts/ci_start_disaggregation_servers.sh'
|
||||
- 'sgl-router/**'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: test-disaggregation-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
test-disaggregation:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
runs-on: [h200]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 10
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Setup Rust
|
||||
run: |
|
||||
bash scripts/ci_install_rust.sh
|
||||
|
||||
- name: Cache Rust dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
sgl-router/target/
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('sgl-router/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Cache pip dependencies
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('python/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Validate environment
|
||||
run: |
|
||||
echo "=== System Validation ==="
|
||||
nvidia-smi
|
||||
echo "GPU count: $(nvidia-smi -L | wc -l)"
|
||||
if [ $(nvidia-smi -L | wc -l) -lt 8 ]; then
|
||||
echo "Error: This test requires at least 8 GPUs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== RDMA Validation ==="
|
||||
if ! command -v ibv_devices >/dev/null 2>&1; then
|
||||
echo "Error: InfiniBand tools not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check for active IB devices
|
||||
found_active_device=false
|
||||
for device in mlx5_{0..11}; do
|
||||
if ibv_devinfo $device >/dev/null 2>&1; then
|
||||
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
|
||||
if [[ "$state" == "PORT_ACTIVE" ]]; then
|
||||
echo "✓ Found active device: $device"
|
||||
found_active_device=true
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$found_active_device" = false ]; then
|
||||
echo "Error: No active IB devices found"
|
||||
echo "Available devices:"
|
||||
ibv_devices || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "=== Model Validation ==="
|
||||
if [ ! -d "/raid/models/meta-llama/Llama-3.1-8B-Instruct" ]; then
|
||||
echo "Error: Model not found"
|
||||
ls -la /raid/models/ || echo "No models directory"
|
||||
exit 1
|
||||
fi
|
||||
echo "✓ Model found"
|
||||
|
||||
- name: Install SGLang dependencies
|
||||
run: |
|
||||
echo "Installing SGLang with all extras..."
|
||||
python3 -m pip --no-cache-dir install -e "python[all]" --break-system-packages
|
||||
python3 -m pip --no-cache-dir install mooncake-transfer-engine==0.3.4.post1
|
||||
|
||||
- name: Build and install sgl-router
|
||||
run: |
|
||||
source "$HOME/.cargo/env"
|
||||
echo "Building sgl-router..."
|
||||
cd sgl-router
|
||||
cargo build && python3 -m build && pip install --force-reinstall dist/*.whl
|
||||
|
||||
- name: Start disaggregation servers
|
||||
id: start_servers
|
||||
run: |
|
||||
echo "Starting disaggregation servers..."
|
||||
bash scripts/ci_start_disaggregation_servers.sh &
|
||||
SERVER_PID=$!
|
||||
echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT
|
||||
|
||||
echo "Waiting for router to become healthy..."
|
||||
TIMEOUT=300
|
||||
ELAPSED=0
|
||||
while [ $ELAPSED -lt $TIMEOUT ]; do
|
||||
if curl --connect-timeout 5 --silent http://127.0.0.9:8000 > /dev/null 2>&1; then
|
||||
echo "✓ Router is reachable"
|
||||
break
|
||||
fi
|
||||
if ! ps -p $SERVER_PID > /dev/null; then
|
||||
echo "Error: Server processes failed to start"
|
||||
exit 1
|
||||
fi
|
||||
echo "Waiting for router... (${ELAPSED}s/${TIMEOUT}s)"
|
||||
sleep 10
|
||||
ELAPSED=$((ELAPSED + 10))
|
||||
done
|
||||
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "Error: Router health check timeout after ${TIMEOUT}s"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✓ Servers started and healthy (PID: $SERVER_PID)"
|
||||
|
||||
- name: Test API functionality
|
||||
timeout-minutes: 5
|
||||
run: |
|
||||
BASE_URL="http://127.0.0.9:8000"
|
||||
|
||||
echo "Testing API completions..."
|
||||
response=$(curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Write a Python function to calculate fibonacci numbers recursively"}
|
||||
],
|
||||
"stream": false,
|
||||
"max_tokens": 100
|
||||
}')
|
||||
|
||||
if echo "$response" | jq -e '.choices[0].message.content' > /dev/null 2>&1; then
|
||||
echo "✓ API test passed"
|
||||
else
|
||||
echo "✗ API test failed: $response"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Testing streaming API..."
|
||||
stream_response=$(timeout 30 curl -s -X POST "$BASE_URL/v1/chat/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer test-token" \
|
||||
-d '{
|
||||
"model": "/raid/models/meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count from 1 to 5"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 50
|
||||
}')
|
||||
|
||||
if echo "$stream_response" | grep -q "data:"; then
|
||||
echo "✓ Streaming API test passed"
|
||||
else
|
||||
echo "✗ Streaming API test failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run benchmark test
|
||||
timeout-minutes: 5
|
||||
run: |
|
||||
echo "Running benchmark test..."
|
||||
benchmark_output=$(python3 -m sglang.bench_one_batch_server \
|
||||
--model-path "/raid/models/meta-llama/Llama-3.1-8B-Instruct" \
|
||||
--base-url "http://127.0.0.9:8000" \
|
||||
--batch-size 8 \
|
||||
--input-len 4096 \
|
||||
--output-len 5 \
|
||||
--skip-warmup)
|
||||
|
||||
echo "$benchmark_output"
|
||||
|
||||
# Extract metrics from output
|
||||
latency=$(echo "$benchmark_output" | grep "latency:" | awk '{print $2}' | sed 's/s//')
|
||||
input_throughput=$(echo "$benchmark_output" | grep "input throughput:" | awk '{print $3}')
|
||||
output_throughput=$(echo "$benchmark_output" | grep "output throughput:" | awk '{print $3}')
|
||||
|
||||
# Validate performance (latency<1.5s, input>20k, output>1k)
|
||||
command -v bc >/dev/null || (apt-get update && apt-get install -y bc)
|
||||
|
||||
echo "Performance: ${latency}s | ${input_throughput} | ${output_throughput} tok/s"
|
||||
|
||||
fail=""
|
||||
(( $(echo "$latency > 1.5" | bc -l) )) && fail="Latency too high (${latency}s>1.5s) "
|
||||
(( $(echo "$input_throughput < 20000" | bc -l) )) && fail="${fail}Input too low (${input_throughput}<20k) "
|
||||
(( $(echo "$output_throughput < 1000" | bc -l) )) && fail="${fail}Output too low (${output_throughput}<1k) "
|
||||
|
||||
if [ -n "$fail" ]; then
|
||||
echo "✗ Benchmark failed: $fail"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ Performance validation passed"
|
||||
fi
|
||||
|
||||
- name: Cleanup servers
|
||||
if: always()
|
||||
run: |
|
||||
if [ -n "${{ steps.start_servers.outputs.server_pid }}" ]; then
|
||||
pkill -P ${{ steps.start_servers.outputs.server_pid }} || true
|
||||
kill ${{ steps.start_servers.outputs.server_pid }} || true
|
||||
fi
|
||||
pkill -f "sglang.launch_server" || true
|
||||
sleep 5
|
||||
remaining=$(ps aux | grep -c "sglang.launch_server" || echo "0")
|
||||
echo "Cleanup completed. Remaining processes: $remaining"
|
||||
@@ -18,6 +18,7 @@ dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle
|
||||
[project.optional-dependencies]
|
||||
runtime_common = [
|
||||
"blobfile==3.0.0",
|
||||
"build",
|
||||
"compressed-tensors",
|
||||
"datasets",
|
||||
"fastapi",
|
||||
|
||||
106
scripts/ci_start_disaggregation_servers.sh
Executable file
106
scripts/ci_start_disaggregation_servers.sh
Executable file
@@ -0,0 +1,106 @@
|
||||
#!/bin/bash
|
||||
|
||||
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Function to find the first available active IB device
|
||||
find_active_ib_device() {
|
||||
for device in mlx5_{0..11}; do
|
||||
if ibv_devinfo $device >/dev/null 2>&1; then
|
||||
state=$(ibv_devinfo $device | grep "state:" | head -1 | awk '{print $2}')
|
||||
if [[ "$state" == "PORT_ACTIVE" ]]; then
|
||||
echo "$device"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
echo "No active IB device found" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Get the first available active IB device
|
||||
DEVICE=$(find_active_ib_device)
|
||||
echo "Using IB device: $DEVICE"
|
||||
|
||||
# Launch prefill servers on GPU 0–3
|
||||
for i in {0..3}; do
|
||||
PORT=$((30001 + i))
|
||||
BOOTSTRAP_PORT=$((9001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching PREFILL server on GPU $i at $HOST:$PORT (bootstrap: $BOOTSTRAP_PORT)"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode prefill \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--disaggregation-bootstrap-port "$BOOTSTRAP_PORT" &
|
||||
done
|
||||
|
||||
# Launch decode servers on GPU 4–7
|
||||
for i in {4..7}; do
|
||||
PORT=$((30001 + i))
|
||||
HOST="127.0.0.$((i + 1))"
|
||||
echo "Launching DECODE server on GPU $i at $HOST:$PORT"
|
||||
CUDA_VISIBLE_DEVICES=$i \
|
||||
python3 -m sglang.launch_server \
|
||||
--model-path "$MODEL_PATH" \
|
||||
--disaggregation-mode decode \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--disaggregation-ib-device "$DEVICE" \
|
||||
--base-gpu-id 0 &
|
||||
done
|
||||
|
||||
# Wait for disaggregation servers to initialize
|
||||
echo "Waiting for disaggregation servers to initialize..."
|
||||
|
||||
# Health check with 5-minute timeout
|
||||
TIMEOUT=300
|
||||
START_TIME=$(date +%s)
|
||||
|
||||
echo "Checking health of all 8 servers..."
|
||||
while true; do
|
||||
CURRENT_TIME=$(date +%s)
|
||||
ELAPSED=$((CURRENT_TIME - START_TIME))
|
||||
|
||||
if [ $ELAPSED -ge $TIMEOUT ]; then
|
||||
echo "❌ Timeout: Servers did not become healthy within 5 minutes"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
HEALTHY_COUNT=0
|
||||
# Check all 8 servers (127.0.0.1-8:30001-30008)
|
||||
for i in {1..8}; do
|
||||
if curl -s -f "http://127.0.0.$i:$((30000 + i))/health" >/dev/null 2>&1; then
|
||||
HEALTHY_COUNT=$((HEALTHY_COUNT + 1))
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Healthy servers: $HEALTHY_COUNT/8 (elapsed: ${ELAPSED}s)"
|
||||
|
||||
if [ $HEALTHY_COUNT -eq 8 ]; then
|
||||
echo "✅ All 8 servers are healthy!"
|
||||
break
|
||||
else
|
||||
sleep 10 # Wait 10 seconds before next check
|
||||
fi
|
||||
done
|
||||
|
||||
# Launch the router
|
||||
echo "Launching router at 127.0.0.9:8000..."
|
||||
python3 -m sglang_router.launch_router \
|
||||
--pd-disaggregation \
|
||||
--policy power_of_two \
|
||||
--prefill http://127.0.0.1:30001 9001 \
|
||||
--prefill http://127.0.0.2:30002 9002 \
|
||||
--prefill http://127.0.0.3:30003 9003 \
|
||||
--prefill http://127.0.0.4:30004 9004 \
|
||||
--decode http://127.0.0.5:30005 \
|
||||
--decode http://127.0.0.6:30006 \
|
||||
--decode http://127.0.0.7:30007 \
|
||||
--decode http://127.0.0.8:30008 \
|
||||
--host 127.0.0.9 \
|
||||
--port 8000 &
|
||||
|
||||
wait # Wait for all background jobs to finish
|
||||
Reference in New Issue
Block a user